From e8e2aed17a4ade7b14021e05f2a55f9b8f26510f Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Sun, 21 Sep 2025 13:41:44 +0100 Subject: [PATCH] Refactor row sampling --- src/llama-quant.cpp | 49 +++++++++++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 9e7d9d295c..4a8c08e68f 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -1029,7 +1029,6 @@ static std::unordered_map target_bpw_type( const int64_t nrows_total = tensor->ne[1]; const int64_t ne2 = tensor->ne[2] > 0 ? tensor->ne[2] : 1; - // Larger rows_sample_per_expert values may result in more accurate error estimates, but it will take much longer to compute const int rows_sample_per_expert = activations_data ? 512 : 256; std::vector f32_sample; f32_sample.reserve((size_t)ne2 * (size_t)std::min(nrows_total, rows_sample_per_expert) * (size_t)n_per_row); @@ -1037,11 +1036,30 @@ static std::unordered_map target_bpw_type( std::vector rows_sample(ne2, 0); const int64_t rows_sample_max = std::max(1, std::min(nrows_total, rows_sample_per_expert)); const int64_t stride = std::max(1, nrows_total / rows_sample_max); - std::vector row_buffer(n_per_row); const ggml_type src_type = tensor->type; - const ggml_type_traits *src_traits = ggml_get_type_traits(src_type); + const ggml_type_traits * src_traits = ggml_get_type_traits(src_type); const bool src_is_quant = ggml_is_quantized(src_type); const size_t src_row_sz = ggml_row_size(src_type, n_per_row); + + std::vector row_buffer(n_per_row); + auto row_to_fp32 = [&](const uint8_t * src, float * dst) { + if (src_type == GGML_TYPE_F32) { + std::memcpy(dst, src, sizeof(float) * (size_t)n_per_row); + } else if (src_type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((const ggml_fp16_t *)src, dst, (int)n_per_row); + } else if (src_type == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((const ggml_bf16_t *)src, dst, (int)n_per_row); + } else if (src_is_quant) { + if (!src_traits || !src_traits->to_float) { + throw std::runtime_error(format("cannot dequantize type %s for sampling", ggml_type_name(src_type))); + } + + src_traits->to_float(src, dst, (int)n_per_row); + } else { + throw std::runtime_error(format("unsupported src type %s for sampling", ggml_type_name(src_type))); + } + }; + for (int64_t slice = 0; slice < ne2; ++slice) { std::mt19937 rng(std::hash{}(name) ^ 0xeabada55cafed00d ^ slice); int64_t current_sampled_rows = 0; @@ -1052,31 +1070,18 @@ static std::unordered_map target_bpw_type( } for (int64_t r = offset; r < nrows_total && current_sampled_rows < rows_sample_max; r += stride) { + const uint8_t * src_row = (const uint8_t *)tensor->data + slice * (src_row_sz * nrows_total) + r * src_row_sz; if (src_type == GGML_TYPE_F32) { - const float * src_row = (const float *)tensor->data + slice * (n_per_row * nrows_total) + r * n_per_row; - f32_sample.insert(f32_sample.end(), src_row, src_row + n_per_row); - } else if (src_type == GGML_TYPE_F16) { - const auto * src_row = (const ggml_fp16_t *)((const uint8_t *)tensor->data + slice * (src_row_sz * nrows_total) + r * src_row_sz); - ggml_fp16_to_fp32_row(src_row, row_buffer.data(), (int)n_per_row); - f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end()); - } else if (src_type == GGML_TYPE_BF16) { - const auto * src_row = (const ggml_bf16_t *)((const uint8_t *)tensor->data + slice * (src_row_sz * nrows_total) + r * src_row_sz); - ggml_bf16_to_fp32_row(src_row, row_buffer.data(), (int)n_per_row); - f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end()); - } else if (src_is_quant) { - const uint8_t * qrow = (const uint8_t *)tensor->data + slice * (src_row_sz * nrows_total) + r * src_row_sz; - if (!src_traits || !src_traits->to_float) { - throw std::runtime_error(format("cannot dequantize type %s for sampling", ggml_type_name(src_type))); - } - src_traits->to_float(qrow, row_buffer.data(), (int)n_per_row); - f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end()); + auto src_f32 = (const float *)src_row; + f32_sample.insert(f32_sample.end(), src_f32, src_f32 + n_per_row); } else { - throw std::runtime_error(format("unsupported src type %s for sampling", ggml_type_name(src_type))); + row_to_fp32(src_row, row_buffer.data()); + f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end()); } ++current_sampled_rows; } - + rows_sample[slice] = current_sampled_rows; }