diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1bee3e187c..b353d04142 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3079,6 +3079,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec case GGML_TYPE_MXFP4: lut_size = 4*16; break; + case GGML_TYPE_NVFP4: + // Same kvalues budget as MXFP4 plus ue4m3_fp32_lut[128] (types.glsl, DATA_A_NVFP4). + lut_size = 4*16 + 128u * (uint32_t)sizeof(float); + break; default: break; } @@ -3558,6 +3562,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_NVFP4], matmul_nvfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) GGML_ASSERT(device->subgroup_ballot); @@ -3588,6 +3593,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) #undef CREATE_MM #undef CREATE_MM2 } else @@ -3651,6 +3657,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3674,6 +3681,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } GGML_ASSERT(device->subgroup_ballot); @@ -3708,6 +3716,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); #undef CREATE_MM2 #undef CREATE_MM } else @@ -3773,6 +3782,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3819,6 +3829,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3864,6 +3875,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3939,6 +3951,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3983,6 +3996,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_subgroup_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); } else { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -4010,6 +4024,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } } // reusing CREATE_MM from the fp32 path @@ -4108,6 +4123,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); @@ -4133,6 +4149,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -4184,6 +4201,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_NVFP4], "mul_mat_vec_id_nvfp4_f32", arr_dmmv_id_nvfp4_f32_f32_len[reduc16], arr_dmmv_id_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -4239,6 +4257,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -4265,6 +4284,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_NVFP4], "get_rows_nvfp4", get_rows_nvfp4_len, get_rows_nvfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -4291,6 +4311,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); @@ -6089,6 +6110,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6161,6 +6183,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6227,6 +6250,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6318,6 +6342,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6387,6 +6412,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -15373,6 +15399,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return false; @@ -15488,6 +15515,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_I32: return true; default: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 06df509525..6a69214747 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -4,7 +4,7 @@ #include "generic_unary_head.glsl" #include "dequant_funcs.glsl" -#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) // 16 invocations needed for init_iq_shmem layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; #else diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index ede1275cfc..88d07d2dfd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -450,6 +450,25 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_NVFP4) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint sub = iqs >> 4; + const float d = ue4m3_to_fp32(data_a[a_offset + ib].d[sub]); + const uint j = iqs & 7; + const uint shift = (iqs & 8) >> 1; // 0 or 4 + const uint vui0 = uint(data_a[a_offset + ib].qs[sub * 8u + j]); + const uint vui1 = uint(data_a[a_offset + ib].qs[sub * 8u + j + 1]); + const uint qs0 = (vui0 >> shift) & 0xF; + const uint qs1 = (vui1 >> shift) & 0xF; + return vec2(float(kvalues_mxfp4[qs0]), float(kvalues_mxfp4[qs1])) * d * 0.5; +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const vec2 v1 = dequantize(ib, iqs + 2u, a_offset); + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); @@ -484,6 +503,12 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif +#if defined(DATA_A_NVFP4) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1.0, 0.0); +} +#endif + #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 03035f2812..c582aba87d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -697,6 +697,24 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords } #endif +#if defined(DATA_A_NVFP4) +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4 { + block_nvfp4 block; +}; + +float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint sub = (idx & 0x30) >> 4; + const uint iqs = ((idx & 0x30) >> 1) + (idx & 0x7); + const uint shift = (idx & 0x8) >> 1; + const float d = ue4m3_to_fp32(bl.block.d[sub]); + uint qs = uint(bl.block.qs[iqs]); + qs = (qs >> shift) & 0xF; + return float16_t(kvalues_mxfp4[qs] * d * 0.5); +} +#endif + #if defined(DATA_A_Q1_0) #define dequantFuncA dequantFuncQ1_0 #elif defined(DATA_A_Q4_0) @@ -743,6 +761,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords #define dequantFuncA dequantFuncIQ4_NL #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#elif defined(DATA_A_NVFP4) +#define dequantFuncA dequantFuncNVFP4 #elif defined(DATA_A_F32) #define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp new file mode 100644 index 0000000000..689089160b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_nvfp4 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint sub = tid / 16; + const uint ir = tid % 16; + const uint ib = 16 * i + ir; + if (ib >= p.nel / 64) { + return; + } + + const uint q_idx = 8 * sub; + const uint b_idx = 1024 * i + 64 * ir + 16 * sub; + + const float d = ue4m3_to_fp32(data_a[ib].d[sub]); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF])); + data_b[b_idx + l + 8] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 219bd60803..6e4a29d2fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -501,6 +501,23 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin kvalues_mxfp4[vui2 & 0xF] * d); buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, kvalues_mxfp4[vui2 >> 4] * d); +#elif defined(DATA_A_NVFP4) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + // lo and hi nibbles are 8 elements apart, which doesn't quite line up with + // how the thread mapping and buf_idx calculation works for other types. + const uint buf_idx = col * SHMEM_STRIDE + (row & 3) + (row & ~3) * 2; + + const uint ib = idx / 16u; + const uint sub = (idx & 0xC) >> 2; + const uint iqs = (idx & 0xF) * 2; + const float d = ue4m3_to_fp32(data_a[ib].d[sub]) * 0.5; + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 4] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); #endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 1fb592fb84..4bcd97756f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -1713,6 +1713,22 @@ struct block_mxfp4 #define A_TYPE block_mxfp4 #endif +#define QUANT_K_NVFP4 64 +#define QUANT_R_NVFP4 1 + +struct block_nvfp4 +{ + uint8_t d[QUANT_K_NVFP4 / 16]; + uint8_t qs[QUANT_K_NVFP4 / 2]; +}; + +#if defined(DATA_A_NVFP4) +#define QUANT_K QUANT_K_NVFP4 +#define QUANT_R QUANT_R_NVFP4 +#define QUANT_AUXF 1 +#define A_TYPE block_nvfp4 +#endif + #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) const int8_t kvalues_iq4nl_const[16] = { int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), @@ -1732,7 +1748,7 @@ void init_iq_shmem(uvec3 wgsize) } #endif -#if defined(DATA_A_MXFP4) +#if defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) const int8_t kvalues_mxfp4_const[16] = { int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12), int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12), @@ -1740,6 +1756,24 @@ const int8_t kvalues_mxfp4_const[16] = { shared int8_t kvalues_mxfp4[16]; +#if defined(DATA_A_NVFP4) +// UE4M3 scale in NVFP4 blocks use only 7 bits; sign (bit 7) is always zero. +shared float ue4m3_fp32_lut[128]; + +float ue4m3_to_fp32_build(uint u) { + if (u == 0u || u == 127u) { + return 0.0; + } + const uint exp = (u >> 3) & 15u; + const uint man = u & 7u; + if (exp == 0u) { + return float(man) * (1.0 / 512.0); + } + const uint bits = (exp + 120u) << 23 | (man << 20); + return uintBitsToFloat(bits); +} +#endif + #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { @@ -1747,6 +1781,11 @@ void init_iq_shmem(uvec3 wgsize) for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) { kvalues_mxfp4[i] = kvalues_mxfp4_const[i]; } +#if defined(DATA_A_NVFP4) + for (uint i = gl_LocalInvocationIndex.x; i < 128u; i += wgsize.x) { + ue4m3_fp32_lut[i] = ue4m3_to_fp32_build(i); + } +#endif barrier(); } #endif @@ -1783,6 +1822,12 @@ float e8m0_to_fp32(uint8_t x) { return uintBitsToFloat(bits); } +#if defined(DATA_A_NVFP4) +float ue4m3_to_fp32(uint8_t x) { + return ue4m3_fp32_lut[uint(x)]; +} +#endif + #if BDA #extension GL_EXT_buffer_reference : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 607eef7d0d..b232927658 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -66,6 +66,7 @@ const std::vector type_names = { "iq4_xs", "iq4_nl", "mxfp4", + "nvfp4", "bf16", }; @@ -556,7 +557,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string load_vec_quant = "2"; if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4") || (tname == "nvfp4")) load_vec_quant = "4"; if (tname == "bf16") {