Refactor side_data() and copy_or_broadcast()
This commit is contained in:
parent
7386d4eadd
commit
08146fd67f
|
|
@ -1088,14 +1088,12 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto side_data = [&](const std::unordered_map<std::string, std::vector<float>> * m, const std::string & tensor_name) -> std::pair<const float*, size_t> {
|
auto side_data = [&](const std::unordered_map<std::string, std::vector<float>> * m, const std::string & tensor_name) {
|
||||||
if (!m) { return {nullptr, 0}; }
|
if (!m) { return std::pair<const float*, size_t>{nullptr, 0}; }
|
||||||
|
|
||||||
const std::string key = remap_imatrix(tensor_name, mapped);
|
const std::string key = remap_imatrix(tensor_name, mapped);
|
||||||
const auto it = m->find(key);
|
const auto it = m->find(key);
|
||||||
if (it == m->end()) { return {nullptr, 0}; }
|
return it == m->end() ? std::pair<const float*, size_t>{nullptr, 0} : std::pair<const float*, size_t>{ it->second.data(), it->second.size() };
|
||||||
|
|
||||||
return { it->second.data(), it->second.size() };
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Copy this row's side data (values and activations), or broadcasts to all slices
|
// Copy this row's side data (values and activations), or broadcasts to all slices
|
||||||
|
|
@ -1105,9 +1103,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
|
|
||||||
const size_t want = (size_t)ne2 * (size_t)n_per_row;
|
const size_t want = (size_t)ne2 * (size_t)n_per_row;
|
||||||
if (src_sz == want) {
|
if (src_sz == want) {
|
||||||
dst.resize(want);
|
dst.assign(src, src + want);
|
||||||
std::memcpy(dst.data(), src, want * sizeof(float));
|
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (src_sz == (size_t)n_per_row) {
|
if (src_sz == (size_t)n_per_row) {
|
||||||
|
|
@ -1115,7 +1111,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
for (int64_t s = 0; s < ne2; ++s) {
|
for (int64_t s = 0; s < ne2; ++s) {
|
||||||
std::memcpy(dst.data() + s * n_per_row, src, n_per_row * sizeof(float));
|
std::memcpy(dst.data() + s * n_per_row, src, n_per_row * sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue