Refactor row sampling

This commit is contained in:
Ed Addario 2025-09-21 16:18:26 +01:00
parent b6c008fd8a
commit 7386d4eadd
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 48 additions and 39 deletions

View File

@ -1019,64 +1019,73 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
ml.load_data_for(tensor);
// Dequantize sampled rows into f32_sample
const int rows_sample_per_expert = activations_data ? 512 : 256;
const int64_t n_per_row = tensor->ne[0];
const int64_t nrows_total = tensor->ne[1];
const int64_t ne2 = tensor->ne[2] > 0 ? tensor->ne[2] : 1;
const int rows_sample_per_expert = activations_data ? 512 : 256;
std::vector<float> f32_sample;
f32_sample.reserve((size_t)ne2 * (size_t)std::min<int64_t>(nrows_total, rows_sample_per_expert) * (size_t)n_per_row);
std::vector<int64_t> rows_sample(ne2, 0);
const int64_t rows_sample_max = std::max<int64_t>(1, std::min<int64_t>(nrows_total, rows_sample_per_expert));
const int64_t stride = std::max<int64_t>(1, nrows_total / rows_sample_max);
const ggml_type src_type = tensor->type;
const ggml_type_traits * src_traits = ggml_get_type_traits(src_type);
const bool src_is_quant = ggml_is_quantized(src_type);
const size_t src_row_sz = ggml_row_size(src_type, n_per_row);
std::vector<float> row_buffer(n_per_row);
// Convert a single row to fp32
auto row_to_fp32 = [&](const uint8_t * src, float * dst) {
if (src_type == GGML_TYPE_F32) {
const ggml_type t = src_type;
if (t == GGML_TYPE_F32) {
std::memcpy(dst, src, sizeof(float) * (size_t)n_per_row);
} else if (src_type == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *)src, dst, (int)n_per_row);
} else if (src_type == GGML_TYPE_BF16) {
ggml_bf16_to_fp32_row((const ggml_bf16_t *)src, dst, (int)n_per_row);
} else if (src_is_quant) {
if (!src_traits || !src_traits->to_float) {
throw std::runtime_error(format("cannot dequantize type %s for sampling", ggml_type_name(src_type)));
}
src_traits->to_float(src, dst, (int)n_per_row);
} else {
throw std::runtime_error(format("unsupported src type %s for sampling", ggml_type_name(src_type)));
return;
}
if (t == GGML_TYPE_F16) {
ggml_fp16_to_fp32_row((const ggml_fp16_t *) src, dst, (int)n_per_row);
return;
}
if (t == GGML_TYPE_BF16) {
ggml_bf16_to_fp32_row((const ggml_bf16_t *) src, dst, (int)n_per_row);
return;
}
if (src_is_quant) {
GGML_ASSERT(src_traits && src_traits->to_float);
src_traits->to_float(src, dst, (int) n_per_row);
return;
}
throw std::runtime_error(format("unsupported src type %s for sampling", ggml_type_name(t)));
};
for (int64_t slice = 0; slice < ne2; ++slice) {
std::mt19937 rng(std::hash<std::string>{}(name) ^ 0xeabada55cafed00d ^ slice);
int64_t current_sampled_rows = 0;
int64_t offset = 0;
if (stride > 1) {
std::uniform_int_distribution<int64_t> dist(0, stride - 1);
offset = dist(rng);
}
for (int64_t r = offset; r < nrows_total && current_sampled_rows < rows_sample_max; r += stride) {
const uint8_t * src_row = (const uint8_t *)tensor->data + slice * (src_row_sz * nrows_total) + r * src_row_sz;
if (src_type == GGML_TYPE_F32) {
auto src_f32 = (const float *)src_row;
f32_sample.insert(f32_sample.end(), src_f32, src_f32 + n_per_row);
} else {
row_to_fp32(src_row, row_buffer.data());
f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end());
// Sample rows randomly per slice
{
f32_sample.clear();
std::vector<float> row_buffer(n_per_row);
for (int64_t slice = 0; slice < ne2; ++slice) {
std::mt19937 rng(std::hash<std::string>{}(name) ^ 0xeabada55cafed00d ^ slice);
const int64_t rows_sample_max = std::max<int64_t>(1, std::min<int64_t>(nrows_total, rows_sample_per_expert));
const int64_t stride = std::max<int64_t>(1, nrows_total / rows_sample_max);
int64_t offset = 0;
if (stride > 1) {
std::uniform_int_distribution<int64_t> dist(0, stride - 1);
offset = dist(rng);
}
++current_sampled_rows;
}
int64_t current = 0;
for (int64_t r = offset; r < nrows_total && current < rows_sample_max; r += stride) {
const uint8_t * src_row = (const uint8_t *)tensor->data + slice * (src_row_sz * nrows_total) + r * src_row_sz;
if (src_type == GGML_TYPE_F32) {
auto src_f32 = (const float *)src_row;
f32_sample.insert(f32_sample.end(), src_f32, src_f32 + n_per_row);
} else {
row_to_fp32(src_row, row_buffer.data());
f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end());
}
rows_sample[slice] = current_sampled_rows;
++current;
}
rows_sample[slice] = current;
}
}
auto side_data = [&](const std::unordered_map<std::string, std::vector<float>> * m, const std::string & tensor_name) -> std::pair<const float*, size_t> {