perplexity: add proper batching (#19661)
This commit is contained in:
parent
cceb1b4e33
commit
d612901116
|
|
@ -347,7 +347,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||
int count = 0;
|
||||
double nll = 0.0;
|
||||
|
||||
LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
||||
const int n_seq = std::max(1, n_batch / n_ctx);
|
||||
LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
|
||||
|
||||
for (int i = 0; i < n_chunk; ++i) {
|
||||
const int start = i * params.ppl_stride;
|
||||
|
|
@ -1737,11 +1738,21 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||
}
|
||||
|
||||
const int n_batch = params.n_batch;
|
||||
const int num_batches = (n_ctx + n_batch - 1)/n_batch;
|
||||
const int num_batches = (static_cast<int>(n_ctx) + n_batch - 1) / n_batch;
|
||||
// Calculate n_seq based on the logits file's n_ctx, but cap it at what the context supports
|
||||
const int n_seq_max = llama_n_seq_max(ctx);
|
||||
int n_seq = std::max(1, n_batch / static_cast<int>(n_ctx));
|
||||
if (n_seq > n_seq_max) {
|
||||
LOG_WRN("%s: calculated n_seq=%d exceeds context's n_seq_max=%d, capping at %d\n",
|
||||
__func__, n_seq, n_seq_max, n_seq_max);
|
||||
n_seq = n_seq_max;
|
||||
}
|
||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
|
||||
|
||||
llama_batch batch = llama_batch_init(std::min(n_batch, static_cast<int>(n_ctx)*n_seq), 0, 1);
|
||||
|
||||
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
|
||||
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
|
||||
std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
|
||||
|
|
@ -1750,6 +1761,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||
logits.reserve(size_t(n_ctx) * n_vocab);
|
||||
}
|
||||
|
||||
LOG_INF("%s: computing over %d chunks, n_ctx=%u, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
|
||||
|
||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||
|
||||
auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) {
|
||||
|
|
@ -1774,107 +1787,122 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||
auto kld_ptr = kld_values.data();
|
||||
auto p_diff_ptr = p_diff_values.data();
|
||||
|
||||
for (int i = 0; i < n_chunk; ++i) {
|
||||
const int first = n_ctx/2;
|
||||
|
||||
for (int i = 0; i < n_chunk; i += n_seq) {
|
||||
const int start = i * n_ctx;
|
||||
const int end = start + n_ctx;
|
||||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
const int n_seq_batch = std::min(n_seq, n_chunk - i);
|
||||
|
||||
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
|
||||
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
|
||||
return;
|
||||
}
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
||||
// save original token and restore it after eval
|
||||
const auto token_org = tokens[batch_start];
|
||||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[batch_start] = llama_vocab_bos(vocab);
|
||||
}
|
||||
int n_outputs = 0;
|
||||
|
||||
common_batch_clear(batch);
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
int seq_start = batch_start + seq*n_ctx;
|
||||
|
||||
// save original token and restore it after eval
|
||||
const auto token_org = tokens[seq_start];
|
||||
|
||||
// add BOS token for the first batch of each chunk
|
||||
if (add_bos && j == 0) {
|
||||
tokens[seq_start] = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
for (int k = 0; k < batch_size; ++k) {
|
||||
const int pos = j*n_batch + k;
|
||||
const bool need_logits = pos >= first;
|
||||
common_batch_add(batch, tokens[seq_start + k], pos, { seq }, need_logits);
|
||||
n_outputs += need_logits;
|
||||
}
|
||||
|
||||
// restore the original token in case it was set to BOS
|
||||
tokens[seq_start] = token_org;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
LOG_ERR("%s : failed to decode\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
return;
|
||||
}
|
||||
|
||||
// restore the original token in case it was set to BOS
|
||||
tokens[batch_start] = token_org;
|
||||
|
||||
if (num_batches > 1) {
|
||||
if (num_batches > 1 && n_outputs > 0) {
|
||||
const auto * batch_logits = llama_get_logits(ctx);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab);
|
||||
logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
|
||||
if (i == 0) {
|
||||
llama_synchronize(ctx);
|
||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||
int total_seconds = (int)(t_total * n_chunk);
|
||||
int total_seconds = (int)(t_total * n_chunk / n_seq);
|
||||
if (total_seconds >= 60*60) {
|
||||
LOG("%d hours ", total_seconds / (60*60));
|
||||
total_seconds = total_seconds % (60*60);
|
||||
}
|
||||
LOG("%.2f minutes\n", total_seconds / 60.0);
|
||||
LOG("\n");
|
||||
LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
|
||||
}
|
||||
LOG("\n");
|
||||
LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
|
||||
|
||||
const int first = n_ctx/2;
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||
process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||
workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
|
||||
p_diff_ptr += n_ctx - 1 - first;
|
||||
kld_ptr += n_ctx - 1 - first;
|
||||
// Read log probs for each sequence in the batch
|
||||
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
|
||||
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i + seq);
|
||||
llama_batch_free(batch);
|
||||
return;
|
||||
}
|
||||
|
||||
LOG("%4d", i+1);
|
||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
|
||||
|
||||
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
||||
const double ppl_val = exp(log_ppl.first);
|
||||
const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 )
|
||||
LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
|
||||
process_logits(n_vocab, all_logits, tokens.data() + start + seq*n_ctx + first, n_ctx - 1 - first,
|
||||
workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
|
||||
p_diff_ptr += n_ctx - 1 - first;
|
||||
kld_ptr += n_ctx - 1 - first;
|
||||
|
||||
auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
|
||||
const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
|
||||
const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
|
||||
const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
|
||||
LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
|
||||
LOG("%4d", i + seq + 1);
|
||||
|
||||
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
||||
LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
|
||||
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
||||
const double ppl_val = exp(log_ppl.first);
|
||||
const double ppl_unc = ppl_val * log_ppl.second;
|
||||
LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
|
||||
|
||||
auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
|
||||
const double p_diff_rms_val = sqrt(p_diff_mse.first);
|
||||
const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
|
||||
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
|
||||
auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
|
||||
const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
|
||||
const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
|
||||
const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
|
||||
LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
|
||||
|
||||
double p_top_val = 1.*kld.n_same_top/kld.count;
|
||||
double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
|
||||
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
|
||||
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
||||
LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
|
||||
|
||||
LOG("\n");
|
||||
auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
|
||||
const double p_diff_rms_val = sqrt(p_diff_mse.first);
|
||||
const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
|
||||
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
|
||||
|
||||
double p_top_val = 1.*kld.n_same_top/kld.count;
|
||||
double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
|
||||
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
|
||||
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
logits.clear();
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
LOG("\n");
|
||||
|
||||
if (kld.count < 100) return; // we do not wish to do statistics on so few values
|
||||
|
|
@ -1996,7 +2024,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
|
||||
|
||||
if (ppl) {
|
||||
if (ppl || params.kl_divergence) {
|
||||
const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
|
||||
const int32_t n_kv = n_seq * n_ctx;
|
||||
|
||||
|
|
@ -2006,12 +2034,8 @@ int main(int argc, char ** argv) {
|
|||
params.n_batch = std::min(params.n_batch, n_kv);
|
||||
} else {
|
||||
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||
if (params.kl_divergence) {
|
||||
params.n_parallel = 1;
|
||||
} else {
|
||||
// ensure there's at least enough seq_ids for HellaSwag
|
||||
params.n_parallel = std::max(4, params.n_parallel);
|
||||
}
|
||||
// ensure there's at least enough seq_ids for HellaSwag
|
||||
params.n_parallel = std::max(4, params.n_parallel);
|
||||
}
|
||||
|
||||
if (params.ppl_stride > 0) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue