From e48ca32f19095ba0c47058dc7a703c1bb52977e0 Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Sun, 5 Oct 2025 20:17:27 +0100 Subject: [PATCH] Add save_bpw_state() --- src/llama-quant.cpp | 50 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 50c8dbf423..3080b0ed71 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -734,6 +734,56 @@ static std::unordered_map target_bpw_type( }); }; + // Saved state per tensor + struct saved_info { + std::vector candidate; + int choice = -1; + float min_bpw = 0.0f; + float max_bpw = 0.0f; + size_t n_elements = 0; + }; + + auto save_bpw_state = [&](const std::vector & all_vec) { + const std::string tmp = checkpoint_file + ".tmp"; + std::ofstream ofs(tmp, std::ios::binary | std::ios::trunc); + if (!ofs) { return; } // best-effort + const float target_bpw = params->target_bpw; + const uint8_t bias_mode = params->no_bias ? 1 : 0; + ofs.write((const char *)&file_magic, sizeof(file_magic)); + ofs.write((const char *)&target_bpw, sizeof(target_bpw)); + ofs.write((const char *)&bias_mode, sizeof(bias_mode)); + const uint64_t n = all_vec.size(); + ofs.write((const char *)&n, sizeof(n)); + for (const auto & ti : all_vec) { + const std::string name = ggml_get_name(ti.w->tensor); + const uint32_t len = (uint32_t)name.size(); + ofs.write((const char *)&len, sizeof(len)); + ofs.write(name.data(), len); + + const uint64_t cn = ti.candidate.size(); + ofs.write((const char *)&cn, sizeof(cn)); + ofs.write((const char *)&ti.choice, sizeof(ti.choice)); + ofs.write((const char *)&ti.min_bpw, sizeof(ti.min_bpw)); + ofs.write((const char *)&ti.max_bpw, sizeof(ti.max_bpw)); + const uint64_t ne = ti.n_elements; + ofs.write((const char *)&ne, sizeof(ne)); + + for (const auto & c : ti.candidate) { + const int32_t t = c.type; + const uint64_t b = c.bytes; + ofs.write((const char *)&t, sizeof(t)); + ofs.write((const char *)&c.bpw, sizeof(c.bpw)); + ofs.write((const char *)&b, sizeof(b)); + ofs.write((const char *)&c.error, sizeof(c.error)); + } + } + + ofs.close(); + std::remove(checkpoint_file.c_str()); // TODO: handle errors + std::rename(tmp.c_str(), checkpoint_file.c_str()); + LLAMA_LOG_INFO("%s: saved bpw progress for %lu tensors to %s\n", func, all_vec.size(), checkpoint_file.c_str()); + }; + // Estimate error for a given type using a sampled subset of rows auto estimate_error = [&](const ggml_tensor * t, const ggml_type quant_type,