Minor refactoring

This commit is contained in:
Ed Addario 2025-10-28 15:22:32 +00:00
parent 5303212324
commit f8863b9a80
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 23 additions and 25 deletions

View File

@ -694,6 +694,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
constexpr double epsilon = 1e-12;
constexpr double infinity = std::numeric_limits<double>::infinity();
constexpr uint32_t file_magic = 0x42505731; // BPW1
constexpr uint64_t arbitrary_magic = 0xeabada55cafed00d;
const char * func = __func__;
auto tensor_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t {
@ -731,7 +732,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type {
if (is_compatible(t, typ)) { return typ; }
ggml_type fb = fallback_type(typ);
const ggml_type fb = fallback_type(typ);
return is_compatible(t, fb) ? fb : GGML_TYPE_F16;
};
@ -754,7 +755,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
for (size_t i = 0; i < n; ++i) {
h = (h << 5) + h + data[i];
}
return h ? h : 0xeabada55cafed00d;
return h ? h : arbitrary_magic;
};
auto metadata_id = [&](const gguf_context * ctx) -> uint64_t {
@ -795,7 +796,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
ofs.write((const char *)&n, sizeof(n));
for (const auto & ti : all_vec) {
const std::string name = ggml_get_name(ti.w->tensor);
const uint32_t len = (uint32_t)name.size();
const auto len = (uint32_t)name.size();
ofs.write((const char *)&len, sizeof(len));
ofs.write(name.data(), len);
@ -835,13 +836,14 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (magic != file_magic) {
LLAMA_LOG_WARN("%s: invalid resume file, ignoring: %s\n", func, checkpoint_file.c_str());
return out;
} else if (id != model_id) {
}
if (id != model_id) {
LLAMA_LOG_WARN("%s: model ID mismatch, ignoring: %s\n", func, checkpoint_file.c_str());
return out;
} else {
LLAMA_LOG_INFO("%s: state file found, resuming tensor quantization\n", func);
}
LLAMA_LOG_INFO("%s: state file found, resuming tensor quantization\n", func);
uint64_t n = 0;
ifs.read((char *)&n, sizeof(n));
for (uint64_t i = 0; i < n; ++i) {
@ -862,15 +864,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
si.n_elements = (size_t)ne;
si.candidate.resize(cn);
for (size_t j = 0; j < si.candidate.size(); ++j) {
for (auto & s : si.candidate) {
int32_t t = 0;
uint64_t b = 0;
ifs.read((char *)&t, sizeof(t));
si.candidate[j].type = (ggml_type)t;
ifs.read((char *)&si.candidate[j].bpw, sizeof(si.candidate[j].bpw));
s.type = (ggml_type)t;
ifs.read((char *)&s.bpw, sizeof(s.bpw));
ifs.read((char *)&b, sizeof(b));
si.candidate[j].bytes = (size_t)b;
ifs.read((char *)&si.candidate[j].error, sizeof(si.candidate[j].error));
s.bytes = (size_t)b;
ifs.read((char *)&s.error, sizeof(s.error));
}
out.emplace(std::move(name), std::move(si));
@ -886,7 +888,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
LLAMA_LOG_INFO("%s: deleting %s\n", func, checkpoint_file.c_str());
std::remove(checkpoint_file.c_str());
}
};
auto check_signal_handler = [&](const std::vector<tensor_info> & all_vec) {
@ -1198,10 +1199,10 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// Compute rows based on tensor shape and slice count
auto sample_rows = [](const int64_t n, const int64_t rows, const int64_t n2, const bool has_acts) -> int64_t {
const double tensor_budget = has_acts ? 1 * 1024 * 1024 : 0.5 * 1024 * 1024;
const double scale_rows = std::clamp(std::sqrt(std::max(1.0, (double)rows) / 4096.0), 0.5, 2.0); // favour more rows for large nrt
const double scale_rows = std::clamp(std::sqrt(std::max(1.0, (double)rows) / 4096.0), 0.5, 2.0); // favour more rows for large tensors
const double slice_budget = tensor_budget * scale_rows / std::max<int64_t>(1, n2);
const int64_t min_rows = has_acts ? 128 : 64;
const int64_t max_rows = 4096;
constexpr int64_t max_rows = 4096; // row limit to avoid excessive memory use
int64_t total_rows = std::llround(slice_budget / std::max<int64_t>(1, n));
total_rows = std::max<int64_t>(min_rows, std::min<int64_t>(total_rows, std::min<int64_t>(rows, max_rows)));
if (rows <= min_rows * 2) { total_rows = rows; }
@ -1246,7 +1247,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
f32_sample.clear();
std::vector<float> row_buffer(n_per_row);
for (int64_t slice = 0; slice < ne2; ++slice) {
std::mt19937 rng(std::hash<std::string>{}(name) ^ 0xeabada55cafed00d ^ slice);
std::mt19937 rng(std::hash<std::string>{}(name) ^ arbitrary_magic ^ slice);
const int64_t rows_sample_max = std::max<int64_t>(1, std::min<int64_t>(nrows_total, rows_sample_per_expert));
const int64_t stride = std::max<int64_t>(1, nrows_total / rows_sample_max);
int64_t offset = 0;
@ -1411,8 +1412,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (c.bytes == 0) { continue; }
const double final_err = bias_needed ? c.error : c.mse;
info.candidate.push_back(candidate_types{ c.type, c.bpw, c.bytes, final_err, c.mse, c.proj });
// LLAMA_LOG_INFO("\t%s: %35s \t%10s \t%1.4f bpw \t%10zu bytes \t mse: %1.8e \t err: %1.8e\n",
// func, name.c_str(), ggml_type_name(c.type), c.bpw, c.bytes, c.mse, final_err);
}
if (info.candidate.empty()) {
@ -1445,16 +1444,15 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
if (candidates.size() < 3) { return; } // need at least 3 points to do convex hull
// Convex hull (lower envelope)
auto cross_product = [](const candidate_types & h0, const candidate_types & h1, const candidate_types & p) -> double {
const double dx1 = (double)h1.bytes - (double)h0.bytes;
const double dy1 = h1.error - h0.error;
const double dx2 = (double)p.bytes - (double)h0.bytes;
const double dy2 = p.error - h0.error;
return dx1 * dy2 - dx2 * dy1;
};
std::vector<candidate_types> hull; hull.reserve(candidates.size());
for (const auto & c : candidates) {
auto cross_product = [](const candidate_types & h0, const candidate_types & h1, const candidate_types & p) -> double {
const double dx1 = (double)h1.bytes - (double)h0.bytes;
const double dy1 = h1.error - h0.error;
const double dx2 = (double)p.bytes - (double)h0.bytes;
const double dy2 = p.error - h0.error;
return dx1 * dy2 - dx2 * dy1;
};
while (hull.size() >= 2) {
if (cross_product(hull[hull.size() - 2], hull[hull.size() - 1], c) <= epsilon) {
hull.pop_back();