Refactor side_data() and copy_or_broadcast()

This commit is contained in:
Ed Addario 2025-09-21 16:19:03 +01:00
parent 7386d4eadd
commit 08146fd67f
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 4 additions and 9 deletions

View File

@ -1088,14 +1088,12 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
} }
} }
auto side_data = [&](const std::unordered_map<std::string, std::vector<float>> * m, const std::string & tensor_name) -> std::pair<const float*, size_t> { auto side_data = [&](const std::unordered_map<std::string, std::vector<float>> * m, const std::string & tensor_name) {
if (!m) { return {nullptr, 0}; } if (!m) { return std::pair<const float*, size_t>{nullptr, 0}; }
const std::string key = remap_imatrix(tensor_name, mapped); const std::string key = remap_imatrix(tensor_name, mapped);
const auto it = m->find(key); const auto it = m->find(key);
if (it == m->end()) { return {nullptr, 0}; } return it == m->end() ? std::pair<const float*, size_t>{nullptr, 0} : std::pair<const float*, size_t>{ it->second.data(), it->second.size() };
return { it->second.data(), it->second.size() };
}; };
// Copy this row's side data (values and activations), or broadcasts to all slices // Copy this row's side data (values and activations), or broadcasts to all slices
@ -1105,9 +1103,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const size_t want = (size_t)ne2 * (size_t)n_per_row; const size_t want = (size_t)ne2 * (size_t)n_per_row;
if (src_sz == want) { if (src_sz == want) {
dst.resize(want); dst.assign(src, src + want);
std::memcpy(dst.data(), src, want * sizeof(float));
return; return;
} }
if (src_sz == (size_t)n_per_row) { if (src_sz == (size_t)n_per_row) {
@ -1115,7 +1111,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
for (int64_t s = 0; s < ne2; ++s) { for (int64_t s = 0; s < ne2; ++s) {
std::memcpy(dst.data() + s * n_per_row, src, n_per_row * sizeof(float)); std::memcpy(dst.data() + s * n_per_row, src, n_per_row * sizeof(float));
} }
return; return;
} }