Compute rows based on tensor shape and slice count
This commit is contained in:
parent
e49e241d37
commit
b3b8a111a5
|
|
@ -650,9 +650,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
GGML_TYPE_IQ3_XXS,
|
||||
GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_IQ4_XS,
|
||||
GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q4_K,
|
||||
GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
GGML_TYPE_Q8_0
|
||||
|
|
@ -961,10 +959,24 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
ml.load_data_for(tensor);
|
||||
|
||||
// Dequantize sampled rows into f32_sample
|
||||
const int rows_sample_per_expert = activations_data ? 512 : 256;
|
||||
const int64_t n_per_row = tensor->ne[0];
|
||||
const int64_t nrows_total = tensor->ne[1];
|
||||
const int64_t ne2 = tensor->ne[2] > 0 ? tensor->ne[2] : 1;
|
||||
|
||||
// Compute rows based on tensor shape and slice count
|
||||
auto sample_rows = [](const int64_t n, const int64_t rows, const int64_t n2, const bool has_acts) -> int64_t {
|
||||
const double tensor_budget = has_acts ? 1 * 1024 * 1024 : 0.5 * 1024 * 1024;
|
||||
const double scale_rows = std::clamp(std::sqrt(std::max(1.0, (double)rows) / 4096.0), 0.5, 2.0); // favour more rows for large nrt
|
||||
const double slice_budget = tensor_budget * scale_rows / std::max<int64_t>(1, n2);
|
||||
const int64_t min_rows = has_acts ? 128 : 64;
|
||||
const int64_t max_rows = 4096;
|
||||
int64_t total_rows = std::llround(slice_budget / std::max<int64_t>(1, n));
|
||||
total_rows = std::max<int64_t>(min_rows, std::min<int64_t>(total_rows, std::min<int64_t>(rows, max_rows)));
|
||||
if (rows <= min_rows * 2) { total_rows = rows; } // use all rows for small tensors
|
||||
return total_rows;
|
||||
};
|
||||
|
||||
const int64_t rows_sample_per_expert = sample_rows(n_per_row, nrows_total, ne2, activations_data != nullptr);
|
||||
std::vector<float> f32_sample;
|
||||
f32_sample.reserve((size_t)ne2 * (size_t)std::min<int64_t>(nrows_total, rows_sample_per_expert) * (size_t)n_per_row);
|
||||
std::vector<int64_t> rows_sample(ne2, 0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue