Refactor helper lambdas

This commit is contained in:
Ed Addario 2025-09-21 16:04:13 +01:00
parent b433fd9547
commit b6c008fd8a
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 6 additions and 11 deletions

View File

@ -665,28 +665,23 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
auto tensor_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t { auto tensor_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t {
const int64_t n_per_row = t->ne[0]; const int64_t n_per_row = t->ne[0];
const size_t row_sz = ggml_row_size(typ, n_per_row); const size_t row_sz = ggml_row_size(typ, n_per_row);
const int64_t nrows = ggml_nrows(t); return (size_t)ggml_nrows(t) * row_sz;
return (size_t)nrows * row_sz;
}; };
auto tensor_bpw = [&](const ggml_tensor * t, const ggml_type typ) -> double { auto tensor_bpw = [&](const ggml_tensor * t, const ggml_type typ) -> double {
const int64_t nelem = ggml_nelements(t);
const size_t bytes = tensor_bytes(t, typ); const size_t bytes = tensor_bytes(t, typ);
return (double)bytes * 8.0 / (double)nelem; return (double)bytes * 8.0 / (double)ggml_nelements(t);
}; };
auto is_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> bool { auto is_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> bool {
const int64_t n_per_row = t->ne[0];
const int64_t blck = ggml_blck_size(typ); const int64_t blck = ggml_blck_size(typ);
if (blck <= 1) { return true; } return blck <= 1 || (t->ne[0] % blck) == 0;
return n_per_row % blck == 0;
}; };
auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type { auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type {
if (is_compatible(t, typ)) { return typ; } if (is_compatible(t, typ)) return typ;
ggml_type fb = fallback_type(typ); ggml_type fb = fallback_type(typ);
if (is_compatible(t, fb)) { return fb; } return is_compatible(t, fb) ? fb : GGML_TYPE_F16;
return GGML_TYPE_F16;
}; };
auto name_tn = LLM_TN(model.arch); auto name_tn = LLM_TN(model.arch);
@ -1080,7 +1075,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
++current_sampled_rows; ++current_sampled_rows;
} }
rows_sample[slice] = current_sampled_rows; rows_sample[slice] = current_sampled_rows;
} }