diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 03f8a4bd11..85191a66ae 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -596,10 +596,7 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * return new_size; } -// Returns per-tensor overrides of quantization types to meet target BPW with the lowest ppl -// sample_rows_per_expert: Larger values will result in more accurate error estimates, but will take longer to compute -// bias_lambda: Affects the weight of the bias term in the weigthed MSE error function. 0.0 means no bias (standard MSE), -// 1.0 means equal weight for bias and error, 2.0 means twice as much weight for bias +// Returns per-tensor type overrides to meet target BPW at lowest ppl static std::unordered_map target_bpw_type( llama_model_loader & ml, std::vector> & buffer, @@ -609,9 +606,7 @@ static std::unordered_map target_bpw_type( const std::unordered_map> * values_data, const std::unordered_map> * activations_data, const llama_model_quantize_params * params, - int nthread, - int sample_rows_per_expert = 512, - float bias_lambda = 1.0 + int nthread ) { struct candidate_types { ggml_type type; @@ -621,15 +616,15 @@ static std::unordered_map target_bpw_type( }; struct tensor_info { - const llama_model_loader::llama_tensor_weight * w; - std::vector candidate; + const llama_model_loader::llama_tensor_weight * w = nullptr; + std::vector candidate = {}; int choice = -1; float min_bpw = 0.0; float max_bpw = 0.0; size_t n_elements = 0; }; - const ggml_type k_candidates[] = { + constexpr ggml_type k_quants[] = { GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_0, @@ -648,7 +643,7 @@ static std::unordered_map target_bpw_type( #endif }; - const ggml_type iq_candidates[] = { + constexpr ggml_type iq_quants[] = { GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ2_XXS, @@ -665,9 +660,49 @@ static std::unordered_map target_bpw_type( GGML_TYPE_Q6_K }; - auto name_tn = LLM_TN(model.arch); - float target_bpw = params->target_bpw; + auto get_values = [&](const std::string & tensor_name) -> const float * { + if (!values_data) { return nullptr; } + const auto it = values_data->find(remap_imatrix(tensor_name, mapped)); + if (it == values_data->end()) { return nullptr; } + return it->second.data(); + }; + auto get_activations = [&](const std::string & tensor_name) -> const float * { + if (!activations_data) { return nullptr; } + const auto it = activations_data->find(remap_imatrix(tensor_name, mapped)); + if (it == activations_data->end()) { return nullptr; } + return it->second.data(); + }; + + auto tensor_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t { + const int64_t n_per_row = t->ne[0]; + const int64_t nrows = t->ne[1]; + const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; + const size_t row_sz = ggml_row_size(typ, n_per_row); + return (size_t)ne2 * (size_t)nrows * row_sz; + }; + + auto tensor_bpw = [&](const ggml_tensor * t, const ggml_type typ) -> double { + const int64_t nelem = ggml_nelements(t); + const size_t bytes = tensor_bytes(t, typ); + return (double)bytes * 8.0 / (double)nelem; + }; + + auto is_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> bool { + const int64_t n_per_row = t->ne[0]; + const int64_t blck = ggml_blck_size(typ); + if (blck <= 1) { return true; } + return n_per_row % blck == 0; + }; + + auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type { + if (is_compatible(t, typ)) { return typ; } + ggml_type fb = fallback_type(typ); + if (is_compatible(t, fb)) { return fb; } + return GGML_TYPE_F16; + }; + + auto name_tn = LLM_TN(model.arch); auto can_quantize = [&](const ggml_tensor * t) -> bool { const std::string name = ggml_get_name(t); bool q = name.rfind("weight") == name.size() - 6; @@ -705,231 +740,182 @@ static std::unordered_map target_bpw_type( return q; }; - auto get_values = [&](const std::string & tensor_name) -> const float * { - if (!values_data) { return nullptr; } - const auto it = values_data->find(remap_imatrix(tensor_name, mapped)); - if (it == values_data->end()) { return nullptr; } - return it->second.data(); - }; - - auto get_activations = [&](const std::string & tensor_name) -> const float * { - if (!activations_data) { return nullptr; } - const auto it = activations_data->find(remap_imatrix(tensor_name, mapped)); - if (it == activations_data->end()) { return nullptr; } - return it->second.data(); - }; - - auto total_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t { - const int64_t n_per_row = t->ne[0]; - const int64_t nrows = t->ne[1]; - const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; - const size_t row_sz = ggml_row_size(typ, n_per_row); - return (size_t)ne2 * (size_t)nrows * row_sz; - }; - - auto tensor_bpw = [&](const ggml_tensor * t, const ggml_type typ) -> double { - const int64_t nelem = ggml_nelements(t); - const size_t bytes = total_bytes(t, typ); - return bytes * 8.0 / nelem; - }; - - auto is_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> bool { - const int64_t n_per_row = t->ne[0]; - const int64_t blck = ggml_blck_size(typ); - if (blck <= 1) { return true; } - return n_per_row % blck == 0; - }; - - auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type { - if (is_compatible(t, typ)) { return typ; } - ggml_type fb = fallback_type(typ); - if (is_compatible(t, fb)) { return fb; } - return GGML_TYPE_F16; - }; - // Estimate error for a given type using a sampled subset of rows auto estimate_error = [&](const ggml_tensor * t, - const ggml_type typ, + const ggml_type quant_type, const std::vector & f32_sample, const std::vector & sample_rows_per_slice, const float * values_sample, const float * activations_sample, - std::vector & qbuf, - std::vector & deq) -> double + std::vector & quantized_buffer, + std::vector & dequantized_buffer) -> double { const int64_t n_per_row = t->ne[0]; const int64_t nrows = t->ne[1]; const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; - const size_t nels = f32_sample.size(); - const size_t total_sampled_rows = nels / (size_t)n_per_row; - if (total_sampled_rows == 0) { return 0.0; } + const size_t sample_element_count = f32_sample.size(); + const size_t sample_row_count = sample_element_count / (size_t)n_per_row; + if (sample_row_count == 0) { return 0.0; } - const size_t row_sz = ggml_row_size(typ, n_per_row); - const size_t need_q = row_sz * total_sampled_rows; - if (qbuf.size() < need_q) { qbuf.resize(need_q); } - if (deq.size() < nels) { deq.resize(nels); } + const size_t row_size = ggml_row_size(quant_type, n_per_row); + const size_t buffer_size = row_size * sample_row_count; + if (quantized_buffer.size() < buffer_size) { quantized_buffer.resize(buffer_size); } + if (dequantized_buffer.size() < sample_element_count) { dequantized_buffer.resize(sample_element_count); } - // Precompute denominators: - // - x2_per_row: sum_j w[j]*x[j]^2 if w present else sum_j x[j]^2 - // - bden_per_slice: sum_j w[j]*a[j]^2 if w & a present; sum_j a[j]^2 if only a present; 0 otherwise - std::vector x2_per_row(total_sampled_rows, 0.0); - std::vector bden_per_slice(ne2, 0.0); + std::vector row_sq_norm(sample_row_count, 0.0); + std::vector bias_denominator_per_slice(ne2, 0.0); - const bool has_w = (values_sample != nullptr); - const bool has_a = (activations_sample != nullptr); - - // Precompute bden per slice (depends only on w,a) - if (has_a) { + // Precompute bias denominator per slice + const bool has_values = (values_sample != nullptr); + const bool has_activations = (activations_sample != nullptr); + if (has_activations) { for (int64_t s = 0; s < ne2; ++s) { - const float * wv = has_w ? values_sample + s * n_per_row : nullptr; - const float * act = activations_sample + s * n_per_row; - double bden = 0.0; - if (has_w) { + const float * values = has_values ? values_sample + s * n_per_row : nullptr; + const float * activations = activations_sample + s * n_per_row; + double bias_denominator = 0.0; + if (has_values) { for (int64_t j = 0; j < n_per_row; ++j) { - const double a = act[j]; - bden += (double) wv[j] * a * a; + const double a = activations[j]; + bias_denominator += values[j] * a * a; } } else { for (int64_t j = 0; j < n_per_row; ++j) { - const double a = act[j]; - bden += a * a; + const double a = activations[j]; + bias_denominator += a * a; } } - bden_per_slice[s] = bden; + bias_denominator_per_slice[s] = bias_denominator; } } - // Precompute x2 per sampled row + // Compute squared norms of sampled rows { - size_t off = 0; + size_t offset = 0; size_t row_idx = 0; for (int64_t s = 0; s < ne2; ++s) { const int64_t rs = sample_rows_per_slice[s]; if (rs == 0) { continue; } - const float * wv = has_w ? values_sample + s * n_per_row : nullptr; + const float * values = has_values ? values_sample + s * n_per_row : nullptr; for (int64_t r = 0; r < rs; ++r, ++row_idx) { - const float * x = f32_sample.data() + off; - double x2 = 0.0; - if (has_w) { + const float * row = f32_sample.data() + offset; + double rsn = 0.0; + if (has_values) { for (int64_t j = 0; j < n_per_row; ++j) { - const double w = wv[j]; - const double xx = x[j]; - x2 += w * xx * xx; + const double v = values[j]; + const double x = row[j]; + rsn += v * x * x; } } else { for (int64_t j = 0; j < n_per_row; ++j) { - const double xx = x[j]; - x2 += xx * xx; + const double x = row[j]; + rsn += x * x; } } - x2_per_row[row_idx] = x2; - off += (size_t)n_per_row; + row_sq_norm[row_idx] = rsn; + offset += (size_t)n_per_row; } } } - // Quantize sampled rows slice-by-slice into qbuf - size_t qoff = 0; - size_t foff = 0; + // Quantize sampled rows slice-by-slice into quantized_buffer + size_t quantised_offset = 0; + size_t floats_offset = 0; for (int64_t slice = 0; slice < ne2; ++slice) { const int64_t rs = sample_rows_per_slice[slice]; if (rs == 0) { continue; } const float * value = values_sample ? values_sample + slice * n_per_row : nullptr; - (void)ggml_quantize_chunk(typ, f32_sample.data() + foff, qbuf.data() + qoff, 0, rs, n_per_row, value); + (void)ggml_quantize_chunk(quant_type, f32_sample.data() + floats_offset, quantized_buffer.data() + quantised_offset, 0, rs, n_per_row, value); - qoff += row_sz * (size_t)rs; - foff += (size_t)rs * (size_t)n_per_row; + quantised_offset += row_size * (size_t)rs; + floats_offset += (size_t)rs * (size_t)n_per_row; } - // Dequantize into deq (row-wise if needed to avoid int overflow) + // Dequantize into dequantized_buffer { - const ggml_type_traits * traits = ggml_get_type_traits(typ); - if (typ == GGML_TYPE_F16) { - ggml_fp16_to_fp32_row((const ggml_fp16_t *)qbuf.data(), deq.data(), (int)nels); - } else if (typ == GGML_TYPE_BF16) { - ggml_bf16_to_fp32_row((const ggml_bf16_t *)qbuf.data(), deq.data(), (int)nels); + const ggml_type_traits * traits = ggml_get_type_traits(quant_type); + if (quant_type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((const ggml_fp16_t *)quantized_buffer.data(), dequantized_buffer.data(), (int)sample_element_count); + } else if (quant_type == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((const ggml_bf16_t *)quantized_buffer.data(), dequantized_buffer.data(), (int)sample_element_count); } else { if (!traits || !traits->to_float) { - LLAMA_LOG_WARN("%s: unsupported quantization type %s\n", __func__, ggml_type_name(typ)); + LLAMA_LOG_WARN("%s: unsupported quantization type %s\n", __func__, ggml_type_name(quant_type)); return 1e35; } size_t done = 0; - while (done < nels) { - const size_t chunk = std::min((size_t)n_per_row, nels - done); - traits->to_float(qbuf.data() + done / n_per_row * row_sz, deq.data() + done, (int)chunk); + while (done < sample_element_count) { + const size_t chunk = std::min((size_t)n_per_row, sample_element_count - done); + traits->to_float(quantized_buffer.data() + done / n_per_row * row_size, dequantized_buffer.data() + done, (int)chunk); done += chunk; } } } // Compute error - const double eps = 1e-12; - size_t off = 0; + size_t offset = 0; size_t row_idx = 0; double total_err = 0.0; - for (int64_t slice = 0; slice < ne2; ++slice) { const int64_t rs = sample_rows_per_slice[slice]; if (rs == 0) { continue; } - const float * wv = has_w ? values_sample + slice * n_per_row : nullptr; - const float * act = has_a ? activations_sample + slice * n_per_row : nullptr; - const double bden = has_a ? bden_per_slice[slice] : 0.0; - + const float * values = has_values ? values_sample + slice * n_per_row : nullptr; + const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr; + const double bias_denominator = has_activations ? bias_denominator_per_slice[slice] : 0.0; double slice_err = 0.0; - for (int64_t r = 0; r < rs; ++r, ++row_idx) { - const float * x = f32_sample.data() + off; - const float * y = deq.data() + off; - - double mse_w = 0.0; - double bnum = 0.0; - - if (wv && act) { + const float * x = f32_sample.data() + offset; + const float * y = dequantized_buffer.data() + offset; + double weighted_mse = 0.0; + double bias_numerator = 0.0; + if (values && activations) { for (int64_t j = 0; j < n_per_row; ++j) { - const double w = wv[j]; + const double v = values[j]; const double e = y[j] - x[j]; - const double a = act[j]; - mse_w += w * e * e; - bnum += w * e * a; + const double a = activations[j]; + weighted_mse += v * e * e; + bias_numerator += v * e * a; } - } else if (wv) { + } else if (values) { for (int64_t j = 0; j < n_per_row; ++j) { - const double w = wv[j]; + const double v = values[j]; const double e = y[j] - x[j]; - mse_w += w * e * e; + weighted_mse += v * e * e; } - } else if (act) { + } else if (activations) { for (int64_t j = 0; j < n_per_row; ++j) { const double e = y[j] - x[j]; - const double a = act[j]; - mse_w += e * e; - bnum += e * a; + const double a = activations[j]; + weighted_mse += e * e; + bias_numerator += e * a; } } else { for (int64_t j = 0; j < n_per_row; ++j) { const double e = y[j] - x[j]; - mse_w += e * e; + weighted_mse += e * e; } } - // corrected normalization: divide the full numerator by x2 - double numer = mse_w; - if (act && bias_lambda != 0.0) { - const double proj = bnum * bnum / (bden + eps); - numer += bias_lambda * proj; + double err_numerator = weighted_mse; + constexpr double epsilon = 1e-12; + constexpr float bias_lambda = 1.0; + //bias_lambda defines the weight of the bias term in the weigthed MSE error function + // 0.0 means no bias (standard MSE) 1.0 means equal weight for bias and error, + // 2.0 means twice as much weight for bias, etc + if (activations && bias_lambda != 0.0) { + const double proj = bias_numerator * bias_numerator / (bias_denominator + epsilon); + err_numerator += bias_lambda * proj; } - const double denom = x2_per_row[row_idx] + eps; - const double row_err = numer / denom; - + const double err_denominator = row_sq_norm[row_idx] + epsilon; + const double row_err = err_numerator / err_denominator; slice_err += row_err; - off += (size_t)n_per_row; + offset += (size_t)n_per_row; } // scale to full rows (nrows) @@ -942,14 +928,11 @@ static std::unordered_map target_bpw_type( std::vector all; all.reserve(tensors.size()); - for (const auto * tw : tensors) { std::vector workers; workers.reserve(std::max(1, nthread)); - ggml_tensor * t = tw->tensor; const std::string name = ggml_get_name(t); - if (!can_quantize(t)) { continue; } LLAMA_LOG_INFO("\t%s: - processing tensor %45s \t(%12d elements)\n", __func__, name.c_str(), (int)ggml_nelements(t)); @@ -959,37 +942,26 @@ static std::unordered_map target_bpw_type( } ml.load_data_for(t); - const int64_t nelem = ggml_nelements(t); - std::vector> f32_conv_buf; - const float * values_all = get_values(name); - const float * activations_all = get_activations(name); - // Dequantize only sampled rows into f32_sample const int64_t n_per_row = t->ne[0]; const int64_t nrows_total = t->ne[1]; const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; + // Larger sample_rows_per_expert values may result in more accurate error estimates, but will take longer to compute + int sample_rows_per_expert = 512; + std::vector f32_sample; + f32_sample.reserve((size_t)ne2 * (size_t)std::min(nrows_total, sample_rows_per_expert) * (size_t)n_per_row); + + // deterministic sampling seed based on tensor name + fixed constant + std::mt19937 rng(std::hash{}(name) ^0xeabada55cafed00d); + std::vector sample_rows_per_slice(ne2, 0); + const int64_t sample_rows_max = std::max(1, std::min(nrows_total, sample_rows_per_expert)); + const int64_t stride = std::max(1, nrows_total / sample_rows_max); + std::vector row_buffer(n_per_row); const ggml_type src_type = t->type; const ggml_type_traits *src_traits = ggml_get_type_traits(src_type); const bool src_is_quant = ggml_is_quantized(src_type); const size_t src_row_sz = ggml_row_size(src_type, n_per_row); - - std::vector f32_sample; - f32_sample.reserve((size_t)ne2 * (size_t)std::min(nrows_total, sample_rows_per_expert) * (size_t)n_per_row); - - std::vector values_sample; - std::vector activations_sample; - std::vector sample_rows_per_slice(ne2, 0); - - // deterministic sampling seed based on tensor name + fixed constant - std::mt19937 rng(std::hash{}(name) ^0xeabada55cafed00d); - - const int64_t sample_rows_max = std::max(1, std::min(nrows_total, sample_rows_per_expert)); - const int64_t stride = std::max(1, nrows_total / sample_rows_max); - - // Temporary buffer for one dequantized row - std::vector rowbuf((size_t)n_per_row); - for (int64_t slice = 0; slice < ne2; ++slice) { int64_t current_sampled_rows = 0; int64_t offset = 0; @@ -1004,19 +976,19 @@ static std::unordered_map target_bpw_type( f32_sample.insert(f32_sample.end(), src_row, src_row + n_per_row); } else if (src_type == GGML_TYPE_F16) { const ggml_fp16_t * src_row = (const ggml_fp16_t *)((const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz); - ggml_fp16_to_fp32_row(src_row, rowbuf.data(), (int)n_per_row); - f32_sample.insert(f32_sample.end(), rowbuf.begin(), rowbuf.end()); + ggml_fp16_to_fp32_row(src_row, row_buffer.data(), (int)n_per_row); + f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end()); } else if (src_type == GGML_TYPE_BF16) { const ggml_bf16_t * src_row = (const ggml_bf16_t *)((const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz); - ggml_bf16_to_fp32_row(src_row, rowbuf.data(), (int)n_per_row); - f32_sample.insert(f32_sample.end(), rowbuf.begin(), rowbuf.end()); + ggml_bf16_to_fp32_row(src_row, row_buffer.data(), (int)n_per_row); + f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end()); } else if (src_is_quant) { const uint8_t * qrow = (const uint8_t *)t->data + slice * (src_row_sz * nrows_total) + r * src_row_sz; if (!src_traits || !src_traits->to_float) { throw std::runtime_error(format("cannot dequantize type %s for sampling", ggml_type_name(src_type))); } - src_traits->to_float(qrow, rowbuf.data(), (int)n_per_row); - f32_sample.insert(f32_sample.end(), rowbuf.begin(), rowbuf.end()); + src_traits->to_float(qrow, row_buffer.data(), (int)n_per_row); + f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end()); } else { throw std::runtime_error(format("unsupported src type %s for sampling", ggml_type_name(src_type))); } @@ -1045,6 +1017,10 @@ static std::unordered_map target_bpw_type( } }; + const float * values_all = get_values(name); + const float * activations_all = get_activations(name); + std::vector values_sample; + std::vector activations_sample; if (values_all) { // get size from the map (not just the raw pointer) auto itv = values_data->find(remap_imatrix(name, mapped)); @@ -1057,6 +1033,7 @@ static std::unordered_map target_bpw_type( copy_or_broadcast(activations_all, sz, activations_sample); } + const int64_t nelem = ggml_nelements(t); tensor_info info; info.w = tw; info.n_elements = nelem; @@ -1067,12 +1044,12 @@ static std::unordered_map target_bpw_type( // Build list of candidate types first (compatible ones) std::vector quant_candidates; if (is_iq(params->ftype)) { - quant_candidates.assign(std::begin(iq_candidates), std::end(iq_candidates)); + quant_candidates.assign(std::begin(iq_quants), std::end(iq_quants)); } else { - quant_candidates.assign(std::begin(k_candidates), std::end(k_candidates)); + quant_candidates.assign(std::begin(k_quants), std::end(k_quants)); } - // Compute maximum row size among compatible candidates (to size qbuf once) + // Compute maximum row size among compatible candidates (to size quantized_buffer once) size_t max_row_sz = 0; const bool has_valid_imatrix = !values_sample.empty() && values_sample.size() == (size_t)ne2 * (size_t)n_per_row; std::vector compatible_candidates; @@ -1092,21 +1069,20 @@ static std::unordered_map target_bpw_type( compatible_candidates.erase(std::unique(compatible_candidates.begin(), compatible_candidates.end()), compatible_candidates.end()); // Now evaluate candidates - std::vector cand_out(compatible_candidates.size()); - const float *vals_ptr = values_sample.empty() ? nullptr : values_sample.data(); - const float *acts_ptr = activations_sample.empty() ? nullptr : activations_sample.data(); - std::vector qbuf(max_row_sz * total_sampled_rows); - std::vector deq(f32_sample.size()); + std::vector eval_candidates(compatible_candidates.size()); + const float *values = values_sample.empty() ? nullptr : values_sample.data(); + const float *activations = activations_sample.empty() ? nullptr : activations_sample.data(); + std::vector quantized_buffer(max_row_sz * total_sampled_rows); + std::vector dequantised_buffer(f32_sample.size()); int n_eval_threads = std::max(1, std::min(nthread, (int)compatible_candidates.size())); std::atomic cidx{0}; std::vector eval_workers; eval_workers.reserve(n_eval_threads); - for (int ti = 0; ti < n_eval_threads; ++ti) { eval_workers.emplace_back([&] { // thread-local scratch - std::vector tl_qbuf(qbuf.size()); - std::vector tl_deq(deq.size()); + std::vector tl_quantized_buffer(quantized_buffer.size()); + std::vector tl_dequantised_buffer(dequantised_buffer.size()); for (;;) { const size_t i = cidx.fetch_add(1, std::memory_order_relaxed); @@ -1114,15 +1090,16 @@ static std::unordered_map target_bpw_type( const ggml_type tt = compatible_candidates[i]; const auto bpw = (float)tensor_bpw(t, tt); - const size_t bytes = total_bytes(t, tt); - const auto err = (float)estimate_error(t, tt, f32_sample, sample_rows_per_slice, vals_ptr, acts_ptr, tl_qbuf, tl_deq); - cand_out[i] = candidate_types{ tt, bpw, bytes, err }; + const size_t bytes = tensor_bytes(t, tt); + const auto err = (float)estimate_error(t, tt, f32_sample, sample_rows_per_slice, values, activations, tl_quantized_buffer, tl_dequantised_buffer); + eval_candidates[i] = candidate_types{ tt, bpw, bytes, err }; } }); } + for (auto &th : eval_workers) { th.join(); } - for (auto &c : cand_out) { + for (auto &c : eval_candidates) { if (c.bytes > 0) { info.candidate.push_back(c); } } @@ -1132,7 +1109,7 @@ static std::unordered_map target_bpw_type( info.candidate.push_back(candidate_types{ t->type, bpw, ggml_nbytes(t), 0.0 }); } - // Remove dominated candidates: if A has >= bytes and >= error than B (and > in at least one), drop A. + // Keep only the Pareto‑optimal candidates: if A has >= bytes and >= error than B, drop A. { std::vector pruned; pruned.reserve(info.candidate.size()); @@ -1155,36 +1132,37 @@ static std::unordered_map target_bpw_type( info.candidate.swap(pruned); } - std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types & a, const candidate_types & b) { - if (a.bpw != b.bpw) { return a.bpw < b.bpw; } - if (a.error != b.error) { return a.error < b.error; } - return a.bytes < b.bytes; - }); - // Collapse candidates with identical storage size (bytes) { - std::vector uniq; - uniq.reserve(info.candidate.size()); + std::vector unique; + unique.reserve(info.candidate.size()); + // Sort by bpw asc, error asc, bytes asc + std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types & a, const candidate_types & b) { + if (a.bpw != b.bpw) { return a.bpw < b.bpw; } + if (a.error != b.error) { return a.error < b.error; } + return a.bytes < b.bytes; + }); for (size_t i = 0; i < info.candidate.size();) { - size_t j = i + 1; + size_t j = i + 1; candidate_types best = info.candidate[i]; // group same-byte entries, keep the one with the lowest error while (j < info.candidate.size() && info.candidate[j].bytes == info.candidate[i].bytes) { - if (info.candidate[j].error < best.error) { best = info.candidate[j]; } + if (info.candidate[j].error < best.error) { + best = info.candidate[j]; + } ++j; } - uniq.push_back(best); + unique.push_back(best); i = j; } - info.candidate.swap(uniq); + info.candidate.swap(unique); } // Initialize choice at the smallest bpw candidate info.choice = 0; info.min_bpw = info.candidate.front().bpw; info.max_bpw = info.candidate.back().bpw; - all.push_back(std::move(info)); } @@ -1196,6 +1174,7 @@ static std::unordered_map target_bpw_type( for (const auto & ti : all) { b += ti.candidate[ti.choice].bytes; } + return b; }; @@ -1204,6 +1183,7 @@ static std::unordered_map target_bpw_type( for (const auto & ti : all) { w += ti.n_elements; } + return w; }; @@ -1215,12 +1195,14 @@ static std::unordered_map target_bpw_type( // Precompute current bpw double bpw_now = current_bpw(); + float target_bpw = params->target_bpw; // If minimal bpw is already above the target, we're constrained by geometry; return closest (min bpw) if (bpw_now >= target_bpw) { std::unordered_map overrides; for (const auto & ti : all) { overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type; } + return overrides; } @@ -1268,6 +1250,7 @@ static std::unordered_map target_bpw_type( best = upgrade{ i, j, err, delta_bytes, ratio }; } } + return best; }; @@ -1286,16 +1269,12 @@ static std::unordered_map target_bpw_type( } } - // We might still be below target but taking any single upgrade overshoots. - // Try to find the best upgrade that overshoots the target_bpw by the least and has the best error-to-size ratio. + // We might still be below target so we try to find the best upgrade one last time { - double under_gap = target_bpw - bpw_now; - upgrade best_over{ -1, -1, 0.0, 0, -1.0 }; double best_over_gap = 1e300; - + double under_gap = target_bpw - bpw_now; size_t now_bytes = current_total_bytes(); - for (int i = 0; i < (int) all.size(); ++i) { const auto & ti = all[i]; if (ti.choice >= (int)ti.candidate.size() - 1) { continue; } @@ -1305,19 +1284,16 @@ static std::unordered_map target_bpw_type( const auto & cur = ti.candidate[ti.choice]; const auto & nxt = ti.candidate[j]; - size_t delta_bytes = nxt.bytes - cur.bytes; if (delta_bytes == 0) { continue; } size_t over_bytes = now_bytes + delta_bytes; double bpw_over = (double)over_bytes * 8.0 / (double)tw; - - double over_gap = std::abs(bpw_over - (double)target_bpw); - double err = cur.error - nxt.error; if (err < 0.0) { err = 0.0; } double ratio = err / (double)(delta_bytes * 8ull); + double over_gap = std::abs(bpw_over - (double)target_bpw); if (over_gap < best_over_gap - 1e-12 || (std::abs(over_gap - best_over_gap) <= 1e-12 && ratio > best_over.ratio)) { best_over_gap = over_gap; best_over = upgrade{ i, j, err, delta_bytes, ratio }; @@ -1339,6 +1315,7 @@ static std::unordered_map target_bpw_type( __func__, ggml_get_name(ti.w->tensor), ggml_type_name(ti.candidate[ti.choice].type), ti.candidate[ti.choice].bpw, ti.candidate[ti.choice].error); overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type; } + return overrides; }