Vulkan: MMVQ Integer Dot K-Quant and MUL_MAT_ID support (#16900)
* vulkan: split mul_mmq_funcs for mul_mat_vecq use * add mxfp4 mmvq * add q2_k mmvq * add q3_k mmvq * add q4_k and q5_k mmvq * add q6_k mmvq * handle 4x4 quants per mmvq thread * enable MUL_MAT_ID mmvq support * enable subgroup optimizations for mul_mat_vec_id shaders * device tuning * request prealloc_y sync after quantization * fix indentation * fix llvmpipe test failures * fix mul_mat_id mmvq condition * fix unused variable warning
This commit is contained in:
parent
59d8d4e963
commit
47a268ea50
|
|
@ -613,9 +613,10 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
||||||
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||||
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||||
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
|
||||||
|
|
||||||
vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||||
|
vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
|
||||||
|
|
||||||
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
|
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
|
||||||
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
||||||
|
|
@ -1611,7 +1612,7 @@ class vk_perf_logger {
|
||||||
}
|
}
|
||||||
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
|
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
|
||||||
const uint64_t m = node->src[0]->ne[1];
|
const uint64_t m = node->src[0]->ne[1];
|
||||||
const uint64_t n = node->ne[1];
|
const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2];
|
||||||
const uint64_t k = node->src[1]->ne[0];
|
const uint64_t k = node->src[1]->ne[0];
|
||||||
const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
|
const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
|
||||||
std::string name = ggml_op_name(node->op);
|
std::string name = ggml_op_name(node->op);
|
||||||
|
|
@ -3525,13 +3526,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
// the number of rows computed per shader depends on GPU model and quant
|
// the number of rows computed per shader depends on GPU model and quant
|
||||||
uint32_t rm_stdq = 1;
|
uint32_t rm_stdq = 1;
|
||||||
uint32_t rm_kq = 2;
|
uint32_t rm_kq = 2;
|
||||||
|
uint32_t rm_stdq_int = 1;
|
||||||
|
uint32_t rm_kq_int = 1;
|
||||||
if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
||||||
if (device->architecture == AMD_GCN) {
|
if (device->architecture == AMD_GCN) {
|
||||||
rm_stdq = 2;
|
rm_stdq = 2;
|
||||||
rm_kq = 4;
|
rm_kq = 4;
|
||||||
|
rm_stdq_int = 4;
|
||||||
}
|
}
|
||||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL)
|
} else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||||
rm_stdq = 2;
|
rm_stdq = 2;
|
||||||
|
rm_stdq_int = 2;
|
||||||
|
}
|
||||||
uint32_t rm_iq = 2 * rm_kq;
|
uint32_t rm_iq = 2 * rm_kq;
|
||||||
|
|
||||||
const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
|
const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
|
||||||
|
|
@ -3612,39 +3618,73 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
|
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
|
||||||
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
|
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
}
|
}
|
||||||
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", arr_dmmv_id_q5_1_f32_f32_len[reduc], arr_dmmv_id_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", arr_dmmv_id_q8_0_f32_f32_len[reduc], arr_dmmv_id_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", arr_dmmv_id_q2_k_f32_f32_len[reduc16], arr_dmmv_id_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", arr_dmmv_id_q3_k_f32_f32_len[reduc16], arr_dmmv_id_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", arr_dmmv_id_q4_k_f32_f32_len[reduc16], arr_dmmv_id_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", arr_dmmv_id_q5_k_f32_f32_len[reduc16], arr_dmmv_id_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", arr_dmmv_id_q6_k_f32_f32_len[reduc16], arr_dmmv_id_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", arr_dmmv_id_iq1_s_f32_f32_len[reduc16], arr_dmmv_id_iq1_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", arr_dmmv_id_iq1_m_f32_f32_len[reduc16], arr_dmmv_id_iq1_m_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", arr_dmmv_id_iq2_xs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", arr_dmmv_id_iq2_s_f32_f32_len[reduc16], arr_dmmv_id_iq2_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", arr_dmmv_id_iq3_s_f32_f32_len[reduc16], arr_dmmv_id_iq3_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
|
||||||
|
|
||||||
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
|
if (device->integer_dot_product) {
|
||||||
|
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
|
||||||
|
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
|
||||||
|
}
|
||||||
|
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
GGML_UNUSED(rm_stdq_int);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
GGML_UNUSED(rm_kq_int);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
#endif
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
|
|
||||||
|
|
||||||
// dequant shaders
|
// dequant shaders
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
|
||||||
|
|
@ -5453,6 +5493,12 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
|
case GGML_TYPE_Q2_K:
|
||||||
|
case GGML_TYPE_Q3_K:
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
|
case GGML_TYPE_Q5_K:
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
@ -5592,9 +5638,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) {
|
||||||
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
|
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
|
||||||
GGML_ASSERT(b_type == GGML_TYPE_F32);
|
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1);
|
||||||
|
|
||||||
|
if (b_type == GGML_TYPE_Q8_1) {
|
||||||
|
switch (a_type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
|
case GGML_TYPE_Q2_K:
|
||||||
|
case GGML_TYPE_Q3_K:
|
||||||
|
case GGML_TYPE_Q4_K:
|
||||||
|
case GGML_TYPE_Q5_K:
|
||||||
|
case GGML_TYPE_Q6_K:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch (a_type) {
|
switch (a_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
|
|
@ -5625,7 +5690,31 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
|
// heuristic to choose workgroup size
|
||||||
|
uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
|
||||||
|
if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||||
|
// Prefer larger workgroups when M is small, to spread the work out more
|
||||||
|
// and keep more SMs busy.
|
||||||
|
// q6_k seems to prefer small workgroup size even for "medium" values of M.
|
||||||
|
if (a_type == GGML_TYPE_Q6_K) {
|
||||||
|
if (m < 4096 && k >= 1024) {
|
||||||
|
dmmv_wg = DMMV_WG_SIZE_LARGE;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (m <= 8192 && k >= 1024) {
|
||||||
|
dmmv_wg = DMMV_WG_SIZE_LARGE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (b_type == GGML_TYPE_Q8_1) {
|
||||||
|
if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||||
|
dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
|
||||||
|
}
|
||||||
|
return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type];
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type];
|
||||||
}
|
}
|
||||||
|
|
||||||
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
|
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
|
||||||
|
|
@ -6817,20 +6906,35 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// General performance issue with q3_k and q6_k due to 2-byte alignment
|
||||||
|
if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// MMVQ is generally good for batches
|
// MMVQ is generally good for batches
|
||||||
if (n > 1) {
|
if (n > 1) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Quantization overhead is not worth it for small k
|
||||||
switch (device->vendor_id) {
|
switch (device->vendor_id) {
|
||||||
case VK_VENDOR_ID_NVIDIA:
|
case VK_VENDOR_ID_NVIDIA:
|
||||||
|
if (k <= 4096) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
switch (src0_type) {
|
switch (src0_type) {
|
||||||
|
case GGML_TYPE_MXFP4:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
|
return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
|
||||||
default:
|
default:
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
case VK_VENDOR_ID_AMD:
|
case VK_VENDOR_ID_AMD:
|
||||||
|
if (k < 2048) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
switch (src0_type) {
|
switch (src0_type) {
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
return device->architecture == vk_device_architecture::AMD_GCN;
|
return device->architecture == vk_device_architecture::AMD_GCN;
|
||||||
|
|
@ -6838,6 +6942,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
case VK_VENDOR_ID_INTEL:
|
case VK_VENDOR_ID_INTEL:
|
||||||
|
if (k < 2048) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
switch (src0_type) {
|
switch (src0_type) {
|
||||||
// From tests on A770 Linux, may need more tuning
|
// From tests on A770 Linux, may need more tuning
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
|
|
@ -6851,7 +6959,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_UNUSED(m);
|
GGML_UNUSED(m);
|
||||||
GGML_UNUSED(k);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||||
|
|
@ -7574,7 +7681,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
if (x_non_contig || qx_needs_dequant) {
|
if (x_non_contig || qx_needs_dequant) {
|
||||||
ctx->prealloc_x_need_sync = true;
|
ctx->prealloc_x_need_sync = true;
|
||||||
}
|
}
|
||||||
if (y_non_contig) {
|
if (y_non_contig || quantize_y) {
|
||||||
ctx->prealloc_y_need_sync = true;
|
ctx->prealloc_y_need_sync = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -7600,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||||
|
|
||||||
const uint64_t ne10 = src1->ne[0];
|
const uint64_t ne10 = src1->ne[0];
|
||||||
const uint64_t ne11 = src1->ne[1];
|
const uint64_t ne11 = src1->ne[1];
|
||||||
// const uint64_t ne12 = src1->ne[2];
|
const uint64_t ne12 = src1->ne[2];
|
||||||
// const uint64_t ne13 = src1->ne[3];
|
// const uint64_t ne13 = src1->ne[3];
|
||||||
|
|
||||||
const uint64_t nei0 = ids->ne[0];
|
const uint64_t nei0 = ids->ne[0];
|
||||||
|
|
@ -7617,19 +7724,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||||
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
|
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
|
||||||
|
|
||||||
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
|
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
|
||||||
|
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type);
|
||||||
const bool qx_needs_dequant = x_non_contig;
|
|
||||||
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
|
|
||||||
|
|
||||||
// Not implemented
|
|
||||||
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
|
||||||
|
|
||||||
const uint64_t x_ne = ggml_nelements(src0);
|
|
||||||
const uint64_t y_ne = ggml_nelements(src1);
|
|
||||||
|
|
||||||
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
|
|
||||||
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
|
|
||||||
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
|
|
||||||
|
|
||||||
vk_pipeline to_fp16_vk_0 = nullptr;
|
vk_pipeline to_fp16_vk_0 = nullptr;
|
||||||
vk_pipeline to_fp16_vk_1 = nullptr;
|
vk_pipeline to_fp16_vk_1 = nullptr;
|
||||||
|
|
@ -7641,11 +7736,38 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||||
} else {
|
} else {
|
||||||
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
||||||
}
|
}
|
||||||
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
|
|
||||||
|
// Check for mmq first
|
||||||
|
vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr;
|
||||||
|
vk_pipeline to_q8_1 = nullptr;
|
||||||
|
|
||||||
|
if (dmmv == nullptr) {
|
||||||
|
// Fall back to f16 dequant mul mat
|
||||||
|
dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00);
|
||||||
|
quantize_y = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (quantize_y) {
|
||||||
|
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool qx_needs_dequant = x_non_contig;
|
||||||
|
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
|
||||||
|
|
||||||
|
// Not implemented
|
||||||
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
||||||
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
||||||
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
||||||
GGML_ASSERT(dmmv != nullptr);
|
GGML_ASSERT(dmmv != nullptr);
|
||||||
|
|
||||||
|
const uint64_t x_ne = ggml_nelements(src0);
|
||||||
|
const uint64_t y_ne = ggml_nelements(src1);
|
||||||
|
|
||||||
|
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
|
||||||
|
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
|
||||||
|
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
|
||||||
|
(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
|
||||||
|
|
||||||
{
|
{
|
||||||
if (
|
if (
|
||||||
(qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
|
(qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
|
||||||
|
|
@ -7656,7 +7778,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||||
ctx->prealloc_size_x = x_sz;
|
ctx->prealloc_size_x = x_sz;
|
||||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||||
}
|
}
|
||||||
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz) {
|
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
|
||||||
ctx->prealloc_size_y = y_sz;
|
ctx->prealloc_size_y = y_sz;
|
||||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||||
}
|
}
|
||||||
|
|
@ -7668,6 +7790,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||||
if (qy_needs_dequant) {
|
if (qy_needs_dequant) {
|
||||||
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
|
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
|
||||||
}
|
}
|
||||||
|
if (quantize_y) {
|
||||||
|
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
||||||
|
}
|
||||||
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
|
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -7683,7 +7808,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||||
} else {
|
} else {
|
||||||
d_X = d_Qx;
|
d_X = d_Qx;
|
||||||
}
|
}
|
||||||
if (qy_needs_dequant) {
|
if (qy_needs_dequant || quantize_y) {
|
||||||
d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
|
d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
|
||||||
} else {
|
} else {
|
||||||
d_Y = d_Qy;
|
d_Y = d_Qy;
|
||||||
|
|
@ -7711,6 +7836,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||||
ctx->prealloc_y_last_tensor_used = src1;
|
ctx->prealloc_y_last_tensor_used = src1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (quantize_y) {
|
||||||
|
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
|
||||||
|
ctx->prealloc_y_last_tensor_used != src1) {
|
||||||
|
if (ctx->prealloc_y_need_sync) {
|
||||||
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
|
||||||
|
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
||||||
|
ctx->prealloc_y_last_tensor_used = src1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t stride_batch_y = ne10*ne11;
|
uint32_t stride_batch_y = ne10*ne11;
|
||||||
|
|
||||||
|
|
@ -7772,7 +7908,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
||||||
if (x_non_contig) {
|
if (x_non_contig) {
|
||||||
ctx->prealloc_x_need_sync = true;
|
ctx->prealloc_x_need_sync = true;
|
||||||
}
|
}
|
||||||
if (y_non_contig) {
|
if (y_non_contig || quantize_y) {
|
||||||
ctx->prealloc_y_need_sync = true;
|
ctx->prealloc_y_need_sync = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,6 @@
|
||||||
|
|
||||||
#include "types.glsl"
|
#include "types.glsl"
|
||||||
|
|
||||||
#if defined(A_TYPE_PACKED16)
|
|
||||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
|
||||||
#endif
|
|
||||||
#if defined(A_TYPE_PACKED32)
|
|
||||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(DATA_A_F32)
|
#if defined(DATA_A_F32)
|
||||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
|
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,13 @@ layout (push_constant) uniform parameter
|
||||||
|
|
||||||
#if !RMS_NORM_ROPE_FUSION
|
#if !RMS_NORM_ROPE_FUSION
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
#if defined(A_TYPE_PACKED16)
|
||||||
|
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||||
|
#endif
|
||||||
|
#if defined(A_TYPE_PACKED32)
|
||||||
|
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||||
|
#endif
|
||||||
|
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,13 @@ layout (push_constant) uniform parameter
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
#if defined(A_TYPE_PACKED16)
|
||||||
|
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||||
|
#endif
|
||||||
|
#if defined(A_TYPE_PACKED32)
|
||||||
|
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||||
|
#endif
|
||||||
|
|
||||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
uint get_idx() {
|
uint get_idx() {
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||||
|
|
||||||
#include "mul_mat_vec_base.glsl"
|
#include "mul_mat_vec_base.glsl"
|
||||||
|
#include "dequant_funcs.glsl"
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,6 @@
|
||||||
|
|
||||||
#include "mul_mat_vec_iface.glsl"
|
#include "mul_mat_vec_iface.glsl"
|
||||||
|
|
||||||
#include "dequant_funcs.glsl"
|
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
{
|
{
|
||||||
uint ncols;
|
uint ncols;
|
||||||
|
|
|
||||||
|
|
@ -5,13 +5,15 @@
|
||||||
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
|
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
|
||||||
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
|
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
|
||||||
|
|
||||||
#ifndef MMQ
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
#if defined(A_TYPE_VEC4)
|
#if defined(A_TYPE_VEC4)
|
||||||
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
||||||
#endif
|
#endif
|
||||||
#else
|
#if defined(A_TYPE_PACKED16)
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||||
|
#endif
|
||||||
|
#if defined(A_TYPE_PACKED32)
|
||||||
|
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||||
|
|
|
||||||
|
|
@ -10,60 +10,56 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
|
||||||
#define K_PER_ITER 8
|
#define K_PER_ITER 8
|
||||||
|
#elif defined(DATA_A_QUANT_K)
|
||||||
#include "mul_mmq_funcs.glsl"
|
#define K_PER_ITER 16
|
||||||
|
#else
|
||||||
|
#error unimplemented
|
||||||
|
#endif
|
||||||
|
|
||||||
uint a_offset, b_offset, d_offset;
|
uint a_offset, b_offset, d_offset;
|
||||||
|
|
||||||
int32_t cache_b_qs[2];
|
int32_t cache_b_qs[K_PER_ITER / 4];
|
||||||
vec2 cache_b_ds;
|
vec2 cache_b_ds;
|
||||||
|
|
||||||
|
#include "mul_mat_vecq_funcs.glsl"
|
||||||
|
|
||||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
|
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
|
||||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
|
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
|
||||||
|
|
||||||
// Preload data_b block
|
// Preload data_b block
|
||||||
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
|
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
|
||||||
const uint b_qs_idx = tid % 4;
|
const uint b_qs_idx = tid % (32 / K_PER_ITER);
|
||||||
const uint b_block_idx_outer = b_block_idx / 4;
|
const uint b_block_idx_outer = b_block_idx / 4;
|
||||||
const uint b_block_idx_inner = b_block_idx % 4;
|
const uint b_block_idx_inner = b_block_idx % 4;
|
||||||
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
|
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
|
||||||
|
|
||||||
#if QUANT_R == 2
|
#if QUANT_R == 2
|
||||||
|
// Assumes K_PER_ITER == 8
|
||||||
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
|
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
|
||||||
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
|
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
|
||||||
#else
|
#else
|
||||||
|
#if K_PER_ITER == 8
|
||||||
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
|
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
|
||||||
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
|
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
|
||||||
|
#elif K_PER_ITER == 16
|
||||||
|
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ];
|
||||||
|
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
|
||||||
|
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
|
||||||
|
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
|
||||||
|
#else
|
||||||
|
#error unimplemented
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
uint ibi = first_row*p.ncols;
|
uint ibi = first_row*p.ncols;
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
|
const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
|
||||||
ibi += p.ncols;
|
ibi += p.ncols;
|
||||||
|
|
||||||
int32_t q_sum = 0;
|
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
|
||||||
#if QUANT_R == 2
|
|
||||||
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
|
|
||||||
q_sum += dotPacked4x8EXT(data_a_qs.x,
|
|
||||||
cache_b_qs[0]);
|
|
||||||
q_sum += dotPacked4x8EXT(data_a_qs.y,
|
|
||||||
cache_b_qs[1]);
|
|
||||||
#else
|
|
||||||
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
|
|
||||||
q_sum += dotPacked4x8EXT(data_a_qs,
|
|
||||||
cache_b_qs[0]);
|
|
||||||
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
|
|
||||||
q_sum += dotPacked4x8EXT(data_a_qs,
|
|
||||||
cache_b_qs[1]);
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if QUANT_AUXF == 1
|
|
||||||
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
|
|
||||||
#else
|
|
||||||
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint tid = gl_LocalInvocationID.x;
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
|
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
a_offset /= QUANT_K;
|
a_offset /= QUANT_K_Q8_1;
|
||||||
b_offset /= QUANT_K_Q8_1;
|
b_offset /= QUANT_K_Q8_1;
|
||||||
|
|
||||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||||
|
|
@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
unroll_count = 2;
|
unroll_count = 2;
|
||||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||||
|
|
||||||
#if K_PER_ITER == 2
|
|
||||||
if ((p.ncols & 1) != 0 &&
|
|
||||||
unrolled_iters == num_iters &&
|
|
||||||
unrolled_iters > 0) {
|
|
||||||
unrolled_iters -= unroll_count;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
while (i < unrolled_iters) {
|
while (i < unrolled_iters) {
|
||||||
// Manually partially unroll the loop
|
// Manually partially unroll the loop
|
||||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||||
|
|
@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
void main() {
|
void main() {
|
||||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||||
|
|
||||||
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||||
|
init_iq_shmem(gl_WorkGroupSize);
|
||||||
|
#endif
|
||||||
|
|
||||||
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
||||||
if (first_row + NUM_ROWS <= p.stride_d) {
|
if (first_row + NUM_ROWS <= p.stride_d) {
|
||||||
compute_outputs(first_row, NUM_ROWS);
|
compute_outputs(first_row, NUM_ROWS);
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,379 @@
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
||||||
|
FLOAT_TYPE get_dm(uint ib) {
|
||||||
|
return FLOAT_TYPE(data_a[ib].d);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
||||||
|
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||||
|
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_MXFP4)
|
||||||
|
FLOAT_TYPE get_dm(uint ib) {
|
||||||
|
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q2_K)
|
||||||
|
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Each iqs value maps to a 32-bit integer
|
||||||
|
#if defined(DATA_A_Q4_0)
|
||||||
|
// 2-byte loads for Q4_0 blocks (18 bytes)
|
||||||
|
i32vec2 repack(uint ib, uint iqs) {
|
||||||
|
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||||
|
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||||
|
const uint32_t vui = pack32(quants);
|
||||||
|
return i32vec2( vui & 0x0F0F0F0F,
|
||||||
|
(vui >> 4) & 0x0F0F0F0F);
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
|
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q4_1)
|
||||||
|
// 4-byte loads for Q4_1 blocks (20 bytes)
|
||||||
|
i32vec2 repack(uint ib, uint iqs) {
|
||||||
|
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
||||||
|
return i32vec2( vui & 0x0F0F0F0F,
|
||||||
|
(vui >> 4) & 0x0F0F0F0F);
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
|
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q5_0)
|
||||||
|
// 2-byte loads for Q5_0 blocks (22 bytes)
|
||||||
|
i32vec2 repack(uint ib, uint iqs) {
|
||||||
|
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||||
|
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||||
|
const uint32_t vui = pack32(quants);
|
||||||
|
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
|
||||||
|
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||||
|
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||||
|
|
||||||
|
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||||
|
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||||
|
|
||||||
|
return i32vec2(v0, v1);
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
|
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q5_1)
|
||||||
|
// 4-byte loads for Q5_1 blocks (24 bytes)
|
||||||
|
i32vec2 repack(uint ib, uint iqs) {
|
||||||
|
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||||
|
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
||||||
|
const uint32_t vui = pack32(quants);
|
||||||
|
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
||||||
|
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
||||||
|
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
||||||
|
|
||||||
|
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
||||||
|
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
||||||
|
|
||||||
|
return i32vec2(v0, v1);
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
|
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q8_0)
|
||||||
|
// 2-byte loads for Q8_0 blocks (34 bytes)
|
||||||
|
int32_t repack(uint ib, uint iqs) {
|
||||||
|
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
||||||
|
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
|
return FLOAT_TYPE(float(q_sum) * da * dsb.x);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_MXFP4)
|
||||||
|
// 1-byte loads for mxfp4 blocks (17 bytes)
|
||||||
|
i32vec2 repack(uint ib, uint iqs) {
|
||||||
|
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||||
|
data_a[ib].qs[iqs * 4 + 1],
|
||||||
|
data_a[ib].qs[iqs * 4 + 2],
|
||||||
|
data_a[ib].qs[iqs * 4 + 3]));
|
||||||
|
|
||||||
|
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
|
||||||
|
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
|
||||||
|
|
||||||
|
return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])),
|
||||||
|
pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])));
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||||
|
return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
|
||||||
|
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||||
|
int32_t q_sum = 0;
|
||||||
|
#if QUANT_R == 2
|
||||||
|
const i32vec2 data_a_qs = repack(ib_a, iqs);
|
||||||
|
q_sum += dotPacked4x8EXT(data_a_qs.x,
|
||||||
|
cache_b_qs[0]);
|
||||||
|
q_sum += dotPacked4x8EXT(data_a_qs.y,
|
||||||
|
cache_b_qs[1]);
|
||||||
|
#else
|
||||||
|
int32_t data_a_qs = repack(ib_a, iqs * 2);
|
||||||
|
q_sum += dotPacked4x8EXT(data_a_qs,
|
||||||
|
cache_b_qs[0]);
|
||||||
|
data_a_qs = repack(ib_a, iqs * 2 + 1);
|
||||||
|
q_sum += dotPacked4x8EXT(data_a_qs,
|
||||||
|
cache_b_qs[1]);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// 2 quants per call => divide sums by 8/2 = 4
|
||||||
|
return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q2_K)
|
||||||
|
// 4-byte loads for Q2_K blocks (84 bytes)
|
||||||
|
i32vec4 repack4(uint ib, uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
|
||||||
|
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||||
|
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||||
|
|
||||||
|
return i32vec4((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303,
|
||||||
|
(data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303,
|
||||||
|
(data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303,
|
||||||
|
(data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t get_scale(uint ib, uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
|
||||||
|
return data_a[ib_k].scales[iqs_k / 4];
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||||
|
int32_t sum_d = 0;
|
||||||
|
int32_t sum_m = 0;
|
||||||
|
|
||||||
|
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
|
||||||
|
const uint8_t scale = get_scale(ib_a, iqs * 4);
|
||||||
|
const vec2 dm = vec2(get_dm(ib_a));
|
||||||
|
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
|
||||||
|
|
||||||
|
sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF);
|
||||||
|
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]);
|
||||||
|
|
||||||
|
sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF);
|
||||||
|
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]);
|
||||||
|
|
||||||
|
sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF);
|
||||||
|
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]);
|
||||||
|
|
||||||
|
sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF);
|
||||||
|
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]);
|
||||||
|
|
||||||
|
return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m)));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q3_K)
|
||||||
|
// 2-byte loads for Q3_K blocks (110 bytes)
|
||||||
|
i32vec4 repack4(uint ib, uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
|
||||||
|
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
||||||
|
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
||||||
|
const uint hm_shift = iqs_k / 8;
|
||||||
|
|
||||||
|
// bitwise OR to add 4 if hmask is set, subtract later
|
||||||
|
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2));
|
||||||
|
|
||||||
|
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)),
|
||||||
|
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)),
|
||||||
|
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)),
|
||||||
|
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4)));
|
||||||
|
}
|
||||||
|
|
||||||
|
float get_d_scale(uint ib, uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
const uint is = iqs_k / 4;
|
||||||
|
|
||||||
|
const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8 ] >> (4 * (is / 8))) & 0x0F0F) |
|
||||||
|
(((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4));
|
||||||
|
return float(data_a[ib_k].d) * float(scale - 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||||
|
int32_t q_sum = 0;
|
||||||
|
|
||||||
|
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
|
||||||
|
const float d_scale = get_d_scale(ib_a, iqs * 4);
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
|
||||||
|
|
||||||
|
return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||||
|
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
|
||||||
|
i32vec4 repack4(uint ib, uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
|
||||||
|
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
|
||||||
|
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q4_K)
|
||||||
|
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
|
||||||
|
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
|
||||||
|
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F;
|
||||||
|
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F;
|
||||||
|
|
||||||
|
return i32vec4(vals0, vals1, vals2, vals3);
|
||||||
|
#else // defined(DATA_A_Q5_K)
|
||||||
|
const uint qh_idx = iqs;
|
||||||
|
const uint qh_shift = iqs_k / 8;
|
||||||
|
|
||||||
|
return i32vec4(((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F) |
|
||||||
|
(((data_a_packed32[ib_k].qh[qh_idx ] >> qh_shift) & 0x01010101) << 4),
|
||||||
|
((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) |
|
||||||
|
(((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4),
|
||||||
|
((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) |
|
||||||
|
(((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4),
|
||||||
|
((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) |
|
||||||
|
(((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
vec2 get_dm_scale(uint ib, uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
const uint is = iqs_k / 8;
|
||||||
|
u8vec2 scale_dm;
|
||||||
|
if (is < 4) {
|
||||||
|
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
|
||||||
|
} else {
|
||||||
|
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
|
||||||
|
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||||
|
int32_t q_sum = 0;
|
||||||
|
|
||||||
|
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
|
||||||
|
const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4);
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
|
||||||
|
|
||||||
|
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(DATA_A_Q6_K)
|
||||||
|
// 2-byte loads for Q6_K blocks (210 bytes)
|
||||||
|
i32vec4 repack4(uint ib, uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
|
||||||
|
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
|
||||||
|
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
|
||||||
|
|
||||||
|
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
|
||||||
|
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
|
||||||
|
|
||||||
|
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||||
|
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||||
|
const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||||
|
const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||||
|
const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||||
|
const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||||
|
const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||||
|
const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) |
|
||||||
|
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
|
||||||
|
|
||||||
|
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)),
|
||||||
|
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)),
|
||||||
|
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)),
|
||||||
|
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y)));
|
||||||
|
}
|
||||||
|
|
||||||
|
float get_d_scale(uint ib, uint iqs) {
|
||||||
|
const uint ib_k = ib / 8;
|
||||||
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]);
|
||||||
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
|
||||||
|
int32_t q_sum = 0;
|
||||||
|
|
||||||
|
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
|
||||||
|
const float d_scale = get_d_scale(ib_a, iqs * 4);
|
||||||
|
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
|
||||||
|
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
|
||||||
|
|
||||||
|
return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
@ -78,8 +78,6 @@ layout (constant_id = 10) const uint WARP = 32;
|
||||||
|
|
||||||
#define BK 32
|
#define BK 32
|
||||||
|
|
||||||
#define MMQ_SHMEM
|
|
||||||
|
|
||||||
#include "mul_mmq_shmem_types.glsl"
|
#include "mul_mmq_shmem_types.glsl"
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
|
|
|
||||||
|
|
@ -9,31 +9,6 @@
|
||||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
|
||||||
// 2-byte loads for Q4_0 blocks (18 bytes)
|
// 2-byte loads for Q4_0 blocks (18 bytes)
|
||||||
// 4-byte loads for Q4_1 blocks (20 bytes)
|
// 4-byte loads for Q4_1 blocks (20 bytes)
|
||||||
i32vec2 repack(uint ib, uint iqs) {
|
|
||||||
#ifdef DATA_A_Q4_0
|
|
||||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
|
||||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
|
||||||
const uint32_t vui = pack32(quants);
|
|
||||||
return i32vec2( vui & 0x0F0F0F0F,
|
|
||||||
(vui >> 4) & 0x0F0F0F0F);
|
|
||||||
#else // DATA_A_Q4_1
|
|
||||||
const uint32_t vui = data_a_packed32[ib].qs[iqs];
|
|
||||||
return i32vec2( vui & 0x0F0F0F0F,
|
|
||||||
(vui >> 4) & 0x0F0F0F0F);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef DATA_A_Q4_0
|
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
|
||||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
|
||||||
}
|
|
||||||
#else // DATA_A_Q4_1
|
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
|
||||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef MMQ_SHMEM
|
|
||||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
#ifdef DATA_A_Q4_0
|
#ifdef DATA_A_Q4_0
|
||||||
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||||
|
|
@ -73,42 +48,17 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
|
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
|
||||||
}
|
}
|
||||||
|
|
||||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
#ifdef DATA_A_Q4_0
|
||||||
|
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y)));
|
||||||
|
#else // DATA_A_Q4_1
|
||||||
|
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
#endif // MMQ_SHMEM
|
#endif
|
||||||
|
|
||||||
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
|
||||||
// 2-byte loads for Q5_0 blocks (22 bytes)
|
// 2-byte loads for Q5_0 blocks (22 bytes)
|
||||||
// 4-byte loads for Q5_1 blocks (24 bytes)
|
// 4-byte loads for Q5_1 blocks (24 bytes)
|
||||||
i32vec2 repack(uint ib, uint iqs) {
|
|
||||||
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
|
||||||
data_a_packed16[ib].qs[iqs * 2 + 1]);
|
|
||||||
const uint32_t vui = pack32(quants);
|
|
||||||
#ifdef DATA_A_Q5_0
|
|
||||||
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
|
|
||||||
#else // DATA_A_Q5_1
|
|
||||||
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
|
|
||||||
#endif
|
|
||||||
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
|
|
||||||
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
|
|
||||||
|
|
||||||
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
|
|
||||||
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
|
|
||||||
|
|
||||||
return i32vec2(v0, v1);
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef DATA_A_Q5_0
|
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
|
||||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
|
||||||
}
|
|
||||||
#else // DATA_A_Q5_1
|
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
|
||||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef MMQ_SHMEM
|
|
||||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
#ifdef DATA_A_Q5_0
|
#ifdef DATA_A_Q5_0
|
||||||
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||||
|
|
@ -154,23 +104,16 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
|
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
|
||||||
}
|
}
|
||||||
|
|
||||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
#ifdef DATA_A_Q5_0
|
||||||
|
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y)));
|
||||||
|
#else // DATA_A_Q5_1
|
||||||
|
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
#endif // MMQ_SHMEM
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q8_0)
|
#if defined(DATA_A_Q8_0)
|
||||||
// 2-byte loads for Q8_0 blocks (34 bytes)
|
// 2-byte loads for Q8_0 blocks (34 bytes)
|
||||||
int32_t repack(uint ib, uint iqs) {
|
|
||||||
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
|
|
||||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
|
||||||
}
|
|
||||||
|
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
|
||||||
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef MMQ_SHMEM
|
|
||||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
|
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
|
||||||
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
data_a_packed16[ib].qs[iqs * 2 + 1]));
|
||||||
|
|
@ -197,28 +140,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
q_sum += dotPacked4x8EXT(qs_a, qs_b);
|
q_sum += dotPacked4x8EXT(qs_a, qs_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x));
|
||||||
}
|
}
|
||||||
#endif // MMQ_SHMEM
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_MXFP4)
|
#if defined(DATA_A_MXFP4)
|
||||||
// 1-byte loads for mxfp4 blocks (17 bytes)
|
// 1-byte loads for mxfp4 blocks (17 bytes)
|
||||||
i32vec2 repack(uint ib, uint iqs) {
|
|
||||||
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
|
||||||
data_a[ib].qs[iqs * 4 + 1],
|
|
||||||
data_a[ib].qs[iqs * 4 + 2],
|
|
||||||
data_a[ib].qs[iqs * 4 + 3]));
|
|
||||||
|
|
||||||
return i32vec2( quants & 0x0F0F0F0F,
|
|
||||||
(quants >> 4) & 0x0F0F0F0F);
|
|
||||||
}
|
|
||||||
|
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
|
||||||
return ACC_TYPE(da * dsb.x * float(q_sum));
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef MMQ_SHMEM
|
|
||||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
|
||||||
data_a[ib].qs[iqs * 4 + 1],
|
data_a[ib].qs[iqs * 4 + 1],
|
||||||
|
|
@ -252,37 +179,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
|
return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum));
|
||||||
}
|
}
|
||||||
#endif // MMQ_SHMEM
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
|
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
|
||||||
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
|
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
|
||||||
#if defined(DATA_A_Q2_K)
|
#if defined(DATA_A_Q2_K)
|
||||||
// 4-byte loads for Q2_K blocks (84 bytes)
|
// 4-byte loads for Q2_K blocks (84 bytes)
|
||||||
int32_t repack(uint ib, uint iqs) {
|
|
||||||
const uint ib_k = ib / 8;
|
|
||||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
|
||||||
|
|
||||||
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
|
|
||||||
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
|
|
||||||
|
|
||||||
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t get_scale(uint ib, uint iqs) {
|
|
||||||
const uint ib_k = ib / 8;
|
|
||||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
|
||||||
|
|
||||||
return data_a[ib_k].scales[iqs_k / 4];
|
|
||||||
}
|
|
||||||
|
|
||||||
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
|
||||||
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef MMQ_SHMEM
|
|
||||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
const uint ib_k = ib / 8;
|
const uint ib_k = ib / 8;
|
||||||
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||||
|
|
@ -326,14 +230,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
|
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
|
return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m)));
|
||||||
}
|
}
|
||||||
#endif // MMQ_SHMEM
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q3_K)
|
#if defined(DATA_A_Q3_K)
|
||||||
// 2-byte loads for Q3_K blocks (110 bytes)
|
// 2-byte loads for Q3_K blocks (110 bytes)
|
||||||
#ifdef MMQ_SHMEM
|
|
||||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
const uint ib_k = ib / 8;
|
const uint ib_k = ib / 8;
|
||||||
const uint hm_idx = iqs * QUANT_R_MMQ;
|
const uint hm_idx = iqs * QUANT_R_MMQ;
|
||||||
|
|
@ -394,18 +296,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
}
|
}
|
||||||
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||||
|
|
||||||
return ACC_TYPE(cache_b.ds.x * result);
|
return ACC_TYPE(float(cache_b.ds.x) * result);
|
||||||
}
|
}
|
||||||
#endif // MMQ_SHMEM
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||||
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
|
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
|
||||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
|
||||||
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
|
|
||||||
}
|
|
||||||
|
|
||||||
#ifdef MMQ_SHMEM
|
|
||||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
const uint ib_k = ib / 8;
|
const uint ib_k = ib / 8;
|
||||||
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
|
||||||
|
|
@ -427,7 +323,6 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
|
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
if (iqs == 0) {
|
if (iqs == 0) {
|
||||||
// Scale index
|
// Scale index
|
||||||
const uint is = iqs_k / 8;
|
const uint is = iqs_k / 8;
|
||||||
|
|
@ -464,49 +359,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
|
||||||
}
|
}
|
||||||
|
|
||||||
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
|
return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
|
||||||
}
|
|
||||||
#endif // MMQ_SHMEM
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef MMQ_SHMEM
|
|
||||||
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
|
|
||||||
if (is_in_bounds) {
|
|
||||||
const uint ib_outer = ib / 4;
|
|
||||||
const uint ib_inner = ib % 4;
|
|
||||||
|
|
||||||
if (iqs == 0) {
|
|
||||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
|
||||||
}
|
|
||||||
|
|
||||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
|
||||||
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
|
|
||||||
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
|
|
||||||
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
|
|
||||||
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
|
|
||||||
} else {
|
|
||||||
if (iqs == 0) {
|
|
||||||
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
|
|
||||||
}
|
|
||||||
|
|
||||||
buf_b[buf_ib].qs[iqs * 4 ] = 0;
|
|
||||||
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
|
|
||||||
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
|
|
||||||
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void block_b_to_registers(const uint ib) {
|
|
||||||
cache_b.ds = buf_b[ib].ds;
|
|
||||||
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
|
|
||||||
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q6_K)
|
#if defined(DATA_A_Q6_K)
|
||||||
// 2-byte loads for Q6_K blocks (210 bytes)
|
// 2-byte loads for Q6_K blocks (210 bytes)
|
||||||
#ifdef MMQ_SHMEM
|
|
||||||
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
|
||||||
const uint ib_k = ib / 8;
|
const uint ib_k = ib / 8;
|
||||||
const uint iqs_k = (ib % 8) * 8 + iqs;
|
const uint iqs_k = (ib % 8) * 8 + iqs;
|
||||||
|
|
@ -558,32 +416,39 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
|
||||||
}
|
}
|
||||||
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
|
||||||
|
|
||||||
return ACC_TYPE(cache_b.ds.x * result);
|
return ACC_TYPE(float(cache_b.ds.x) * result);
|
||||||
}
|
|
||||||
#endif // MMQ_SHMEM
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
|
|
||||||
FLOAT_TYPE get_d(uint ib) {
|
|
||||||
return FLOAT_TYPE(data_a[ib].d);
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_MXFP4)
|
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
|
||||||
FLOAT_TYPE get_d(uint ib) {
|
if (is_in_bounds) {
|
||||||
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
|
const uint ib_outer = ib / 4;
|
||||||
}
|
const uint ib_inner = ib % 4;
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
|
if (iqs == 0) {
|
||||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
|
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(DATA_A_Q2_K)
|
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||||
FLOAT_TYPE_VEC2 get_dm(uint ib) {
|
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
|
||||||
const uint ib_k = ib / 8;
|
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
|
||||||
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
|
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
|
||||||
|
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
|
||||||
|
} else {
|
||||||
|
if (iqs == 0) {
|
||||||
|
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
buf_b[buf_ib].qs[iqs * 4 ] = 0;
|
||||||
|
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
|
||||||
|
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
|
||||||
|
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void block_b_to_registers(const uint ib) {
|
||||||
|
cache_b.ds = buf_b[ib].ds;
|
||||||
|
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
|
||||||
|
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
|
||||||
|
|
@ -679,14 +679,20 @@ void process_shaders() {
|
||||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||||
|
|
||||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||||
|
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||||
|
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||||
|
|
||||||
// mul mat vec with integer dot product
|
// mul mat vec with integer dot product
|
||||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
if (is_legacy_quant(tname)) {
|
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) {
|
||||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
|
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||||
|
|
||||||
|
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||||
|
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||||
|
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -1100,7 +1106,7 @@ void write_output_files() {
|
||||||
|
|
||||||
for (const std::string& btype : btypes) {
|
for (const std::string& btype : btypes) {
|
||||||
for (const auto& tname : type_names) {
|
for (const auto& tname : type_names) {
|
||||||
if (btype == "q8_1" && !is_legacy_quant(tname)) {
|
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
|
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
|
||||||
|
|
@ -1109,6 +1115,16 @@ void write_output_files() {
|
||||||
src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
|
src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
|
||||||
src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
|
src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (btype == "f16") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n";
|
||||||
|
hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n";
|
||||||
|
if (basename(input_filepath) == "mul_mat_vec.comp") {
|
||||||
|
src << "const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
|
||||||
|
src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue