diff --git a/include/llama.h b/include/llama.h index f745e2110b..ce04011e19 100644 --- a/include/llama.h +++ b/include/llama.h @@ -367,6 +367,7 @@ extern "C" { void * prune_layers; // pointer to vector containing layer indices to prune float target_bpw; // target bits per weight (bpw) bool keep_bpw_state; // keep bpw state file + void * bpw_state; // pointer to bpw state file } llama_model_quantize_params; typedef struct llama_logit_bias { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 38d20e3d0f..1dee52d58d 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -762,7 +762,23 @@ static std::unordered_map target_bpw_type( char hex[17]; const uint64_t model_id = metadata_id(ml.meta.get()); std::snprintf(hex, sizeof(hex), "%016" PRIx64, (uint64_t)model_id); - const std::string checkpoint_file = ml.arch_name + "-" + std::string(hex) + ".bpw_state"; + std::string checkpoint_file = ml.arch_name + "-" + std::string(hex) + ".bpw_state"; + if (params->keep_bpw_state && params->bpw_state) { + const auto * filename = static_cast(params->bpw_state); + std::ifstream ifs(filename, std::ios::binary); + if (ifs.good()) { + checkpoint_file = std::string(filename); + } else { + std::ofstream ofs(filename, std::ios::binary | std::ios::app); + if (ofs.is_open()) { + checkpoint_file = std::string(filename); + ofs.close(); + std::remove(checkpoint_file.c_str()); + } else { + LLAMA_LOG_WARN("%s: %s is not a valid file name. Using %s instead\n", func, filename, checkpoint_file.c_str()); + } + } + } auto save_bpw_state = [&](const std::vector & all_vec) { const std::string tmp = checkpoint_file + ".tmp"; @@ -2306,7 +2322,8 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.tensor_type =*/ nullptr, /*.prune_layers =*/ nullptr, /*.target_bpw =*/ -1.0f, - /*.keep_bpw_state =*/ false + /*.keep_bpw_state =*/ false, + /*.bpw_state =*/ nullptr }; return result; diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index e67649beb9..945acbe288 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -117,8 +117,8 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp [[noreturn]] static void usage(const char * executable) { - printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights]\n", executable); - printf(" [--target-bpw n] [--keep-bpw-state] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n"); + printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--target-bpw n]\n", executable); + printf(" [--bpw-state filename] [--keep-bpw-state] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n"); printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); printf(" --allow-requantize: allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); printf(" --leave-output-tensor: will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); @@ -128,13 +128,14 @@ static void usage(const char * executable) { printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n"); printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); - printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n"); + printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. Example: --tensor-type attn_q=q8_0\n"); printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n"); printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n"); printf(" Advanced option to remove all tensors from the given layers\n"); printf(" --target-bpw: target bits per weight (bpw). Must be a positive number between 0.0 and 8.0\n"); printf(" Advanced option to automatically select quantization types to achieve a total bits per weight (bpw) target\n"); - printf(" --keep-bpw-state: preserve the bpw computations in a state file\n"); + printf(" --keep-bpw-state: save the bpw computations to -.bpw_state\n"); + printf(" --bpw-state: file name to use instead of default\n"); printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); @@ -562,6 +563,12 @@ int main(int argc, char ** argv) { } } else if (strcmp(argv[arg_idx], "--keep-bpw-state") == 0) { params.keep_bpw_state = true; + } else if (strcmp(argv[arg_idx], "--bpw-state") == 0) { + if (arg_idx < argc-1) { + params.bpw_state = argv[++arg_idx]; + } else { + usage(argv[0]); + } } else if (strcmp(argv[arg_idx], "--prune-layers") == 0) { if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) { usage(argv[0]);