This commit is contained in:
Italo Nicola 2026-04-01 06:45:07 +00:00 committed by GitHub
commit b0670c6a16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 509 additions and 11 deletions

View File

@ -3505,6 +3505,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_TQ2_0], matmul_tq2_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_TQ1_0], matmul_tq1_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
@ -3595,6 +3597,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0], matmul_tq2_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0], matmul_tq1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@ -3617,6 +3621,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0].f32acc, matmul_tq2_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0].f32acc, matmul_tq1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@ -3714,6 +3720,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0], matmul_tq2_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0], matmul_tq1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@ -3877,6 +3885,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_TQ2_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ2_0].f32acc, matmul_tq2_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_TQ1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_TQ1_0].f32acc, matmul_tq1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@ -4044,6 +4054,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_f32_f32", arr_dmmv_tq2_0_f32_f32_len[reduc], arr_dmmv_tq2_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_TQ1_0][i], "mul_mat_vec_tq1_0_f32_f32", arr_dmmv_tq1_0_f32_f32_len[reduc], arr_dmmv_tq1_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
@ -4068,6 +4080,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_f16_f32", arr_dmmv_tq2_0_f16_f32_len[reduc], arr_dmmv_tq2_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_TQ1_0][i], "mul_mat_vec_tq1_0_f16_f32", arr_dmmv_tq1_0_f16_f32_len[reduc], arr_dmmv_tq1_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
@ -4089,6 +4103,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_TQ2_0][i], "mul_mat_vec_tq2_0_q8_1_f32", arr_dmmv_tq2_0_q8_1_f32_len[reduc], arr_dmmv_tq2_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_TQ1_0][i], "mul_mat_vec_tq1_0_q8_1_f32", arr_dmmv_tq1_0_q8_1_f32_len[reduc], arr_dmmv_tq1_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
@ -4172,6 +4189,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ2_0], "dequant_tq2_0", dequant_tq2_0_len, dequant_tq2_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_TQ1_0], "dequant_tq1_0", dequant_tq1_0_len, dequant_tq1_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
@ -6015,6 +6034,8 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -6086,6 +6107,8 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -6123,6 +6146,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
if (b_type == GGML_TYPE_Q8_1) {
switch (a_type) {
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@ -6151,6 +6176,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
@ -15248,6 +15275,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
return false;
}
if (src0_type == GGML_TYPE_TQ1_0 || src0_type == GGML_TYPE_TQ2_0) {
return false;
}
}
switch (src0_type) {
case GGML_TYPE_F32:
@ -15272,6 +15302,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_TQ2_0:
case GGML_TYPE_TQ1_0:
case GGML_TYPE_MXFP4:
break;
default:

View File

@ -421,6 +421,39 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif
#if defined(DATA_A_TQ2_0)
#include "tq_utils.glsl"
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(tq2_dequantize(ib + a_offset, iqs), tq2_dequantize(ib + a_offset, iqs + 1));
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return vec4(
tq2_dequantize(ib + a_offset, iqs + 0),
tq2_dequantize(ib + a_offset, iqs + 1),
tq2_dequantize(ib + a_offset, iqs + 2),
tq2_dequantize(ib + a_offset, iqs + 3)
);
}
#endif
#if defined(DATA_A_TQ1_0)
#include "tq_utils.glsl"
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(tq1_dequantize(ib + a_offset, iqs), tq1_dequantize(ib + a_offset, iqs + 1));
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
return vec4(
tq1_dequantize(ib + a_offset, iqs + 0),
tq1_dequantize(ib + a_offset, iqs + 1),
tq1_dequantize(ib + a_offset, iqs + 2),
tq1_dequantize(ib + a_offset, iqs + 3)
);
}
#endif
#if defined(DATA_A_MXFP4)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@ -448,7 +481,7 @@ vec2 get_dm(uint ib, uint a_offset) {
}
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_TQ2_0) || defined(DATA_A_TQ1_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), 0);
}

View File

@ -666,6 +666,63 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
}
#endif
#if defined(DATA_A_TQ2_0)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ2_0 {
block_tq2_0 block;
};
float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint upper = idx / 128;
const uint byte = (upper * 32) + (idx % 32);
const uint shift = ((idx % 128) / 32) * 2;
const int c = (int(bl.block.qs[byte]) >> shift) & 3;
return d * float16_t(c - 1);
}
#endif
#if defined(DATA_A_TQ1_0)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ1_0 {
block_tq1_0 block;
};
float16_t dequantFuncTQ1_0(const in decodeBufTQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint pow3[6] = uint[6](1, 3, 9, 27, 81, 243);
uint val;
if (idx < 160) {
const uint trit = idx / 32;
const uint byte = idx % 32;
const uint q = uint(bl.block.qs[byte]);
val = (((q * pow3[trit]) & 255) * 3) / 256;
} else if (idx < 240) {
const uint relative_idx = idx - 160;
const uint trit = relative_idx / 16;
const uint byte = relative_idx % 16;
const uint q = uint(bl.block.qs[32 + byte]);
val = (((q * pow3[trit]) & 255) * 3) / 256;
} else {
const uint relative_idx = idx - 240;
const uint trit = relative_idx / 4;
const uint byte = relative_idx % 4;
const uint q = uint(bl.block.qh[byte]);
val = (((q * pow3[trit]) & 255) * 3) / 256;
}
return d * float16_t(val - 1);
}
#endif
#if defined(DATA_A_MXFP4)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
block_mxfp4 block;
@ -727,6 +784,10 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
#define dequantFuncA dequantFuncIQ4_XS
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_TQ2_0)
#define dequantFuncA dequantFuncTQ2_0
#elif defined(DATA_A_TQ1_0)
#define dequantFuncA dequantFuncTQ1_0
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#elif defined(DATA_A_F32)

View File

@ -0,0 +1,36 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#include "types.glsl"
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
layout (push_constant) uniform parameter {
uint ne;
} p;
layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
#include "tq_utils.glsl"
void main() {
const uint i = gl_GlobalInvocationID.x * 4;
if (i >= p.ne) {
return;
}
const uint ib = i / QUANT_K;
const float d = float(data_a[ib].d);
for (uint j = 0; j < 4 && (i + j) < p.ne; ++j) {
const uint e = (i + j) % QUANT_K;
#if defined(DATA_A_TQ1_0)
data_b[i + j] = D_TYPE(tq1_dequantize(ib, e) * d);
#elif defined(DATA_A_TQ2_0)
data_b[i + j] = D_TYPE(tq2_dequantize(ib, e) * d);
#endif
}
}

View File

@ -0,0 +1,59 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require
#include "mul_mat_vec_base.glsl"
#include "tq_utils.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint tid = gl_LocalInvocationID.x;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
for (uint r = 0; r < num_rows; ++r) {
const uint ib0 = a_offset / QUANT_K + (first_row + r) * num_blocks_per_row;
for (uint jcol = 0; jcol < NUM_COLS; ++jcol) {
const uint b_base = (jcol * p.batch_stride_b);
for (uint i = tid/8; i < num_blocks_per_row; i += gl_WorkGroupSize.x/8) {
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
[[unroll]] for (uint e = tid % 8; e < 256; e += 8) {
#if defined(DATA_A_TQ1_0)
const FLOAT_TYPE dequant_val = FLOAT_TYPE(tq1_dequantize(ib0 + i, e)) * d;
#elif defined(DATA_A_TQ2_0)
const FLOAT_TYPE dequant_val = FLOAT_TYPE(tq2_dequantize(ib0 + i, e)) * d;
#endif
const uint b_idx = i * QUANT_K + e;
temp[jcol][r] += dequant_val * FLOAT_TYPE(data_b[b_base + b_offset + b_idx]);
}
}
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,85 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_integer_dot_product : require
#define MMQ
#define B_TYPE block_q8_1_x4
#include "mul_mat_vec_base.glsl"
#include "mul_mat_vecq_funcs.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
const uint tid = gl_LocalInvocationID.x;
for (uint jcol = 0; jcol < NUM_COLS; jcol++) {
const uint b_base = (jcol * p.batch_stride_b);
for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row + n) * num_blocks_per_row;
FLOAT_TYPE acc = 0.0f;
for (uint i = tid/8; i < num_blocks_per_row; i+=gl_WorkGroupSize.x/8) {
const float d = float(data_a[ib0 + i].d);
[[unroll]] for (uint j = 0; j < 64; j += 32) {
[[unroll]] for (uint l = 0; l < 4; l+=2) {
const uint k = (tid % 8) * 4;
const uint a_idx = j * 4 + l * 32 + k;
const i32vec2 a_packed = repack(ib0 + i, a_idx);
const uint b0_idx = i * QUANT_K + j * 4 + l * 32;
const uint b1_idx = i * QUANT_K + j * 4 + (l+1) * 32;
const uint b0_block_idx = b_offset + (b_base + b0_idx) / QUANT_K_Q8_1;
const uint b1_block_idx = b_offset + (b_base + b1_idx) / QUANT_K_Q8_1;
const uint b0_block_idx_outer = b0_block_idx / 4;
const uint b1_block_idx_outer = b1_block_idx / 4;
const uint b0_block_idx_inner = b0_block_idx % 4;
const uint b1_block_idx_inner = b1_block_idx % 4;
vec2 ds0 = vec2(data_b[b0_block_idx_outer].ds[b0_block_idx_inner]);
vec2 ds1 = vec2(data_b[b1_block_idx_outer].ds[b1_block_idx_inner]);
const uint vec_idx = k / 4;
int32_t b0_packed = data_b[b0_block_idx_outer].qs[b0_block_idx_inner * 8 + vec_idx];
int32_t b1_packed = data_b[b1_block_idx_outer].qs[b1_block_idx_inner * 8 + vec_idx];
int32_t q0_sum = dotPacked4x8EXT(a_packed.x, b0_packed);
acc += mul_q8_1(q0_sum, d, ds0, 4);
int32_t q1_sum = dotPacked4x8EXT(a_packed.y, b1_packed);
acc += mul_q8_1(q1_sum, d, ds1, 4);
}
}
}
temp[jcol][n] = acc;
}
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -4,7 +4,7 @@
#include "types.glsl"
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_TQ1_0) || defined(DATA_A_TQ2_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_dm(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
}
@ -112,6 +112,58 @@ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const i
}
#endif
#if defined(DATA_A_TQ2_0)
#include "tq_utils.glsl"
i32vec2 repack(uint ib, uint iqs) {
const int t00 = tq2_dequantize(ib, iqs + 0);
const int t01 = tq2_dequantize(ib, iqs + 1);
const int t02 = tq2_dequantize(ib, iqs + 2);
const int t03 = tq2_dequantize(ib, iqs + 3);
const int v0 = (t00 & 0xFF) | ((t01 & 0xFF) << 8) | ((t02 & 0xFF) << 16) | ((t03 & 0xFF) << 24);
const int t10 = tq2_dequantize(ib, iqs + 32 + 0);
const int t11 = tq2_dequantize(ib, iqs + 32 + 1);
const int t12 = tq2_dequantize(ib, iqs + 32 + 2);
const int t13 = tq2_dequantize(ib, iqs + 32 + 3);
const int v1 = (t10 & 0xFF) | ((t11 & 0xFF) << 8) | ((t12 & 0xFF) << 16) | ((t13 & 0xFF) << 24);
return i32vec2(v0, v1);
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * float(q_sum) * dsb.x);
}
#endif
#if defined(DATA_A_TQ1_0)
#include "tq_utils.glsl"
i32vec2 repack(uint ib, uint iqs) {
const int t00 = tq1_dequantize(ib, iqs + 0);
const int t01 = tq1_dequantize(ib, iqs + 1);
const int t02 = tq1_dequantize(ib, iqs + 2);
const int t03 = tq1_dequantize(ib, iqs + 3);
const int v0 = (t00 & 0xFF) | ((t01 & 0xFF) << 8) | ((t02 & 0xFF) << 16) | ((t03 & 0xFF) << 24);
const int t10 = tq1_dequantize(ib, iqs + 32 + 0);
const int t11 = tq1_dequantize(ib, iqs + 32 + 1);
const int t12 = tq1_dequantize(ib, iqs + 32 + 2);
const int t13 = tq1_dequantize(ib, iqs + 32 + 3);
const int v1 = (t10 & 0xFF) | ((t11 & 0xFF) << 8) | ((t12 & 0xFF) << 16) | ((t13 & 0xFF) << 24);
return i32vec2(v0, v1);
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * float(q_sum) * dsb.x);
}
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
@ -135,6 +187,7 @@ FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const i
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(ib_a, iqs);
q_sum += dotPacked4x8EXT(data_a_qs.x,

View File

@ -1,3 +1,7 @@
#if defined(DATA_A_TQ2_0) || defined(DATA_A_TQ1_0)
#include "tq_utils.glsl"
#endif
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
#if defined(DATA_A_F32) || defined(DATA_A_F16)
#if LOAD_VEC_A == 8
@ -130,6 +134,42 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
#elif defined(DATA_A_TQ2_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 128;
const uint iqs = idx % 128;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib].d);
const uint e0 = 2 * iqs;
const uint e1 = e0 + 1;
FLOAT_TYPE v0 = FLOAT_TYPE(tq2_dequantize(ib, e0)) * d;
FLOAT_TYPE v1 = FLOAT_TYPE(tq2_dequantize(ib, e1)) * d;
const vec2 v = vec2(v0, v1);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_TQ1_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 128;
const uint iqs = idx % 128;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib].d);
const uint e0 = 2 * iqs;
const uint e1 = e0 + 1;
FLOAT_TYPE v0 = FLOAT_TYPE(tq1_dequantize(ib, e0)) * d;
FLOAT_TYPE v1 = FLOAT_TYPE(tq1_dequantize(ib, e1)) * d;
const vec2 v = vec2(v0, v1);
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_Q2_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;

View File

@ -0,0 +1,45 @@
#ifndef TQ_UTILS_IMPL
#define TQ_UTILS_IMPL
#if defined(DATA_A_TQ2_0)
int tq2_dequantize(uint ib, uint iqs) {
const uint upper = iqs / 128;
const uint byte = (upper * 32) + (iqs % 32);
const uint shift = ((iqs % 128) / 32) * 2;
const int c = (int(data_a[ib].qs[byte]) >> shift) & 3;
return c - 1;
}
#endif
#if defined(DATA_A_TQ1_0)
int tq1_dequantize(uint ib, uint iqs) {
const uint pow3[6] = uint[6](1, 3, 9, 27, 81, 243);
if (iqs < 160) {
const uint trit = iqs / 32;
const uint byte = iqs % 32;
const uint q = uint(data_a[ib].qs[byte]);
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
} else if (iqs < 240) {
const uint relative_idx = iqs - 160;
const uint trit = relative_idx / 16;
const uint byte = relative_idx % 16;
const uint q = uint(data_a[ib].qs[32 + byte]);
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
} else {
const uint relative_idx = iqs - 240;
const uint trit = relative_idx / 4;
const uint byte = relative_idx % 4;
const uint q = uint(data_a[ib].qh[byte]);
const uint val = (((q * pow3[trit]) & 255) * 3) / 256;
return int(val) - 1;
}
}
#endif
#endif

View File

@ -1680,6 +1680,40 @@ struct block_iq4_nl_packed16
#define A_TYPE_PACKED16 block_iq4_nl_packed16
#endif
// TQ1_0
#define QUANT_K_S_TQ1_0 240
#define QUANT_K_H_TQ1_0 16
#define QUANT_K_TQ1_0 (QUANT_K_S_TQ1_0 + QUANT_K_H_TQ1_0)
struct block_tq1_0 {
uint8_t qs[QUANT_K_S_TQ1_0 / 5];
uint8_t qh[QUANT_K_H_TQ1_0 / 4];
float16_t d;
};
#if defined(DATA_A_TQ1_0)
#define QUANT_K QUANT_K_TQ1_0
#define QUANT_AUXF 1
#define A_TYPE block_tq1_0
#endif
// TQ2_0
#define QUANT_K_TQ2_0 256
#define QUANT_R_TQ2_0 4
struct block_tq2_0
{
uint8_t qs[QUANT_K_TQ2_0 / QUANT_R_TQ2_0];
float16_t d;
};
#if defined(DATA_A_TQ2_0)
#define QUANT_K QUANT_K_TQ2_0
#define QUANT_R QUANT_R_TQ2_0
#define QUANT_AUXF 1
#define A_TYPE block_tq2_0
#endif
#define QUANT_K_MXFP4 32
#define QUANT_R_MXFP4 2

View File

@ -50,6 +50,8 @@ const std::vector<std::string> type_names = {
"q5_0",
"q5_1",
"q8_0",
"tq2_0",
"tq1_0",
"q2_k",
"q3_k",
"q4_k",
@ -638,7 +640,7 @@ void process_shaders() {
}
for (const auto& tname : type_names) {
if (tname == "bf16") continue;
if (tname == "bf16" || tname == "tq1_0" || tname == "tq2_0") continue;
if (fp16) {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@ -680,7 +682,14 @@ void process_shaders() {
for (const auto& tname : type_names) {
// mul mat vec
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
std::string shader;
if (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) {
shader = "mul_mat_vec_" + tname + ".comp";
} else if (tname == "tq2_0" || tname == "tq1_0") {
shader = "mul_mat_vec_tq.comp";
} else {
shader = "mul_mat_vec.comp";
}
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
@ -691,9 +700,11 @@ void process_shaders() {
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
if (tname != "tq2_0" && tname != "tq1_0") {
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
}
// mul mat vec with integer dot product
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
@ -705,14 +716,22 @@ void process_shaders() {
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
} else if (tname == "tq2_0" || tname == "tq1_0") {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vec_tq_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vec_tq_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vec_tq_q.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
}
#endif
// Dequant shaders
if (tname != "f16" && tname != "bf16") {
if (tname == "tq2_0" || tname == "tq1_0") {
string_to_spv("dequant_" + tname, "dequant_tq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
} else if (tname != "f16" && tname != "bf16") {
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
}
if (tname == "tq2_0" || tname == "tq1_0") continue;
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
if (tname == "f16") {
@ -1147,7 +1166,7 @@ void write_output_files() {
for (const std::string& btype : btypes) {
for (const auto& tname : type_names) {
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") {
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "tq2_0" && tname != "tq1_0" && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") {
continue;
}
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
@ -1157,7 +1176,7 @@ void write_output_files() {
src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
}
if (btype == "f16") {
if (btype == "f16" || tname == "tq2_0" || tname == "tq1_0") {
continue;
}
hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n";

View File

@ -7288,7 +7288,8 @@ static const ggml_type all_types[] = {
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
// GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends
GGML_TYPE_TQ1_0,
GGML_TYPE_TQ2_0,
GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,