From 47817aad47728ce0c29a6e81f86ee9f01eba491a Mon Sep 17 00:00:00 2001 From: Italo Nicola Date: Wed, 18 Feb 2026 12:49:01 -0300 Subject: [PATCH] vulkan: add TQ1_0 support for MUL_MAT Adds support for TQ2_0 type in the vulkan backend for mul_mat and mul_mat_vec. Sponsored-by: Tether Inc. Signed-off-by: Italo Nicola --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 16 ++++++- .../vulkan-shaders/dequant_funcs.glsl | 19 +++++++- .../vulkan-shaders/dequant_funcs_cm2.glsl | 22 ++++++++++ .../vulkan-shaders/dequant_tq.comp | 4 +- .../vulkan-shaders/mul_mat_vec_tq.comp | 4 +- .../vulkan-shaders/mul_mat_vecq_funcs.glsl | 26 +++++++++++ .../vulkan-shaders/mul_mm_funcs.glsl | 20 ++++++++- .../ggml-vulkan/vulkan-shaders/tq_utils.comp | 44 +++++++++++++++++++ .../src/ggml-vulkan/vulkan-shaders/types.glsl | 17 +++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 15 ++++--- 10 files changed, 175 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 4ff45e471b..7b7831e190 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3448,6 +3448,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_TQ2_0], matmul_tq2_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_TQ1_0], matmul_tq1_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) @@ -3538,6 +3539,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0], matmul_tq1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0], matmul_tq2_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3562,6 +3564,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0].f32acc, matmul_tq2_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0].f32acc, matmul_tq1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3660,6 +3663,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0], matmul_tq2_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0], matmul_tq1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -3824,6 +3828,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0].f32acc, matmul_tq2_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0].f32acc, matmul_tq1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -3992,6 +3997,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_f32_f32", arr_dmmv_tq2_0_f32_f32_len[reduc], arr_dmmv_tq2_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_TQ1_0][i], "mul_mat_vec_tq1_0_f32_f32", arr_dmmv_tq1_0_f32_f32_len[reduc], arr_dmmv_tq1_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); @@ -4017,6 +4023,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_f16_f32", arr_dmmv_tq2_0_f16_f32_len[reduc], arr_dmmv_tq2_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_TQ1_0][i], "mul_mat_vec_tq1_0_f16_f32", arr_dmmv_tq1_0_f16_f32_len[reduc], arr_dmmv_tq1_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); @@ -4039,6 +4046,7 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_q8_1_f32", arr_dmmv_tq2_0_q8_1_f32_len[reduc], arr_dmmv_tq2_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_TQ1_0][i], "mul_mat_vec_tq1_0_q8_1_f32", arr_dmmv_tq1_0_q8_1_f32_len[reduc], arr_dmmv_tq1_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); @@ -4123,6 +4131,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ1_0], "dequant_tq1_0", dequant_tq1_0_len, dequant_tq1_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ2_0], "dequant_tq2_0", dequant_tq2_0_len, dequant_tq2_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); @@ -5905,6 +5914,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: @@ -5978,6 +5988,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TQ1_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -6016,6 +6027,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * if (b_type == GGML_TYPE_Q8_1) { switch (a_type) { case GGML_TYPE_TQ2_0: + case GGML_TYPE_TQ1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6045,6 +6057,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_TQ2_0: + case GGML_TYPE_TQ1_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -14927,7 +14940,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm // If there's not enough shared memory for row_ids and the result tile, fallback to CPU return false; } - if (src0_type == GGML_TYPE_TQ2_0) { + if (src0_type == GGML_TYPE_TQ1_0 || src0_type == GGML_TYPE_TQ2_0) { return false; } } @@ -14954,6 +14967,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: case GGML_TYPE_MXFP4: break; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 5612e35cf6..35eaf23fa6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -437,6 +437,23 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_TQ1_0) +#include "tq_utils.comp" + +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(tq1_dequantize(ib + a_offset, iqs), tq1_dequantize(ib + a_offset, iqs + 1)); +} + +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4( + tq1_dequantize(ib + a_offset, iqs + 0), + tq1_dequantize(ib + a_offset, iqs + 1), + tq1_dequantize(ib + a_offset, iqs + 2), + tq1_dequantize(ib + a_offset, iqs + 3) + ); +} +#endif + #if defined(DATA_A_MXFP4) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); @@ -464,7 +481,7 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_TQ2_0) || defined(DATA_A_TQ1_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) vec2 get_dm(uint ib, uint a_offset) { return vec2(float(data_a[a_offset + ib].d), 0); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 99f41caf2c..dc99b26e0d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -686,6 +686,26 @@ float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords } #endif +#if defined(DATA_A_TQ1_0) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ1_0 { + block_tq1_0 block; +}; + +#define TQ1_CM2 1 +#include "tq_utils.comp" +#undef TQ1_CM2 + +float16_t dequantFuncTQ1_0(const in decodeBufTQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const int val = tq1_dequantize(bl, idx); + + return d * float16_t(val); +} +#endif + #if defined(DATA_A_MXFP4) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 { block_mxfp4 block; @@ -749,6 +769,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords #define dequantFuncA dequantFuncIQ4_NL #elif defined(DATA_A_TQ2_0) #define dequantFuncA dequantFuncTQ2_0 +#elif defined(DATA_A_TQ1_0) +#define dequantFuncA dequantFuncTQ1_0 #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 #elif defined(DATA_A_F32) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq.comp index 460ef8c821..ca624d3ffb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq.comp @@ -27,7 +27,9 @@ void main() { for (uint j = 0; j < 4 && (i + j) < p.ne; ++j) { const uint e = (i + j) % QUANT_K; -#if defined(DATA_A_TQ2_0) +#if defined(DATA_A_TQ1_0) + data_b[i + j] = D_TYPE(tq1_dequantize(ib, e) * d); +#elif defined(DATA_A_TQ2_0) data_b[i + j] = D_TYPE(tq2_dequantize(ib, e) * d); #endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq.comp index da870a1a1c..1461d24bcb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq.comp @@ -30,7 +30,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); [[unroll]] for (uint e = tid % 8; e < 256; e += 8) { -#if defined(DATA_A_TQ2_0) +#if defined(DATA_A_TQ1_0) + const FLOAT_TYPE dequant_val = FLOAT_TYPE(tq1_dequantize(ib0 + i, e)) * d; +#elif defined(DATA_A_TQ2_0) const FLOAT_TYPE dequant_val = FLOAT_TYPE(tq2_dequantize(ib0 + i, e)) * d; #endif const uint b_idx = i * QUANT_K + e; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index b2ca69b9ea..d4465c3c42 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -138,6 +138,32 @@ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int } #endif +#if defined(DATA_A_TQ1_0) +#include "tq_utils.comp" + +i32vec2 repack(uint ib, uint iqs) { + const int t00 = tq1_dequantize(ib, iqs + 0); + const int t01 = tq1_dequantize(ib, iqs + 1); + const int t02 = tq1_dequantize(ib, iqs + 2); + const int t03 = tq1_dequantize(ib, iqs + 3); + + const int v0 = (t00 & 0xFF) | ((t01 & 0xFF) << 8) | ((t02 & 0xFF) << 16) | ((t03 & 0xFF) << 24); + + const int t10 = tq1_dequantize(ib, iqs + 32 + 0); + const int t11 = tq1_dequantize(ib, iqs + 32 + 1); + const int t12 = tq1_dequantize(ib, iqs + 32 + 2); + const int t13 = tq1_dequantize(ib, iqs + 32 + 3); + + const int v1 = (t10 & 0xFF) | ((t11 & 0xFF) << 8) | ((t12 & 0xFF) << 16) | ((t13 & 0xFF) << 24); + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * float(q_sum) * dsb.x); +} +#endif + #if defined(DATA_A_MXFP4) // 1-byte loads for mxfp4 blocks (17 bytes) i32vec2 repack(uint ib, uint iqs) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 7d6ee75efc..647f3665e5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -1,4 +1,4 @@ -#if defined(DATA_A_TQ2_0) +#if defined(DATA_A_TQ2_0) || defined(DATA_A_TQ1_0) #include "tq_utils.comp" #endif @@ -152,6 +152,24 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 v = vec2(v0, v1); buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); +#elif defined(DATA_A_TQ1_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; + const uint iqs = idx % 128; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib].d); + + const uint e0 = 2 * iqs; + const uint e1 = e0 + 1; + + FLOAT_TYPE v0 = FLOAT_TYPE(tq1_dequantize(ib, e0)) * d; + FLOAT_TYPE v1 = FLOAT_TYPE(tq1_dequantize(ib, e1)) * d; + + const vec2 v = vec2(v0, v1); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tq_utils.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tq_utils.comp index 5fabd71221..c4abde19a5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/tq_utils.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tq_utils.comp @@ -21,4 +21,48 @@ int tq2_dequantize(uint ib, uint iqs) { } #endif +#if defined(DATA_A_TQ1_0) +#if defined(TQ1_CM2) +int tq1_dequantize(const in decodeBufTQ1_0 bl, uint iqs) { +#else +int tq1_dequantize(uint ib, uint iqs) { +#endif + const uint pow3[6] = uint[6](1, 3, 9, 27, 81, 243); + + if (iqs < 160) { + const uint trit = iqs / 32; + const uint byte = iqs % 32; +#if defined(TQ1_CM2) + const uint q = uint(bl.block.qs[byte]); +#else + const uint q = uint(data_a[ib].qs[byte]); +#endif + const uint val = (((q * pow3[trit]) & 255) * 3) / 256; + return int(val) - 1; + } else if (iqs < 240) { + const uint relative_idx = iqs - 160; + const uint trit = relative_idx / 16; + const uint byte = relative_idx % 16; +#if defined(TQ1_CM2) + const uint q = uint(bl.block.qs[32 + byte]); +#else + const uint q = uint(data_a[ib].qs[32 + byte]); +#endif + const uint val = (((q * pow3[trit]) & 255) * 3) / 256; + return int(val) - 1; + } else { + const uint relative_idx = iqs - 240; + const uint trit = relative_idx / 4; + const uint byte = relative_idx % 4; +#if defined(TQ1_CM2) + const uint q = uint(bl.block.qh[byte]); +#else + const uint q = uint(data_a[ib].qh[byte]); +#endif + const uint val = (((q * pow3[trit]) & 255) * 3) / 256; + return int(val) - 1; + } +} +#endif + #endif \ No newline at end of file diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 24ecb20fc9..ab721f3397 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -1680,6 +1680,23 @@ struct block_iq4_nl_packed16 #define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif +// TQ1_0 +#define QUANT_K_S_TQ1_0 240 +#define QUANT_K_H_TQ1_0 16 +#define QUANT_K_TQ1_0 (QUANT_K_S_TQ1_0 + QUANT_K_H_TQ1_0) + +struct block_tq1_0 { + uint8_t qs[QUANT_K_S_TQ1_0 / 5]; + uint8_t qh[QUANT_K_H_TQ1_0 / 4]; + float16_t d; +}; + +#if defined(DATA_A_TQ1_0) +#define QUANT_K QUANT_K_TQ1_0 +#define QUANT_AUXF 1 +#define A_TYPE block_tq1_0 +#endif + // TQ2_0 #define QUANT_K_TQ2_0 256 #define QUANT_R_TQ2_0 4 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 6848cf8488..24a484e7fa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -51,6 +51,7 @@ const std::vector type_names = { "q5_1", "q8_0", "tq2_0", + "tq1_0", "q2_k", "q3_k", "q4_k", @@ -684,7 +685,7 @@ void process_shaders() { std::string shader; if (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) { shader = "mul_mat_vec_" + tname + ".comp"; - } else if (tname == "tq2_0") { + } else if (tname == "tq2_0" || tname == "tq1_0") { shader = "mul_mat_vec_tq.comp"; } else { shader = "mul_mat_vec.comp"; @@ -699,7 +700,7 @@ void process_shaders() { string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - if (tname != "tq2_0") { + if (tname != "tq2_0" && tname != "tq1_0") { string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); @@ -715,7 +716,7 @@ void process_shaders() { string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - } else if (tname == "tq2_0") { + } else if (tname == "tq2_0" || tname == "tq1_0") { string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vec_tq_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vec_tq_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vec_tq_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); @@ -723,13 +724,13 @@ void process_shaders() { #endif // Dequant shaders - if (tname == "tq2_0") { + if (tname == "tq2_0" || tname == "tq1_0") { string_to_spv("dequant_" + tname, "dequant_tq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); } else if (tname != "f16" && tname != "bf16") { string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); } - if (tname == "tq2_0") continue; + if (tname == "tq2_0" || tname == "tq1_0") continue; shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; @@ -1157,7 +1158,7 @@ void write_output_files() { for (const std::string& btype : btypes) { for (const auto& tname : type_names) { - if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "tq2_0" && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") { + if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "tq2_0" && tname != "tq1_0" && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") { continue; } hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; @@ -1167,7 +1168,7 @@ void write_output_files() { src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; } - if (btype == "f16" || tname == "tq2_0") { + if (btype == "f16" || tname == "tq2_0" || tname == "tq1_0") { continue; } hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n";