common: simplify speculative sampling to greedy-only for performance

Removes heavy penalty checks (repetition, frequency, presence, DRY) from
`common_sampler_sample_speculative`.

The specialized speculative sampler now uses a pure ArgMax (Greedy) approach.
This significantly reduces CPU overhead during the drafting phase, which
improves overall tokens per second.
This commit is contained in:
samuel 2025-12-19 21:57:15 -03:00 committed by Aaron Lee
parent a3e29da02a
commit a8dc54672c
1 changed files with 3 additions and 16 deletions

View File

@ -670,26 +670,13 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
/**
* Specialized sampling for speculative drafting.
*
* Prioritizes performance by using a direct ArgMax loop (Greedy) when no
* penalties (repetition, frequency, presence, DRY) are configured.
* Falls back to the full sampler chain if penalties are active to prevent
* generative loops or adhere to constraints.
* Prioritizes performance by using a direct ArgMax loop (Greedy).
* Penalties and complex sampling logic are bypassed to minimize
* drafting latency.
*/
llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
const auto & params = gsmpl->params;
bool use_heavy_sampler =
(params.penalty_last_n > 0 && (
params.penalty_repeat != 1.0f ||
params.penalty_freq != 0.0f ||
params.penalty_present != 0.0f
)) ||
(params.dry_allowed_length > 0 && params.dry_multiplier != 0.0f);
if (use_heavy_sampler) {
return common_sampler_sample(gsmpl, ctx, idx, false);
}
float * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_model_get_vocab(llama_get_model(ctx)));