Refactor copy_or_broadcast()

This commit is contained in:
Ed Addario 2025-09-21 13:42:07 +01:00
parent e8e2aed17a
commit bdefdb673c
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 12 additions and 6 deletions

View File

@ -1087,6 +1087,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
auto side_data = [&](const std::unordered_map<std::string, std::vector<float>> * m, const std::string & tensor_name) -> std::pair<const float*, size_t> {
if (!m) { return {nullptr, 0}; }
const std::string key = remap_imatrix(tensor_name, mapped);
const auto it = m->find(key);
if (it == m->end()) { return {nullptr, 0}; }
@ -1095,22 +1096,27 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
};
// Copy this row's side data (values and activations), or broadcasts to all slices
auto copy_or_broadcast = [&](const float *src, size_t src_sz, std::vector<float> &dst) {
const size_t want = (size_t)ne2 * (size_t)n_per_row;
auto copy_or_broadcast = [&](const float * src, size_t src_sz, std::vector<float> & dst) {
dst.clear();
if (!src || src_sz == 0) { return; }
const size_t want = (size_t)ne2 * (size_t)n_per_row;
if (src_sz == want) {
dst.resize(want);
std::memcpy(dst.data(), src, want * sizeof(float));
} else if (src_sz == (size_t)n_per_row) {
return;
}
if (src_sz == (size_t)n_per_row) {
dst.resize(want);
for (int64_t s = 0; s < ne2; ++s) {
std::memcpy(dst.data() + s * n_per_row, src, n_per_row * sizeof(float));
}
} else {
LLAMA_LOG_WARN("%s: side data size mismatch for %s: got %zu, expected %zu or %zu; ignoring\n",
func, name.c_str(), src_sz, (size_t)n_per_row, want);
return;
}
LLAMA_LOG_WARN("%s: side data size mismatch for %s: got %zu, expected %zu or %zu; ignoring\n", func, name.c_str(), src_sz, (size_t)n_per_row, want);
};
const auto [values_all, values_sz] = side_data(values_data, name);