diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 3080b0ed71..4d0dc6a36e 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -672,7 +672,9 @@ static std::unordered_map target_bpw_type( constexpr double epsilon = 1e-12; constexpr double infinity = std::numeric_limits::infinity(); + constexpr uint32_t file_magic = 0x42505731; // BPW1 const char * func = __func__; + const std::string checkpoint_file = ml.arch_name + ".bpw_state"; auto tensor_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t { const int64_t n_per_row = t->ne[0]; @@ -784,6 +786,68 @@ static std::unordered_map target_bpw_type( LLAMA_LOG_INFO("%s: saved bpw progress for %lu tensors to %s\n", func, all_vec.size(), checkpoint_file.c_str()); }; + auto load_bpw_state = [&]() -> std::unordered_map { + std::unordered_map out; + std::ifstream ifs(checkpoint_file, std::ios::binary); + if (!ifs) { return out; } + + uint32_t magic = 0; + float target_bpw = 0.0f; + uint8_t bias_mode = 0; + ifs.read((char *)&magic, sizeof(magic)); + ifs.read((char *)&target_bpw, sizeof(target_bpw)); + ifs.read((char *)&bias_mode, sizeof(bias_mode)); + if (magic != file_magic) { + LLAMA_LOG_WARN("%s: invalid resume file, ignoring: %s\n", func, checkpoint_file.c_str()); + return out; + } + if (target_bpw != params->target_bpw) { + LLAMA_LOG_WARN("%s: target bpw of %f does not match %f, ignoring: %s\n", func, params->target_bpw, target_bpw, checkpoint_file.c_str()); + return out; + } + if (bias_mode != (params->no_bias ? 1 : 0)) { + LLAMA_LOG_WARN("%s: bias mode does not match, ignoring: %s\n", func, checkpoint_file.c_str()); + return out; + } + + uint64_t n = 0; + ifs.read((char *)&n, sizeof(n)); + for (uint64_t i = 0; i < n; ++i) { + uint32_t len = 0; + ifs.read((char *)&len, sizeof(len)); + std::string name(len, '\0'); + ifs.read(name.data(), len); + + uint64_t cn = 0; + ifs.read((char *)&cn, sizeof(cn)); + + saved_info si; + ifs.read((char *)&si.choice, sizeof(si.choice)); + ifs.read((char *)&si.min_bpw, sizeof(si.min_bpw)); + ifs.read((char *)&si.max_bpw, sizeof(si.max_bpw)); + uint64_t ne = 0; + ifs.read((char *)&ne, sizeof(ne)); + si.n_elements = (size_t)ne; + + si.candidate.resize(cn); + for (size_t j = 0; j < si.candidate.size(); ++j) { + int32_t t = 0; + uint64_t b = 0; + ifs.read((char *)&t, sizeof(t)); + si.candidate[j].type = (ggml_type)t; + ifs.read((char *)&si.candidate[j].bpw, sizeof(si.candidate[j].bpw)); + ifs.read((char *)&b, sizeof(b)); + si.candidate[j].bytes = (size_t)b; + ifs.read((char *)&si.candidate[j].error, sizeof(si.candidate[j].error)); + } + + out.emplace(std::move(name), std::move(si)); + } + + LLAMA_LOG_INFO("%s: loaded bpw state for %lu tensors from %s\n", func, out.size(), checkpoint_file.c_str()); + return out; + }; + // Estimate error for a given type using a sampled subset of rows auto estimate_error = [&](const ggml_tensor * t, const ggml_type quant_type,