From de6db3fed6f0f69a76e3023fb572bfeef68f56fd Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Mon, 9 Feb 2026 08:23:16 +0100 Subject: [PATCH] use float_type for dequantize4 functions --- .../ggml-vulkan/vulkan-shaders/flash_attn.comp | 8 ++++---- .../vulkan-shaders/flash_attn_base.glsl | 18 +++++++++--------- .../vulkan-shaders/flash_attn_cm1.comp | 8 ++++---- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 4d95ef58f9..0a425ce75f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -193,7 +193,7 @@ void main() { uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); #endif @@ -224,7 +224,7 @@ void main() { uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif @@ -299,7 +299,7 @@ void main() { uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - V_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); #endif @@ -331,7 +331,7 @@ void main() { uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - Vf = FLOAT_TYPEV4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 5147b0236c..88b6d65bee 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -97,12 +97,12 @@ layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16 #define BLOCK_SIZE 4 #define BLOCK_BYTE_SIZE 16 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { // iqs is currently always zero in the flash attention shaders if (binding_idx == BINDING_IDX_K) { - return k_packed.k_data_packed[a_offset + ib]; + return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]); } else { - return v_packed.v_data_packed[a_offset + ib]; + return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]); } } #endif @@ -110,7 +110,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -118,7 +118,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); } else { uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -126,24 +126,24 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); } } #endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); } else { const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); } } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 3768b23053..ebb55c9504 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -234,7 +234,7 @@ void main() { uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); #endif @@ -269,7 +269,7 @@ void main() { uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); #endif @@ -400,7 +400,7 @@ void main() { uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - V_Tf = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); #endif @@ -444,7 +444,7 @@ void main() { if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { #if BLOCK_SIZE > 1 - kvsh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; #endif