Add save_bpw_state()
This commit is contained in:
parent
533cda3076
commit
e48ca32f19
|
|
@ -734,6 +734,56 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
});
|
||||
};
|
||||
|
||||
// Saved state per tensor
|
||||
struct saved_info {
|
||||
std::vector<candidate_types> candidate;
|
||||
int choice = -1;
|
||||
float min_bpw = 0.0f;
|
||||
float max_bpw = 0.0f;
|
||||
size_t n_elements = 0;
|
||||
};
|
||||
|
||||
auto save_bpw_state = [&](const std::vector<tensor_info> & all_vec) {
|
||||
const std::string tmp = checkpoint_file + ".tmp";
|
||||
std::ofstream ofs(tmp, std::ios::binary | std::ios::trunc);
|
||||
if (!ofs) { return; } // best-effort
|
||||
const float target_bpw = params->target_bpw;
|
||||
const uint8_t bias_mode = params->no_bias ? 1 : 0;
|
||||
ofs.write((const char *)&file_magic, sizeof(file_magic));
|
||||
ofs.write((const char *)&target_bpw, sizeof(target_bpw));
|
||||
ofs.write((const char *)&bias_mode, sizeof(bias_mode));
|
||||
const uint64_t n = all_vec.size();
|
||||
ofs.write((const char *)&n, sizeof(n));
|
||||
for (const auto & ti : all_vec) {
|
||||
const std::string name = ggml_get_name(ti.w->tensor);
|
||||
const uint32_t len = (uint32_t)name.size();
|
||||
ofs.write((const char *)&len, sizeof(len));
|
||||
ofs.write(name.data(), len);
|
||||
|
||||
const uint64_t cn = ti.candidate.size();
|
||||
ofs.write((const char *)&cn, sizeof(cn));
|
||||
ofs.write((const char *)&ti.choice, sizeof(ti.choice));
|
||||
ofs.write((const char *)&ti.min_bpw, sizeof(ti.min_bpw));
|
||||
ofs.write((const char *)&ti.max_bpw, sizeof(ti.max_bpw));
|
||||
const uint64_t ne = ti.n_elements;
|
||||
ofs.write((const char *)&ne, sizeof(ne));
|
||||
|
||||
for (const auto & c : ti.candidate) {
|
||||
const int32_t t = c.type;
|
||||
const uint64_t b = c.bytes;
|
||||
ofs.write((const char *)&t, sizeof(t));
|
||||
ofs.write((const char *)&c.bpw, sizeof(c.bpw));
|
||||
ofs.write((const char *)&b, sizeof(b));
|
||||
ofs.write((const char *)&c.error, sizeof(c.error));
|
||||
}
|
||||
}
|
||||
|
||||
ofs.close();
|
||||
std::remove(checkpoint_file.c_str()); // TODO: handle errors
|
||||
std::rename(tmp.c_str(), checkpoint_file.c_str());
|
||||
LLAMA_LOG_INFO("%s: saved bpw progress for %lu tensors to %s\n", func, all_vec.size(), checkpoint_file.c_str());
|
||||
};
|
||||
|
||||
// Estimate error for a given type using a sampled subset of rows
|
||||
auto estimate_error = [&](const ggml_tensor * t,
|
||||
const ggml_type quant_type,
|
||||
|
|
|
|||
Loading…
Reference in New Issue