diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 69e03179b3..89cf0fbf80 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -497,6 +497,24 @@ static bool parse_target_bpw(const char * data, float & target_bpw) { return true; } +static const char * get_ftype(const float bpw) { + const std::map quant_bpw = { + {1.5625, "IQ1_S"}, + {1.7500, "IQ1_M"}, + {2.0625, "IQ2_XXS"}, + {2.6250, "Q2_K"}, + {3.0625, "IQ3_XXS"}, + {3.4375, "Q3_K"}, + {4.2500, "IQ4_XS"}, + {4.5000, "Q4_K"}, + {5.5000, "Q5_K"}, + {6.5625, "Q6_K"}, + {8.5000, "Q8_0"} + }; + + return quant_bpw.lower_bound(bpw)->second; +} + int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -655,6 +673,7 @@ int main(int argc, char ** argv) { std::string ftype_str; std::string suffix = ".gguf"; + std::vector tmp_argv(argv, argv + argc); if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { std::string fpath; const size_t pos = fname_inp.find_last_of("/\\"); @@ -678,7 +697,21 @@ int main(int argc, char ** argv) { } arg_idx++; - if (argc <= arg_idx) { + // select quantization type if target_bpw is set unless user specifies type and threads + if (argc - arg_idx <= 1 && params.target_bpw != -1.0f) { + auto * ftype = const_cast(get_ftype(params.target_bpw)); + if (argc == arg_idx) { + tmp_argv.push_back(ftype); + tmp_argv.push_back(nullptr); + argv = const_cast(tmp_argv.data()); + argc++; + } else { + tmp_argv.insert(tmp_argv.end() - 1, ftype); + tmp_argv.push_back(nullptr); + argv = const_cast(tmp_argv.data()); + argc++; + } + } else if (argc <= arg_idx) { fprintf(stderr, "%s: missing ftype\n", __func__); return 1; }