perplexity: add proper batching (#19661)

This commit is contained in:
AesSedai 2026-02-16 08:44:44 -08:00 committed by GitHub
parent cceb1b4e33
commit d612901116
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 89 additions and 65 deletions

View File

@ -347,7 +347,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
int count = 0; int count = 0;
double nll = 0.0; double nll = 0.0;
LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); const int n_seq = std::max(1, n_batch / n_ctx);
LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
for (int i = 0; i < n_chunk; ++i) { for (int i = 0; i < n_chunk; ++i) {
const int start = i * params.ppl_stride; const int start = i * params.ppl_stride;
@ -1737,11 +1738,21 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
} }
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int num_batches = (n_ctx + n_batch - 1)/n_batch; const int num_batches = (static_cast<int>(n_ctx) + n_batch - 1) / n_batch;
// Calculate n_seq based on the logits file's n_ctx, but cap it at what the context supports
const int n_seq_max = llama_n_seq_max(ctx);
int n_seq = std::max(1, n_batch / static_cast<int>(n_ctx));
if (n_seq > n_seq_max) {
LOG_WRN("%s: calculated n_seq=%d exceeds context's n_seq_max=%d, capping at %d\n",
__func__, n_seq, n_seq_max, n_seq_max);
n_seq = n_seq_max;
}
const int nv = 2*((n_vocab + 1)/2) + 4; const int nv = 2*((n_vocab + 1)/2) + 4;
const bool add_bos = llama_vocab_get_add_bos(vocab); const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
llama_batch batch = llama_batch_init(std::min(n_batch, static_cast<int>(n_ctx)*n_seq), 0, 1);
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector<float> p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
@ -1750,6 +1761,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
logits.reserve(size_t(n_ctx) * n_vocab); logits.reserve(size_t(n_ctx) * n_vocab);
} }
LOG_INF("%s: computing over %d chunks, n_ctx=%u, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1); std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) { auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) {
@ -1774,107 +1787,122 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
auto kld_ptr = kld_values.data(); auto kld_ptr = kld_values.data();
auto p_diff_ptr = p_diff_values.data(); auto p_diff_ptr = p_diff_values.data();
for (int i = 0; i < n_chunk; ++i) { const int first = n_ctx/2;
for (int i = 0; i < n_chunk; i += n_seq) {
const int start = i * n_ctx; const int start = i * n_ctx;
const int end = start + n_ctx; const int end = start + n_ctx;
const auto t_start = std::chrono::high_resolution_clock::now(); const int n_seq_batch = std::min(n_seq, n_chunk - i);
if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) { const auto t_start = std::chrono::high_resolution_clock::now();
LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i);
return;
}
// clear the KV cache // clear the KV cache
llama_memory_clear(llama_get_memory(ctx), true); llama_memory_clear(llama_get_memory(ctx), true);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch); const int batch_size = std::min(end - batch_start, n_batch);
// save original token and restore it after eval int n_outputs = 0;
const auto token_org = tokens[batch_start];
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[batch_start] = llama_vocab_bos(vocab);
}
common_batch_clear(batch); common_batch_clear(batch);
for (int i = 0; i < batch_size; i++) { for (int seq = 0; seq < n_seq_batch; seq++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); int seq_start = batch_start + seq*n_ctx;
// save original token and restore it after eval
const auto token_org = tokens[seq_start];
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[seq_start] = llama_vocab_bos(vocab);
}
for (int k = 0; k < batch_size; ++k) {
const int pos = j*n_batch + k;
const bool need_logits = pos >= first;
common_batch_add(batch, tokens[seq_start + k], pos, { seq }, need_logits);
n_outputs += need_logits;
}
// restore the original token in case it was set to BOS
tokens[seq_start] = token_org;
} }
if (llama_decode(ctx, batch)) { if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to decode\n", __func__);
llama_batch_free(batch); llama_batch_free(batch);
return; return;
} }
// restore the original token in case it was set to BOS if (num_batches > 1 && n_outputs > 0) {
tokens[batch_start] = token_org;
if (num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx); const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab); logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab);
} }
} }
llama_batch_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) { if (i == 0) {
llama_synchronize(ctx);
const auto t_end = std::chrono::high_resolution_clock::now();
const float t_total = std::chrono::duration<float>(t_end - t_start).count(); const float t_total = std::chrono::duration<float>(t_end - t_start).count();
LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total); LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk); int total_seconds = (int)(t_total * n_chunk / n_seq);
if (total_seconds >= 60*60) { if (total_seconds >= 60*60) {
LOG("%d hours ", total_seconds / (60*60)); LOG("%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60); total_seconds = total_seconds % (60*60);
} }
LOG("%.2f minutes\n", total_seconds / 60.0); LOG("%.2f minutes\n", total_seconds / 60.0);
LOG("\n");
LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
} }
LOG("\n");
LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
const int first = n_ctx/2; // Read log probs for each sequence in the batch
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); for (int seq = 0; seq < n_seq_batch; seq++) {
process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) {
workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr); LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i + seq);
p_diff_ptr += n_ctx - 1 - first; llama_batch_free(batch);
kld_ptr += n_ctx - 1 - first; return;
}
LOG("%4d", i+1); const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); process_logits(n_vocab, all_logits, tokens.data() + start + seq*n_ctx + first, n_ctx - 1 - first,
const double ppl_val = exp(log_ppl.first); workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr);
const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 ) p_diff_ptr += n_ctx - 1 - first;
LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc); kld_ptr += n_ctx - 1 - first;
auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); LOG("%4d", i + seq + 1);
const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second); const double ppl_val = exp(log_ppl.first);
const double ppl_unc = ppl_val * log_ppl.second;
LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc);
auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count);
const double p_diff_rms_val = sqrt(p_diff_mse.first); const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count);
const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first;
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov);
LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc);
double p_top_val = 1.*kld.n_same_top/kld.count; auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1)); LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second);
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
LOG("\n"); auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count);
const double p_diff_rms_val = sqrt(p_diff_mse.first);
const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second;
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc);
double p_top_val = 1.*kld.n_same_top/kld.count;
double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1));
LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc);
LOG("\n");
}
logits.clear(); logits.clear();
} }
llama_batch_free(batch);
LOG("\n"); LOG("\n");
if (kld.count < 100) return; // we do not wish to do statistics on so few values if (kld.count < 100) return; // we do not wish to do statistics on so few values
@ -1996,7 +2024,7 @@ int main(int argc, char ** argv) {
const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence; const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
if (ppl) { if (ppl || params.kl_divergence) {
const int32_t n_seq = std::max(1, params.n_batch / n_ctx); const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
const int32_t n_kv = n_seq * n_ctx; const int32_t n_kv = n_seq * n_ctx;
@ -2006,12 +2034,8 @@ int main(int argc, char ** argv) {
params.n_batch = std::min(params.n_batch, n_kv); params.n_batch = std::min(params.n_batch, n_kv);
} else { } else {
params.n_batch = std::min(params.n_batch, params.n_ctx); params.n_batch = std::min(params.n_batch, params.n_ctx);
if (params.kl_divergence) { // ensure there's at least enough seq_ids for HellaSwag
params.n_parallel = 1; params.n_parallel = std::max(4, params.n_parallel);
} else {
// ensure there's at least enough seq_ids for HellaSwag
params.n_parallel = std::max(4, params.n_parallel);
}
} }
if (params.ppl_stride > 0) { if (params.ppl_stride > 0) {