Minor refactoring

This commit is contained in:
Ed Addario 2025-10-16 15:11:48 +01:00
parent 0b3e930d52
commit a5103933bb
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 33 additions and 18 deletions

View File

@ -647,7 +647,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
std::signal(SIGINT, prev_int);
std::signal(SIGTERM, prev_term);
}
} _signal_guard;
} signal_guard;
struct candidate_types {
ggml_type type;
@ -683,7 +683,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
GGML_TYPE_Q8_0,
#ifdef GGML_USE_METAL
GGML_TYPE_F16
#else
GGML_TYPE_BF16
#endif
};
constexpr double epsilon = 1e-12;
@ -1004,17 +1008,30 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// Dequantize into dequantized_buffer
{
const ggml_type_traits * traits = ggml_get_type_traits(quant_type);
if (!traits || !traits->to_float) {
if (out_mse) { *out_mse = infinity; }
if (out_proj) { *out_proj = 0.0; }
return infinity;
}
for (size_t r = 0; r < sample_rows; ++r) {
const uint8_t * src = quantized_buffer.data() + r * row_sz;
float * dst = dequantized_buffer.data() + r * (size_t)n_per_row;
traits->to_float(src, dst, (int)n_per_row);
if (quant_type == GGML_TYPE_F16) {
for (size_t r = 0; r < sample_rows; ++r) {
auto src = (const ggml_fp16_t *)(quantized_buffer.data() + r * row_sz);
float * dst = dequantized_buffer.data() + r * (size_t)n_per_row;
ggml_fp16_to_fp32_row(src, dst, (int)n_per_row);
}
} else if (quant_type == GGML_TYPE_BF16) {
for (size_t r = 0; r < sample_rows; ++r) {
auto src = (const ggml_bf16_t *)(quantized_buffer.data() + r * row_sz);
float * dst = dequantized_buffer.data() + r * (size_t)n_per_row;
ggml_bf16_to_fp32_row(src, dst, (int)n_per_row);
}
} else {
const ggml_type_traits * traits = ggml_get_type_traits(quant_type);
if (!traits || !traits->to_float) {
if (out_mse) { *out_mse = infinity; }
if (out_proj) { *out_proj = 0.0; }
return infinity;
}
for (size_t r = 0; r < sample_rows; ++r) {
const uint8_t * src = quantized_buffer.data() + r * row_sz;
float * dst = dequantized_buffer.data() + r * (size_t)n_per_row;
traits->to_float(src, dst, (int)n_per_row);
}
}
}
@ -1500,13 +1517,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
// Compute total elements across all tensors and bytes for non-quantizable tensors
size_t nq_elements = 0;
size_t nq_bytes = 0;
for (const auto & it : ml.weights_map) {
const ggml_tensor * tensor = it.second.tensor;
const std::string name = it.first;
for (const auto * it : tensors) {
const ggml_tensor * tensor = it->tensor;
const std::string name = ggml_get_name(tensor);
nq_elements += (size_t)ggml_nelements(tensor);
if (!is_quantizable(name, model.arch, params)) {
nq_bytes += ggml_nbytes(tensor);
}
if (!can_quantize(tensor)) { nq_bytes += ggml_nbytes(tensor); }
}
auto total_bytes = [&]() -> size_t {