vulkan: add TQ1_0 support for MUL_MAT

Adds support for TQ2_0 type in the vulkan backend for mul_mat and
mul_mat_vec.

Sponsored-by: Tether Inc.
Signed-off-by: Italo Nicola <italo.nicola@collabora.com>
This commit is contained in:
Italo Nicola 2026-02-18 12:49:01 -03:00
parent 782d22ca35
commit 47817aad47
10 changed files with 175 additions and 12 deletions

View File

@ -3448,6 +3448,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_TQ2_0], matmul_tq2_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_TQ1_0], matmul_tq1_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
@ -3538,6 +3539,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0], matmul_tq1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0], matmul_tq2_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@ -3562,6 +3564,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0].f32acc, matmul_tq2_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0].f32acc, matmul_tq1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@ -3660,6 +3663,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0], matmul_tq2_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0], matmul_tq1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@ -3824,6 +3828,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0].f32acc, matmul_tq2_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0].f32acc, matmul_tq1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@ -3992,6 +3997,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_f32_f32", arr_dmmv_tq2_0_f32_f32_len[reduc], arr_dmmv_tq2_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_TQ1_0][i], "mul_mat_vec_tq1_0_f32_f32", arr_dmmv_tq1_0_f32_f32_len[reduc], arr_dmmv_tq1_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
@ -4017,6 +4023,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_f16_f32", arr_dmmv_tq2_0_f16_f32_len[reduc], arr_dmmv_tq2_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_TQ1_0][i], "mul_mat_vec_tq1_0_f16_f32", arr_dmmv_tq1_0_f16_f32_len[reduc], arr_dmmv_tq1_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
@ -4039,6 +4046,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_q8_1_f32", arr_dmmv_tq2_0_q8_1_f32_len[reduc], arr_dmmv_tq2_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_TQ1_0][i], "mul_mat_vec_tq1_0_q8_1_f32", arr_dmmv_tq1_0_q8_1_f32_len[reduc], arr_dmmv_tq1_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
@ -4123,6 +4131,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ1_0], "dequant_tq1_0", dequant_tq1_0_len, dequant_tq1_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ2_0], "dequant_tq2_0", dequant_tq2_0_len, dequant_tq2_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
@ -5905,6 +5914,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
@ -5978,6 +5988,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -6016,6 +6027,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
if (b_type == GGML_TYPE_Q8_1) {
switch (a_type) {
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@ -6045,6 +6057,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -14927,7 +14940,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
return false;
}
if (src0_type == GGML_TYPE_TQ2_0) {
if (src0_type == GGML_TYPE_TQ1_0 || src0_type == GGML_TYPE_TQ2_0) {
return false;
}
}
@ -14954,6 +14967,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_MXFP4:
break;

View File

@ -437,6 +437,23 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif
#if defined(DATA_A_TQ1_0)
#include "tq_utils.comp"
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(tq1_dequantize(ib + a_offset, iqs), tq1_dequantize(ib + a_offset, iqs + 1));
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return vec4(
tq1_dequantize(ib + a_offset, iqs + 0),
tq1_dequantize(ib + a_offset, iqs + 1),
tq1_dequantize(ib + a_offset, iqs + 2),
tq1_dequantize(ib + a_offset, iqs + 3)
);
}
#endif
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@ -464,7 +481,7 @@ vec2 get_dm(uint ib, uint a_offset) {
}
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_TQ2_0) || defined(DATA_A_TQ1_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), 0);
}

View File

@ -686,6 +686,26 @@ float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords
}
#endif
#if defined(DATA_A_TQ1_0)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ1_0 {
block_tq1_0 block;
};
#define TQ1_CM2 1
#include "tq_utils.comp"
#undef TQ1_CM2
float16_t dequantFuncTQ1_0(const in decodeBufTQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const int val = tq1_dequantize(bl, idx);
return d * float16_t(val);
}
#endif
#if defined(DATA_A_MXFP4)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
block_mxfp4 block;
@ -749,6 +769,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_TQ2_0)
#define dequantFuncA dequantFuncTQ2_0
#elif defined(DATA_A_TQ1_0)
#define dequantFuncA dequantFuncTQ1_0
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#elif defined(DATA_A_F32)

View File

@ -27,7 +27,9 @@ void main() {
for (uint j = 0; j < 4 && (i + j) < p.ne; ++j) {
const uint e = (i + j) % QUANT_K;
#if defined(DATA_A_TQ2_0)
#if defined(DATA_A_TQ1_0)
data_b[i + j] = D_TYPE(tq1_dequantize(ib, e) * d);
#elif defined(DATA_A_TQ2_0)
data_b[i + j] = D_TYPE(tq2_dequantize(ib, e) * d);
#endif
}

View File

@ -30,7 +30,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
[[unroll]] for (uint e = tid % 8; e < 256; e += 8) {
#if defined(DATA_A_TQ2_0)
#if defined(DATA_A_TQ1_0)
const FLOAT_TYPE dequant_val = FLOAT_TYPE(tq1_dequantize(ib0 + i, e)) * d;
#elif defined(DATA_A_TQ2_0)
const FLOAT_TYPE dequant_val = FLOAT_TYPE(tq2_dequantize(ib0 + i, e)) * d;
#endif
const uint b_idx = i * QUANT_K + e;

View File

@ -138,6 +138,32 @@ ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int
}
#endif
#if defined(DATA_A_TQ1_0)
#include "tq_utils.comp"
i32vec2 repack(uint ib, uint iqs) {
const int t00 = tq1_dequantize(ib, iqs + 0);
const int t01 = tq1_dequantize(ib, iqs + 1);
const int t02 = tq1_dequantize(ib, iqs + 2);
const int t03 = tq1_dequantize(ib, iqs + 3);
const int v0 = (t00 & 0xFF) | ((t01 & 0xFF) << 8) | ((t02 & 0xFF) << 16) | ((t03 & 0xFF) << 24);
const int t10 = tq1_dequantize(ib, iqs + 32 + 0);
const int t11 = tq1_dequantize(ib, iqs + 32 + 1);
const int t12 = tq1_dequantize(ib, iqs + 32 + 2);
const int t13 = tq1_dequantize(ib, iqs + 32 + 3);
const int v1 = (t10 & 0xFF) | ((t11 & 0xFF) << 8) | ((t12 & 0xFF) << 16) | ((t13 & 0xFF) << 24);
return i32vec2(v0, v1);
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * float(q_sum) * dsb.x);
}
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {

View File

@ -1,4 +1,4 @@
#if defined(DATA_A_TQ2_0)
#if defined(DATA_A_TQ2_0) || defined(DATA_A_TQ1_0)
#include "tq_utils.comp"
#endif
@ -152,6 +152,24 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const vec2 v = vec2(v0, v1);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_TQ1_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 128;
const uint iqs = idx % 128;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib].d);
const uint e0 = 2 * iqs;
const uint e1 = e0 + 1;
FLOAT_TYPE v0 = FLOAT_TYPE(tq1_dequantize(ib, e0)) * d;
FLOAT_TYPE v1 = FLOAT_TYPE(tq1_dequantize(ib, e1)) * d;
const vec2 v = vec2(v0, v1);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_Q2_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;

View File

@ -21,4 +21,48 @@ int tq2_dequantize(uint ib, uint iqs) {
}
#endif
#if defined(DATA_A_TQ1_0)
#if defined(TQ1_CM2)
int tq1_dequantize(const in decodeBufTQ1_0 bl, uint iqs) {
#else
int tq1_dequantize(uint ib, uint iqs) {
#endif
const uint pow3[6] = uint[6](1, 3, 9, 27, 81, 243);
if (iqs < 160) {
const uint trit = iqs / 32;
const uint byte = iqs % 32;
#if defined(TQ1_CM2)
const uint q = uint(bl.block.qs[byte]);
#else
const uint q = uint(data_a[ib].qs[byte]);
#endif
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
} else if (iqs < 240) {
const uint relative_idx = iqs - 160;
const uint trit = relative_idx / 16;
const uint byte = relative_idx % 16;
#if defined(TQ1_CM2)
const uint q = uint(bl.block.qs[32 + byte]);
#else
const uint q = uint(data_a[ib].qs[32 + byte]);
#endif
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
} else {
const uint relative_idx = iqs - 240;
const uint trit = relative_idx / 4;
const uint byte = relative_idx % 4;
#if defined(TQ1_CM2)
const uint q = uint(bl.block.qh[byte]);
#else
const uint q = uint(data_a[ib].qh[byte]);
#endif
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
}
}
#endif
#endif

View File

@ -1680,6 +1680,23 @@ struct block_iq4_nl_packed16
#define A_TYPE_PACKED16 block_iq4_nl_packed16
#endif
// TQ1_0
#define QUANT_K_S_TQ1_0 240
#define QUANT_K_H_TQ1_0 16
#define QUANT_K_TQ1_0 (QUANT_K_S_TQ1_0 + QUANT_K_H_TQ1_0)
struct block_tq1_0 {
uint8_t qs[QUANT_K_S_TQ1_0 / 5];
uint8_t qh[QUANT_K_H_TQ1_0 / 4];
float16_t d;
};
#if defined(DATA_A_TQ1_0)
#define QUANT_K QUANT_K_TQ1_0
#define QUANT_AUXF 1
#define A_TYPE block_tq1_0
#endif
// TQ2_0
#define QUANT_K_TQ2_0 256
#define QUANT_R_TQ2_0 4

View File

@ -51,6 +51,7 @@ const std::vector<std::string> type_names = {
"q5_1",
"q8_0",
"tq2_0",
"tq1_0",
"q2_k",
"q3_k",
"q4_k",
@ -684,7 +685,7 @@ void process_shaders() {
std::string shader;
if (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) {
shader = "mul_mat_vec_" + tname + ".comp";
} else if (tname == "tq2_0") {
} else if (tname == "tq2_0" || tname == "tq1_0") {
shader = "mul_mat_vec_tq.comp";
} else {
shader = "mul_mat_vec.comp";
@ -699,7 +700,7 @@ void process_shaders() {
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
if (tname != "tq2_0") {
if (tname != "tq2_0" && tname != "tq1_0") {
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
@ -715,7 +716,7 @@ void process_shaders() {
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
} else if (tname == "tq2_0") {
} else if (tname == "tq2_0" || tname == "tq1_0") {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vec_tq_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vec_tq_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vec_tq_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
@ -723,13 +724,13 @@ void process_shaders() {
#endif
// Dequant shaders
if (tname == "tq2_0") {
if (tname == "tq2_0" || tname == "tq1_0") {
string_to_spv("dequant_" + tname, "dequant_tq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
} else if (tname != "f16" && tname != "bf16") {
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
}
if (tname == "tq2_0") continue;
if (tname == "tq2_0" || tname == "tq1_0") continue;
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
@ -1157,7 +1158,7 @@ void write_output_files() {
for (const std::string& btype : btypes) {
for (const auto& tname : type_names) {
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "tq2_0" && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") {
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "tq2_0" && tname != "tq1_0" && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") {
continue;
}
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
@ -1167,7 +1168,7 @@ void write_output_files() {
src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
}
if (btype == "f16" || tname == "tq2_0") {
if (btype == "f16" || tname == "tq2_0" || tname == "tq1_0") {
continue;
}
hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n";