diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 33b7f7e584..3d4785c1a3 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -765,10 +765,18 @@ static std::unordered_map target_bpw_type( return djb2_hash(buf.data(), buf.size()); }; + std::string gen_name; + std::string checkpoint_file; char hex[17]; const uint64_t model_id = metadata_id(ml.meta.get()); + std::snprintf(hex, sizeof(hex), "%016" PRIx64, (uint64_t)model_id); - std::string checkpoint_file = ml.arch_name + "-" + std::string(hex) + ".bpw_state"; + ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false); + std::replace(gen_name.begin(), gen_name.end(), ' ', '_'); + + gen_name.empty() ? checkpoint_file = ml.arch_name : checkpoint_file = gen_name; + checkpoint_file += "-" + std::string(hex) + ".bpw_state"; + if (params->keep_bpw_state && params->bpw_state) { const auto * filename = static_cast(params->bpw_state); std::ifstream ifs(filename, std::ios::binary);