From 9bb8e17e04fe7c8940fc4ba7351e8b6d8764a8be Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Thu, 12 Mar 2026 19:55:56 +0000 Subject: [PATCH] Remove wce flag --- include/llama.h | 1 - src/llama-quant.cpp | 41 +++++++++++-------------------------- tools/quantize/quantize.cpp | 5 +++-- 3 files changed, 15 insertions(+), 32 deletions(-) diff --git a/include/llama.h b/include/llama.h index d811ae9814..947a22e0b0 100644 --- a/include/llama.h +++ b/include/llama.h @@ -403,7 +403,6 @@ extern "C" { bool save_state; // keep bpw state file void * state_file; // pointer to bpw state file float importance_pct; // identify up to pct% of tensors as important - bool use_wce; // optimize for WCE instead of MSE } llama_model_quantize_params; typedef struct llama_logit_bias { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 0e76bebc9e..fd0c5fe636 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -721,18 +722,10 @@ static std::unordered_map target_bpw_type( constexpr double EPSILON = 1e-12; constexpr double INFINITE = std::numeric_limits::infinity(); - constexpr uint32_t MSE_MAGIC = 0x4d534531; // MSE1 - constexpr uint32_t WCE_MAGIC = 0x57434531; // WCE1 + constexpr uint64_t STATE_MAGIC = 0x4250572d5631; // "BPW-V1" constexpr uint64_t HASH_MAGIC = 0xeabada55cafed00d; constexpr float penalty = 5.0f; const char * func = __func__; - const bool wce = params->use_wce; - const bool valid_wce = wce && activations_data && statistics_data != nullptr; - const uint32_t file_magic = valid_wce ? WCE_MAGIC : MSE_MAGIC; - - if (wce && !valid_wce) { - LLAMA_LOG_WARN("%s: WCE optimization requested but no activation or statistics data provided; using default MSE optimization.\n", func); - } // Tensor size in bytes for a given type auto tensor_bytes = [](const ggml_tensor * gt, const ggml_type gq) -> size_t { @@ -834,15 +827,15 @@ static std::unordered_map target_bpw_type( name.erase(0, name.find_last_of('/') + 1); std::replace(name.begin(), name.end(), ' ', '_'); name.empty() ? checkpoint_file = ml.arch_name : checkpoint_file = name; - checkpoint_file += "-" + std::string(hex) + (valid_wce ? "-wce" : "-mse") + ".bpw_state"; + checkpoint_file += "-" + std::string(hex) + ".bpw_state"; - if (params->state_file) { - const auto * filename = static_cast(params->state_file); + if (qs.params->state_file) { + const auto * filename = static_cast(qs.params->state_file); bool is_valid = false; if (std::ifstream(filename, std::ios::binary).good()) { is_valid = true; - } else if (params->save_state) { + } else if (qs.params->save_state) { std::ofstream ofs(filename, std::ios::binary | std::ios::app); if (ofs.is_open()) { is_valid = true; @@ -865,7 +858,7 @@ static std::unordered_map target_bpw_type( const std::string tmp = checkpoint_file + ".tmp"; std::ofstream ofs(tmp, std::ios::binary | std::ios::trunc); if (!ofs) { return; } - ofs.write((const char *)& file_magic, sizeof(file_magic)); + ofs.write((const char *)& STATE_MAGIC, sizeof(STATE_MAGIC)); ofs.write((const char *)& model_id, sizeof(model_id)); const uint64_t n = all_tensors.size(); ofs.write((const char *)& n, sizeof(n)); @@ -904,7 +897,7 @@ static std::unordered_map target_bpw_type( std::ifstream ifs(checkpoint_file, std::ios::binary); if (!ifs) { return {}; } - uint32_t magic = 0; + uint64_t magic = 0; uint64_t id = 0; ifs.read((char *)& magic, sizeof(magic)); ifs.read((char *)& id, sizeof(id)); @@ -913,9 +906,8 @@ static std::unordered_map target_bpw_type( return {}; } - if (magic != file_magic) { - LLAMA_LOG_WARN("%s: bpw state file mismatch (expected %s, got %s), ignoring\n", - func, file_magic == MSE_MAGIC ? "MSE" : "WCE", magic == MSE_MAGIC ? "MSE" : "WCE"); + if (magic != STATE_MAGIC) { + LLAMA_LOG_WARN("%s: invalid state file, ignoring\n", func); return {}; } @@ -950,9 +942,6 @@ static std::unordered_map target_bpw_type( ifs.read((char *)& b, sizeof(b)); cd.bytes = (size_t)b; ifs.read((char *)& cd.error, sizeof(cd.error)); - // Populate mse/wce for consistency, though optimization relies on s.error - if (valid_wce) { cd.wce = cd.error; } - else { cd.mse = cd.error; } } out.emplace(std::move(name), std::move(si)); @@ -2126,11 +2115,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (params->importance_pct != 0.0f) { LLAMA_LOG_INFO("%s: marking up to %.2f%% of tensors as important\n", __func__, params->importance_pct); } - if (params->use_wce) { - LLAMA_LOG_INFO("%s: using experimental Weighted Cosine Error (WCE) optimization\n", __func__); - } else { - LLAMA_LOG_INFO("%s: using default Weighted Mean Squared Error (MSE) optimization\n", __func__); - } if (params->target_size >= 0) { LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve file size %.2f MiB\n", __func__, (double)params->target_size / 1024.0 / 1024.0); } else { @@ -2138,7 +2122,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } // get quantization type overrides targeting a given bits per weight budget - bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, statistics_data, params, nthread); + bpw_overrides = target_bpw_type(ml, qs, tensors, mapped, values_data, activations_data, statistics_data, nthread); for (size_t i = 0; i < tensors.size(); ++i) { const std::string name = ggml_get_name(tensors[i]->tensor); auto it = bpw_overrides.find(name); @@ -2363,8 +2347,7 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.target_size =*/ -1, /*.save_state =*/ false, /*.state_file =*/ nullptr, - /*.importance_pct =*/ 0.0f, - /*.use_wce =*/ false + /*.importance_pct =*/ 0.0f }; return result; diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index f06a450ed8..39e976f8e5 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -3,8 +3,11 @@ #include "gguf.h" #include +#include #include +#include #include +#include #include #include #include @@ -710,8 +713,6 @@ int main(int argc, char ** argv) { if (arg_idx == argc-1 || !parse_target_size(argv[++arg_idx], target_size)) { usage(argv[0]); } - } else if (strcmp(argv[arg_idx], "--use-wce") == 0) { - params.use_wce = true; } else if (strcmp(argv[arg_idx], "--importance-pct") == 0) { if (arg_idx == argc-1 || !parse_importance_pct(argv[++arg_idx], importance_pct)) { usage(argv[0]);