server: introduce self-speculative decoding

This commit is contained in:
Sascha Rogmann 2025-12-29 20:46:32 +01:00
parent ced765be44
commit 5d67c20387
6 changed files with 105 additions and 11 deletions

View File

@ -3216,6 +3216,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.replacements.push_back({ tgt, dft });
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-self"}, "<0|1>",
"use self-speculation without a draft model (default: 0, no self speculation without draft model)",
[](common_params & params, int value) {
params.speculative.use_self = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(

View File

@ -242,6 +242,7 @@ struct common_params_speculative {
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
int32_t use_self = 0; // use self-speculative decoding without draft model (default: 0 = off)
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;

View File

@ -359,3 +359,60 @@ llama_tokens common_speculative_gen_draft(
}
return result;
}
llama_tokens common_speculative_gen_self_draft(const llama_tokens & tokens, llama_token sampled,
size_t n_draft_min, size_t n_draft_max) {
const size_t cur_len = tokens.size();
// vector for tokens we want to verify.
// return empty vector if there is no match.
llama_tokens draft_tokens;
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
return draft_tokens;
}
// pattern search
llama_tokens pattern;
pattern.reserve(n_draft_min);
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
pattern.push_back(tokens[j]);
}
pattern.push_back(sampled); // add the last token to the pattern
size_t match_pos = 0; // we ignore position 0, position 0 == no match
// search backwards, but skip the current match (we are currently there)
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
bool match = true;
for (size_t k = 0; k < pattern.size(); ++k) {
if (tokens[j + k] != pattern[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
}
if (match_pos == 0) {
return draft_tokens;
}
const size_t copy_max = std::min(
n_draft_max,
cur_len - (match_pos + n_draft_min)
);
if (copy_max < n_draft_min) {
return draft_tokens;
}
LOG_DBG("%s: #tokens = %ld: found matching pattern at pos %ld, length %ld, draft length %ld\n",
__func__, (int64_t) cur_len,
(int64_t) match_pos, (int64_t) pattern.size(), copy_max);
draft_tokens.reserve(copy_max);
for (size_t j = 0; j < copy_max; ++j) {
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
}
return draft_tokens;
}

View File

@ -33,3 +33,19 @@ llama_tokens common_speculative_gen_draft(
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
/**
* Perform speculative generation using the model's own token history.
* Searches for a matching pattern in the token history and returns draft tokens.
*
* @param tokens Token history to search in
* @param sampled Last sampled token
* @param n_draft_min Minimum number of draft tokens required
* @param n_draft_max Maximum number of draft tokens to generate
* @return Vector of draft tokens, empty if no matching pattern is found
*/
llama_tokens common_speculative_gen_self_draft(
const llama_tokens & tokens,
llama_token sampled,
size_t n_draft_min,
size_t n_draft_max);

View File

@ -264,7 +264,7 @@ struct server_slot {
}
int get_n_draft_max() const {
if (!can_speculate()) {
if (!can_speculate() && !task->params.speculative.use_self) {
return 0;
}
@ -1153,7 +1153,7 @@ private:
// initialize draft batch
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
if (slot.ctx_dft) {
if (slot.ctx_dft || task.params.speculative.use_self) {
llama_batch_free(slot.batch_spec);
slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1);
@ -1974,12 +1974,23 @@ private:
GGML_ABORT("not supported by multimodal");
}
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
params_spec.p_min = slot.task->params.speculative.p_min;
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
llama_tokens draft = {};
if (slot.task->params.speculative.use_self) {
// we search at least 5 tokens in history to try a self-speculative draft
const int n_draft_min = std::max(5, slot.task->params.speculative.n_min);
const llama_tokens & tokens = slot.prompt.tokens.get_text_tokens();
llama_token id = slot.sampled;
draft = common_speculative_gen_self_draft(tokens, id, n_draft_min, n_draft_max);
}
if (draft.empty() && slot.can_speculate()) {
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
params_spec.p_min = slot.task->params.speculative.p_min;
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
}
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
@ -2748,6 +2759,7 @@ private:
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens());
}
}
SRV_DBG("%s", "run slots completed\n");

View File

@ -206,9 +206,10 @@ task_params server_task::params_from_json_cmpl(
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
params.speculative.use_self = json_value(data, "speculative.use_self", defaults.speculative.use_self);
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
params.speculative.n_min = std::max(params.speculative.n_min, 0);