From d612901116ab2066c7923372d4827032ff296bc4 Mon Sep 17 00:00:00 2001 From: AesSedai <7980540+AesSedai@users.noreply.github.com> Date: Mon, 16 Feb 2026 08:44:44 -0800 Subject: [PATCH] perplexity: add proper batching (#19661) --- tools/perplexity/perplexity.cpp | 154 ++++++++++++++++++-------------- 1 file changed, 89 insertions(+), 65 deletions(-) diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index 1ead9c871e..433b747f0d 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -347,7 +347,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params int count = 0; double nll = 0.0; - LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); + const int n_seq = std::max(1, n_batch / n_ctx); + LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); for (int i = 0; i < n_chunk; ++i) { const int start = i * params.ppl_stride; @@ -1737,11 +1738,21 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } const int n_batch = params.n_batch; - const int num_batches = (n_ctx + n_batch - 1)/n_batch; + const int num_batches = (static_cast(n_ctx) + n_batch - 1) / n_batch; + // Calculate n_seq based on the logits file's n_ctx, but cap it at what the context supports + const int n_seq_max = llama_n_seq_max(ctx); + int n_seq = std::max(1, n_batch / static_cast(n_ctx)); + if (n_seq > n_seq_max) { + LOG_WRN("%s: calculated n_seq=%d exceeds context's n_seq_max=%d, capping at %d\n", + __func__, n_seq, n_seq_max, n_seq_max); + n_seq = n_seq_max; + } const int nv = 2*((n_vocab + 1)/2) + 4; const bool add_bos = llama_vocab_get_add_bos(vocab); GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); + llama_batch batch = llama_batch_init(std::min(n_batch, static_cast(n_ctx)*n_seq), 0, 1); + std::vector log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); @@ -1750,6 +1761,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { logits.reserve(size_t(n_ctx) * n_vocab); } + LOG_INF("%s: computing over %d chunks, n_ctx=%u, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); + std::vector workers(std::thread::hardware_concurrency() - 1); auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) { @@ -1774,107 +1787,122 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { auto kld_ptr = kld_values.data(); auto p_diff_ptr = p_diff_values.data(); - for (int i = 0; i < n_chunk; ++i) { + const int first = n_ctx/2; + + for (int i = 0; i < n_chunk; i += n_seq) { const int start = i * n_ctx; const int end = start + n_ctx; - const auto t_start = std::chrono::high_resolution_clock::now(); + const int n_seq_batch = std::min(n_seq, n_chunk - i); - if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) { - LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i); - return; - } + const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache llama_memory_clear(llama_get_memory(ctx), true); - llama_batch batch = llama_batch_init(n_batch, 0, 1); - for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - // save original token and restore it after eval - const auto token_org = tokens[batch_start]; - - // add BOS token for the first batch of each chunk - if (add_bos && j == 0) { - tokens[batch_start] = llama_vocab_bos(vocab); - } + int n_outputs = 0; common_batch_clear(batch); - for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + for (int seq = 0; seq < n_seq_batch; seq++) { + int seq_start = batch_start + seq*n_ctx; + + // save original token and restore it after eval + const auto token_org = tokens[seq_start]; + + // add BOS token for the first batch of each chunk + if (add_bos && j == 0) { + tokens[seq_start] = llama_vocab_bos(vocab); + } + + for (int k = 0; k < batch_size; ++k) { + const int pos = j*n_batch + k; + const bool need_logits = pos >= first; + common_batch_add(batch, tokens[seq_start + k], pos, { seq }, need_logits); + n_outputs += need_logits; + } + + // restore the original token in case it was set to BOS + tokens[seq_start] = token_org; } if (llama_decode(ctx, batch)) { - LOG_ERR("%s : failed to eval\n", __func__); + LOG_ERR("%s : failed to decode\n", __func__); llama_batch_free(batch); return; } - // restore the original token in case it was set to BOS - tokens[batch_start] = token_org; - - if (num_batches > 1) { + if (num_batches > 1 && n_outputs > 0) { const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab); + logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab); } } - llama_batch_free(batch); - - const auto t_end = std::chrono::high_resolution_clock::now(); - if (i == 0) { + llama_synchronize(ctx); + const auto t_end = std::chrono::high_resolution_clock::now(); const float t_total = std::chrono::duration(t_end - t_start).count(); LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total); - int total_seconds = (int)(t_total * n_chunk); + int total_seconds = (int)(t_total * n_chunk / n_seq); if (total_seconds >= 60*60) { LOG("%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } LOG("%.2f minutes\n", total_seconds / 60.0); + LOG("\n"); + LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n"); } - LOG("\n"); - LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n"); - const int first = n_ctx/2; - const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); - process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr); - p_diff_ptr += n_ctx - 1 - first; - kld_ptr += n_ctx - 1 - first; + // Read log probs for each sequence in the batch + for (int seq = 0; seq < n_seq_batch; seq++) { + if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) { + LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i + seq); + llama_batch_free(batch); + return; + } - LOG("%4d", i+1); + const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); - auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); - const double ppl_val = exp(log_ppl.first); - const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 ) - LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc); + process_logits(n_vocab, all_logits, tokens.data() + start + seq*n_ctx + first, n_ctx - 1 - first, + workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr); + p_diff_ptr += n_ctx - 1 - first; + kld_ptr += n_ctx - 1 - first; - auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); - const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); - const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; - const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); - LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc); + LOG("%4d", i + seq + 1); - auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); - LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second); + auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); + const double ppl_val = exp(log_ppl.first); + const double ppl_unc = ppl_val * log_ppl.second; + LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc); - auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); - const double p_diff_rms_val = sqrt(p_diff_mse.first); - const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; - LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); + auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); + const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); + const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; + const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); + LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc); - double p_top_val = 1.*kld.n_same_top/kld.count; - double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1)); - LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc); + auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); + LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second); - LOG("\n"); + auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); + const double p_diff_rms_val = sqrt(p_diff_mse.first); + const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; + LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); + + double p_top_val = 1.*kld.n_same_top/kld.count; + double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1)); + LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc); + + LOG("\n"); + } logits.clear(); } + + llama_batch_free(batch); LOG("\n"); if (kld.count < 100) return; // we do not wish to do statistics on so few values @@ -1996,7 +2024,7 @@ int main(int argc, char ** argv) { const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence; - if (ppl) { + if (ppl || params.kl_divergence) { const int32_t n_seq = std::max(1, params.n_batch / n_ctx); const int32_t n_kv = n_seq * n_ctx; @@ -2006,12 +2034,8 @@ int main(int argc, char ** argv) { params.n_batch = std::min(params.n_batch, n_kv); } else { params.n_batch = std::min(params.n_batch, params.n_ctx); - if (params.kl_divergence) { - params.n_parallel = 1; - } else { - // ensure there's at least enough seq_ids for HellaSwag - params.n_parallel = std::max(4, params.n_parallel); - } + // ensure there's at least enough seq_ids for HellaSwag + params.n_parallel = std::max(4, params.n_parallel); } if (params.ppl_stride > 0) {