From 82764c341a182218f9c391774e2013400e5f4b29 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 31 Mar 2026 22:38:24 -0700 Subject: [PATCH] ggml webgpu: quantized buffers to u32 + wider browser/device support (#21046) * Work towards removing bitcast * Move rest of existing types over * Add timeout back to wait and remove synchronous set_tensor/memset_tensor * move to unpackf16 for wider compatibility * cleanup * Remove deadlock condition in free_bufs --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 10 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 17 +- .../wgsl-shaders/common_decls.tmpl | 24 +++ .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 81 ++++++-- .../wgsl-shaders/mul_mat_decls.tmpl | 194 +++++++----------- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 103 ++++------ 6 files changed, 206 insertions(+), 223 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 97863f4041..a194ce84e2 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1219,9 +1219,8 @@ class ggml_webgpu_shader_lib { defines.push_back("BYTE_HELPERS"); defines.push_back("MUL_ACC_" + type_upper); - - // For fast path we always dequantize from f16 inside the shader - defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); break; } } @@ -1334,9 +1333,8 @@ class ggml_webgpu_shader_lib { defines.push_back("MUL_ACC_" + type_upper); defines.push_back("INIT_SRC0_SHMEM_" + type_upper); defines.push_back("INIT_SRC1_SHMEM_FLOAT"); - - // Use f16 inside the shader for quantized types - defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); variant += std::string("_") + src0_name; break; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index fa3c492a7a..1aa15b0507 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -83,7 +83,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #define WEBGPU_NUM_PARAM_BUFS 96u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u -#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 +#define WEBGPU_WAIT_ANY_TIMEOUT_MS 100 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) @@ -171,6 +171,7 @@ struct webgpu_buf_pool { // Try growing the pool if no free buffers if (free.empty() && cur_pool_size < max_pool_size && should_grow) { cur_pool_size++; + lock.unlock(); // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks wgpu::Buffer dev_buf; ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); @@ -507,7 +508,7 @@ static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; while (blocking_wait) { - auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0); + auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6); if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { #ifdef GGML_WEBGPU_GPU_PROFILE ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); @@ -728,7 +729,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); std::vector commands = { command }; std::vector sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) }; - ggml_backend_webgpu_wait(ctx, sub); } /** End WebGPU Actions */ @@ -2694,17 +2694,6 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, // memset the remaining bytes ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); - } else { - // wait for WriteBuffer to complete - buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", - std::string(message).c_str()); - } - }), - UINT64_MAX); } WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 9a5b18ebc0..feb0bca3f8 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -8,6 +8,30 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { } #endif +#ifdef U32_DEQUANT_HELPERS +fn load_src0_u16_at(byte_offset: u32) -> u32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_src0_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = src0[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = src0[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} + +fn load_src0_f16_at(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_src0_u16_at(byte_offset)); + return f16(packed[0]); +} +#endif + #ifdef Q4_0_T struct q4_0 { d: f16, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index b682216146..8b76cecba9 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -6,6 +6,8 @@ enable chromium_experimental_subgroup_matrix; #ifdef KV_F32 #define KV_TYPE f32 +#elif defined(KV_Q4_0) || defined(KV_Q8_0) +#define KV_TYPE u32 #else #define KV_TYPE f16 #endif @@ -37,11 +39,13 @@ enable chromium_experimental_subgroup_matrix; #define NQ 16 // Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights #define F16_PER_BLOCK 9 +#define BLOCK_SIZE_BYTES 18u #define WEIGHTS_PER_F16 4 #elif defined(KV_Q8_0) #define NQ 8 // Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights #define F16_PER_BLOCK 17 +#define BLOCK_SIZE_BYTES 34u #define WEIGHTS_PER_F16 2 #endif #define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) @@ -55,6 +59,47 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; } +#if defined(KV_Q4_0) || defined(KV_Q8_0) +fn load_k_u16_at(byte_offset: u32) -> u32 { + let word = K[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_k_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = K[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = K[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} + +fn load_v_u16_at(byte_offset: u32) -> u32 { + let word = V[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_v_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = V[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = V[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} + +fn f16_from_u16(bits: u32) -> f16 { + let packed = unpack2x16float(bits); + return f16(packed[0]); +} +#endif + struct Params { offset_q: u32, offset_k: u32, @@ -254,12 +299,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_k_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -282,12 +326,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_k_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d; @@ -459,12 +502,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_v_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -487,12 +529,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_v_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index de60ebbcf2..eb228537ba 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -61,10 +61,10 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 18u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -81,14 +81,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1u + block_offset + j]; - let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -104,10 +102,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q4_1 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 20u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -124,15 +122,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); + let m = load_src0_f16_at(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; @@ -149,11 +145,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_0 // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 22u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. // tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights @@ -171,18 +167,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let qh0 = src0[scale_idx + 1u]; - let qh1 = src0[scale_idx + 2u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = load_src0_f16_at(block_byte_base); + let qh_packed = load_src0_u32_at(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -207,11 +199,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_1 // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 24u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. // tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights @@ -229,20 +221,16 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; - let qh0 = src0[scale_idx + 2u]; - let qh1 = src0[scale_idx + 3u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = load_src0_f16_at(block_byte_base); + let m = load_src0_f16_at(block_byte_base + 2u); + let qh_packed = load_src0_u32_at(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -266,10 +254,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q8_0 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 34u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread @@ -286,14 +274,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_0 = src0[scale_idx + 1u + block_offset + j]; - let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -308,10 +294,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q8_1 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 36u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block @@ -328,15 +314,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); + let m = load_src0_f16_at(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -351,7 +335,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q2_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 42u; +const BLOCK_SIZE_BYTES = 84u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { // Use standard thread layout instead of lane/row_group @@ -371,10 +355,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx + 40u]; - let dmin = src0[scale_idx + 41u]; + let d = load_src0_f16_at(block_byte_base + 80u); + let dmin = load_src0_f16_at(block_byte_base + 82u); // Decode the element at position k_in_block let block_of_32 = k_in_block / 32u; @@ -387,18 +371,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let is = k_in_block / 16u; - let sc_0 = src0[scale_idx + 2u * (is / 4u)]; - let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u]; - let sc_packed = bitcast(vec2(sc_0, sc_1)); + let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u)); let sc = get_byte(sc_packed, is % 4u); let dl = d * f16(sc & 0xFu); let ml = dmin * f16(sc >> 4u); let q_idx = q_b_idx + k + l; - let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 3u; @@ -410,7 +390,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q3_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 55u; +const BLOCK_SIZE_BYTES = 110u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -429,9 +409,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx + 54u]; + let d = load_src0_f16_at(block_byte_base + 108u); // Load and unpack scales let kmask1: u32 = 0x03030303u; @@ -439,9 +419,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var scale_vals: array; for (var i: u32 = 0u; i < 4u; i++) { - let scale_0 = src0[scale_idx + 48u + (2u*i)]; - let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i); } var tmp: u32 = scale_vals[2]; @@ -453,16 +431,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load hmask and qs arrays var hmask_vals: array; for (var i: u32 = 0u; i < 8u; i++) { - let hmask_0 = src0[scale_idx + (2u*i)]; - let hmask_1 = src0[scale_idx + (2u*i) + 1u]; - hmask_vals[i] = bitcast(vec2(hmask_0, hmask_1)); + hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i); } var qs_vals: array; for (var i: u32 = 0u; i < 16u; i++) { - let qs_0 = src0[scale_idx + 16u + (2u*i)]; - let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u]; - qs_vals[i] = bitcast(vec2(qs_0, qs_1)); + qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i); } let half = k_in_block / 128u; // 0 or 1 @@ -502,7 +476,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q4_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 72u; +const BLOCK_SIZE_BYTES = 144u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -521,17 +495,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let dmin = src0[scale_idx + 1u]; + let d = load_src0_f16_at(block_byte_base); + let dmin = load_src0_f16_at(block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - let scale_0 = src0[scale_idx + 2u + (2u*i)]; - let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); } // Map k_in_block to loop structure: @@ -567,9 +539,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 0xFu; @@ -582,7 +552,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 88u; +const BLOCK_SIZE_BYTES = 176u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -601,17 +571,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let dmin = src0[scale_idx + 1u]; + let d = load_src0_f16_at(block_byte_base); + let dmin = load_src0_f16_at(block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - let scale_0 = src0[scale_idx + 2u + (2u*i)]; - let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); } // The original loop processes elements in groups of 64 @@ -651,15 +619,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); - let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)]; - let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u]; - let qh_packed = bitcast(vec2(qh_0, qh_1)); + let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u)); let qh_byte = get_byte(qh_packed, l % 4u); @@ -675,7 +639,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q6_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 105u; +const BLOCK_SIZE_BYTES = 210u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -694,7 +658,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let half = k_in_block / 128u; let pos_in_half = k_in_block % 128u; @@ -707,30 +671,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only ql13 word needed let ql13_flat = ql_b_idx + l; - let ql13_word = ql13_flat / 4u; - let ql13 = bitcast(vec2( - src0[scale_idx + 2u * ql13_word], - src0[scale_idx + 2u * ql13_word + 1u] - )); - let ql13_b = get_byte(ql13, ql13_flat % 4u); + let ql13 = load_src0_u32_at(block_byte_base + ql13_flat); + let ql13_b = get_byte(ql13, 0u); // Load only ql24 word needed let ql24_flat = ql_b_idx + l + 32u; - let ql24_word = ql24_flat / 4u; - let ql24 = bitcast(vec2( - src0[scale_idx + 2u * ql24_word], - src0[scale_idx + 2u * ql24_word + 1u] - )); - let ql24_b = get_byte(ql24, ql24_flat % 4u); + let ql24 = load_src0_u32_at(block_byte_base + ql24_flat); + let ql24_b = get_byte(ql24, 0u); // Load only qh word needed let qh_flat = qh_b_idx + l; - let qh_word = qh_flat / 4u; - let qh = bitcast(vec2( - src0[scale_idx + 64u + 2u * qh_word], - src0[scale_idx + 64u + 2u * qh_word + 1u] - )); - let qh_b = get_byte(qh, qh_flat % 4u); + let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat); + let qh_b = get_byte(qh, 0u); let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0); @@ -740,14 +692,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only the scale word needed let is = l / 16u; let sc_idx = sc_b_idx + is + quarter * 2u; - let sc_word = sc_idx / 4u; - let sc = bitcast(vec2( - src0[scale_idx + 96u + 2u * sc_word], - src0[scale_idx + 96u + 2u * sc_word + 1u] - )); - let sc_val = get_byte_i32(sc, sc_idx % 4u); + let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx); + let sc_val = get_byte_i32(sc, 0u); - let d = src0[scale_idx + 104u]; + let d = load_src0_f16_at(block_byte_base + 208u); var q_val: f16; if (quarter == 0u) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 94f4bae11f..6525f23bdf 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -52,8 +52,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q4_0 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 18u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -62,14 +62,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); + let d = f32(load_src0_f16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1 + block_offset + j]; - let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -86,8 +85,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q4_1 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 20u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 10u; const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -96,15 +95,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = f32(src0[scale_idx + 1u]); + let d = f32(load_src0_f16_at(block_byte_base)); + let m = f32(load_src0_f16_at(block_byte_base + 2u)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = f32((q_byte >> 4) & 0xF) * d + m; @@ -121,8 +119,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q5_0 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 22u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 11u; const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -131,18 +129,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let qh0 = src0[scale_idx + 1u]; - let qh1 = src0[scale_idx + 2u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = f32(load_src0_f16_at(block_byte_base)); + let qh_packed = load_src0_u32_at(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -168,8 +163,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q5_1 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 24u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 12u; const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -178,19 +173,16 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = src0[scale_idx + 1u]; - let qh0 = src0[scale_idx + 2u]; - let qh1 = src0[scale_idx + 3u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = f32(load_src0_f16_at(block_byte_base)); + let m = load_src0_f16_at(block_byte_base + 2u); + let qh_packed = load_src0_u32_at(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -216,8 +208,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q8_0 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 34u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 17u; const WEIGHTS_PER_F16 = 2u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -226,15 +218,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); + let d = f32(load_src0_f16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1 + block_offset + j]; - let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -250,8 +241,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q8_1 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 36u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 18u; const WEIGHTS_PER_F16 = 2u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -260,16 +251,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = src0[scale_idx + 1u]; + let d = f32(load_src0_f16_at(block_byte_base)); + let m = load_src0_f16_at(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d + f32(m); @@ -284,13 +274,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q6_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 105u; - -fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 { - let aligned = byte_offset & ~3u; - let idx = bbase + aligned / 2u; - return bitcast(vec2(src0[idx], src0[idx + 1u])); -} +const BLOCK_SIZE_BYTES = 210u; fn byte_of(v: u32, b: u32) -> u32 { return (v >> (b * 8u)) & 0xFFu; @@ -323,16 +307,15 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { var local_sum = 0.0; for (var i = ix; i < nb; i += 2u) { - let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK; + let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; - let d_raw = load_u32_at(bbase, 208u); - let d = f32(bitcast>(d_raw)[0]); + let d = f32(load_src0_f16_at(bbase + 208u)); - let ql1_u32 = load_u32_at(bbase, q_offset_l); - let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u); - let qh_u32 = load_u32_at(bbase, 128u + q_offset_h); - let sc_u32_0 = load_u32_at(bbase, sc_base_byte); - let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u); + let ql1_u32 = load_src0_u32_at(bbase + q_offset_l); + let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u); + let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h); + let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte); + let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u); let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u);