Remove wce flag
This commit is contained in:
parent
0ccf5e5f21
commit
9bb8e17e04
|
|
@ -403,7 +403,6 @@ extern "C" {
|
|||
bool save_state; // keep bpw state file
|
||||
void * state_file; // pointer to bpw state file
|
||||
float importance_pct; // identify up to pct% of tensors as important
|
||||
bool use_wce; // optimize for WCE instead of MSE
|
||||
} llama_model_quantize_params;
|
||||
|
||||
typedef struct llama_logit_bias {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <cinttypes>
|
||||
#include <cmath>
|
||||
#include <csignal>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
#include <numeric>
|
||||
|
|
@ -721,18 +722,10 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
|
||||
constexpr double EPSILON = 1e-12;
|
||||
constexpr double INFINITE = std::numeric_limits<double>::infinity();
|
||||
constexpr uint32_t MSE_MAGIC = 0x4d534531; // MSE1
|
||||
constexpr uint32_t WCE_MAGIC = 0x57434531; // WCE1
|
||||
constexpr uint64_t STATE_MAGIC = 0x4250572d5631; // "BPW-V1"
|
||||
constexpr uint64_t HASH_MAGIC = 0xeabada55cafed00d;
|
||||
constexpr float penalty = 5.0f;
|
||||
const char * func = __func__;
|
||||
const bool wce = params->use_wce;
|
||||
const bool valid_wce = wce && activations_data && statistics_data != nullptr;
|
||||
const uint32_t file_magic = valid_wce ? WCE_MAGIC : MSE_MAGIC;
|
||||
|
||||
if (wce && !valid_wce) {
|
||||
LLAMA_LOG_WARN("%s: WCE optimization requested but no activation or statistics data provided; using default MSE optimization.\n", func);
|
||||
}
|
||||
|
||||
// Tensor size in bytes for a given type
|
||||
auto tensor_bytes = [](const ggml_tensor * gt, const ggml_type gq) -> size_t {
|
||||
|
|
@ -834,15 +827,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
name.erase(0, name.find_last_of('/') + 1);
|
||||
std::replace(name.begin(), name.end(), ' ', '_');
|
||||
name.empty() ? checkpoint_file = ml.arch_name : checkpoint_file = name;
|
||||
checkpoint_file += "-" + std::string(hex) + (valid_wce ? "-wce" : "-mse") + ".bpw_state";
|
||||
checkpoint_file += "-" + std::string(hex) + ".bpw_state";
|
||||
|
||||
if (params->state_file) {
|
||||
const auto * filename = static_cast<const char*>(params->state_file);
|
||||
if (qs.params->state_file) {
|
||||
const auto * filename = static_cast<const char*>(qs.params->state_file);
|
||||
bool is_valid = false;
|
||||
|
||||
if (std::ifstream(filename, std::ios::binary).good()) {
|
||||
is_valid = true;
|
||||
} else if (params->save_state) {
|
||||
} else if (qs.params->save_state) {
|
||||
std::ofstream ofs(filename, std::ios::binary | std::ios::app);
|
||||
if (ofs.is_open()) {
|
||||
is_valid = true;
|
||||
|
|
@ -865,7 +858,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const std::string tmp = checkpoint_file + ".tmp";
|
||||
std::ofstream ofs(tmp, std::ios::binary | std::ios::trunc);
|
||||
if (!ofs) { return; }
|
||||
ofs.write((const char *)& file_magic, sizeof(file_magic));
|
||||
ofs.write((const char *)& STATE_MAGIC, sizeof(STATE_MAGIC));
|
||||
ofs.write((const char *)& model_id, sizeof(model_id));
|
||||
const uint64_t n = all_tensors.size();
|
||||
ofs.write((const char *)& n, sizeof(n));
|
||||
|
|
@ -904,7 +897,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
std::ifstream ifs(checkpoint_file, std::ios::binary);
|
||||
if (!ifs) { return {}; }
|
||||
|
||||
uint32_t magic = 0;
|
||||
uint64_t magic = 0;
|
||||
uint64_t id = 0;
|
||||
ifs.read((char *)& magic, sizeof(magic));
|
||||
ifs.read((char *)& id, sizeof(id));
|
||||
|
|
@ -913,9 +906,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
return {};
|
||||
}
|
||||
|
||||
if (magic != file_magic) {
|
||||
LLAMA_LOG_WARN("%s: bpw state file mismatch (expected %s, got %s), ignoring\n",
|
||||
func, file_magic == MSE_MAGIC ? "MSE" : "WCE", magic == MSE_MAGIC ? "MSE" : "WCE");
|
||||
if (magic != STATE_MAGIC) {
|
||||
LLAMA_LOG_WARN("%s: invalid state file, ignoring\n", func);
|
||||
return {};
|
||||
}
|
||||
|
||||
|
|
@ -950,9 +942,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
ifs.read((char *)& b, sizeof(b));
|
||||
cd.bytes = (size_t)b;
|
||||
ifs.read((char *)& cd.error, sizeof(cd.error));
|
||||
// Populate mse/wce for consistency, though optimization relies on s.error
|
||||
if (valid_wce) { cd.wce = cd.error; }
|
||||
else { cd.mse = cd.error; }
|
||||
}
|
||||
|
||||
out.emplace(std::move(name), std::move(si));
|
||||
|
|
@ -2126,11 +2115,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
if (params->importance_pct != 0.0f) {
|
||||
LLAMA_LOG_INFO("%s: marking up to %.2f%% of tensors as important\n", __func__, params->importance_pct);
|
||||
}
|
||||
if (params->use_wce) {
|
||||
LLAMA_LOG_INFO("%s: using experimental Weighted Cosine Error (WCE) optimization\n", __func__);
|
||||
} else {
|
||||
LLAMA_LOG_INFO("%s: using default Weighted Mean Squared Error (MSE) optimization\n", __func__);
|
||||
}
|
||||
if (params->target_size >= 0) {
|
||||
LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve file size %.2f MiB\n", __func__, (double)params->target_size / 1024.0 / 1024.0);
|
||||
} else {
|
||||
|
|
@ -2138,7 +2122,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
}
|
||||
|
||||
// get quantization type overrides targeting a given bits per weight budget
|
||||
bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, statistics_data, params, nthread);
|
||||
bpw_overrides = target_bpw_type(ml, qs, tensors, mapped, values_data, activations_data, statistics_data, nthread);
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
const std::string name = ggml_get_name(tensors[i]->tensor);
|
||||
auto it = bpw_overrides.find(name);
|
||||
|
|
@ -2363,8 +2347,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
|
|||
/*.target_size =*/ -1,
|
||||
/*.save_state =*/ false,
|
||||
/*.state_file =*/ nullptr,
|
||||
/*.importance_pct =*/ 0.0f,
|
||||
/*.use_wce =*/ false
|
||||
/*.importance_pct =*/ 0.0f
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
|
|||
|
|
@ -3,8 +3,11 @@
|
|||
#include "gguf.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <clocale>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
|
|
@ -710,8 +713,6 @@ int main(int argc, char ** argv) {
|
|||
if (arg_idx == argc-1 || !parse_target_size(argv[++arg_idx], target_size)) {
|
||||
usage(argv[0]);
|
||||
}
|
||||
} else if (strcmp(argv[arg_idx], "--use-wce") == 0) {
|
||||
params.use_wce = true;
|
||||
} else if (strcmp(argv[arg_idx], "--importance-pct") == 0) {
|
||||
if (arg_idx == argc-1 || !parse_importance_pct(argv[++arg_idx], importance_pct)) {
|
||||
usage(argv[0]);
|
||||
|
|
|
|||
Loading…
Reference in New Issue