Remove wce flag

This commit is contained in:
Ed Addario 2026-03-12 19:55:56 +00:00
parent 0ccf5e5f21
commit 9bb8e17e04
No known key found for this signature in database
GPG Key ID: E7875815A3230993
3 changed files with 15 additions and 32 deletions

View File

@ -403,7 +403,6 @@ extern "C" {
bool save_state; // keep bpw state file
void * state_file; // pointer to bpw state file
float importance_pct; // identify up to pct% of tensors as important
bool use_wce; // optimize for WCE instead of MSE
} llama_model_quantize_params;
typedef struct llama_logit_bias {

View File

@ -8,6 +8,7 @@
#include <cinttypes>
#include <cmath>
#include <csignal>
#include <cstring>
#include <fstream>
#include <mutex>
#include <numeric>
@ -721,18 +722,10 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
constexpr double EPSILON = 1e-12;
constexpr double INFINITE = std::numeric_limits<double>::infinity();
constexpr uint32_t MSE_MAGIC = 0x4d534531; // MSE1
constexpr uint32_t WCE_MAGIC = 0x57434531; // WCE1
constexpr uint64_t STATE_MAGIC = 0x4250572d5631; // "BPW-V1"
constexpr uint64_t HASH_MAGIC = 0xeabada55cafed00d;
constexpr float penalty = 5.0f;
const char * func = __func__;
const bool wce = params->use_wce;
const bool valid_wce = wce && activations_data && statistics_data != nullptr;
const uint32_t file_magic = valid_wce ? WCE_MAGIC : MSE_MAGIC;
if (wce && !valid_wce) {
LLAMA_LOG_WARN("%s: WCE optimization requested but no activation or statistics data provided; using default MSE optimization.\n", func);
}
// Tensor size in bytes for a given type
auto tensor_bytes = [](const ggml_tensor * gt, const ggml_type gq) -> size_t {
@ -834,15 +827,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
name.erase(0, name.find_last_of('/') + 1);
std::replace(name.begin(), name.end(), ' ', '_');
name.empty() ? checkpoint_file = ml.arch_name : checkpoint_file = name;
checkpoint_file += "-" + std::string(hex) + (valid_wce ? "-wce" : "-mse") + ".bpw_state";
checkpoint_file += "-" + std::string(hex) + ".bpw_state";
if (params->state_file) {
const auto * filename = static_cast<const char*>(params->state_file);
if (qs.params->state_file) {
const auto * filename = static_cast<const char*>(qs.params->state_file);
bool is_valid = false;
if (std::ifstream(filename, std::ios::binary).good()) {
is_valid = true;
} else if (params->save_state) {
} else if (qs.params->save_state) {
std::ofstream ofs(filename, std::ios::binary | std::ios::app);
if (ofs.is_open()) {
is_valid = true;
@ -865,7 +858,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const std::string tmp = checkpoint_file + ".tmp";
std::ofstream ofs(tmp, std::ios::binary | std::ios::trunc);
if (!ofs) { return; }
ofs.write((const char *)& file_magic, sizeof(file_magic));
ofs.write((const char *)& STATE_MAGIC, sizeof(STATE_MAGIC));
ofs.write((const char *)& model_id, sizeof(model_id));
const uint64_t n = all_tensors.size();
ofs.write((const char *)& n, sizeof(n));
@ -904,7 +897,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
std::ifstream ifs(checkpoint_file, std::ios::binary);
if (!ifs) { return {}; }
uint32_t magic = 0;
uint64_t magic = 0;
uint64_t id = 0;
ifs.read((char *)& magic, sizeof(magic));
ifs.read((char *)& id, sizeof(id));
@ -913,9 +906,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
return {};
}
if (magic != file_magic) {
LLAMA_LOG_WARN("%s: bpw state file mismatch (expected %s, got %s), ignoring\n",
func, file_magic == MSE_MAGIC ? "MSE" : "WCE", magic == MSE_MAGIC ? "MSE" : "WCE");
if (magic != STATE_MAGIC) {
LLAMA_LOG_WARN("%s: invalid state file, ignoring\n", func);
return {};
}
@ -950,9 +942,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
ifs.read((char *)& b, sizeof(b));
cd.bytes = (size_t)b;
ifs.read((char *)& cd.error, sizeof(cd.error));
// Populate mse/wce for consistency, though optimization relies on s.error
if (valid_wce) { cd.wce = cd.error; }
else { cd.mse = cd.error; }
}
out.emplace(std::move(name), std::move(si));
@ -2126,11 +2115,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
if (params->importance_pct != 0.0f) {
LLAMA_LOG_INFO("%s: marking up to %.2f%% of tensors as important\n", __func__, params->importance_pct);
}
if (params->use_wce) {
LLAMA_LOG_INFO("%s: using experimental Weighted Cosine Error (WCE) optimization\n", __func__);
} else {
LLAMA_LOG_INFO("%s: using default Weighted Mean Squared Error (MSE) optimization\n", __func__);
}
if (params->target_size >= 0) {
LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve file size %.2f MiB\n", __func__, (double)params->target_size / 1024.0 / 1024.0);
} else {
@ -2138,7 +2122,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
}
// get quantization type overrides targeting a given bits per weight budget
bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, statistics_data, params, nthread);
bpw_overrides = target_bpw_type(ml, qs, tensors, mapped, values_data, activations_data, statistics_data, nthread);
for (size_t i = 0; i < tensors.size(); ++i) {
const std::string name = ggml_get_name(tensors[i]->tensor);
auto it = bpw_overrides.find(name);
@ -2363,8 +2347,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
/*.target_size =*/ -1,
/*.save_state =*/ false,
/*.state_file =*/ nullptr,
/*.importance_pct =*/ 0.0f,
/*.use_wce =*/ false
/*.importance_pct =*/ 0.0f
};
return result;

View File

@ -3,8 +3,11 @@
#include "gguf.h"
#include <algorithm>
#include <cctype>
#include <clocale>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <map>
@ -710,8 +713,6 @@ int main(int argc, char ** argv) {
if (arg_idx == argc-1 || !parse_target_size(argv[++arg_idx], target_size)) {
usage(argv[0]);
}
} else if (strcmp(argv[arg_idx], "--use-wce") == 0) {
params.use_wce = true;
} else if (strcmp(argv[arg_idx], "--importance-pct") == 0) {
if (arg_idx == argc-1 || !parse_importance_pct(argv[++arg_idx], importance_pct)) {
usage(argv[0]);