This commit is contained in:
Sascha Rogmann 2026-01-02 22:43:02 +01:00 committed by GitHub
commit cfaba1d7f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 112 additions and 11 deletions

View File

@ -3216,6 +3216,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.replacements.push_back({ tgt, dft });
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--spec-self"}, "<0|1>",
"use self-speculation without a draft model (default: 0, no self speculation without draft model)",
[](common_params & params, int value) {
params.speculative.use_self = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
string_format(

View File

@ -242,6 +242,7 @@ struct common_params_speculative {
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
int32_t use_self = 0; // use self-speculative decoding without draft model (default: 0 = off)
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;

View File

@ -187,6 +187,18 @@ llama_tokens common_speculative_gen_draft(
struct common_speculative_params params,
const llama_tokens & prompt_tgt_main_model, // specified in target model vocab
llama_token id_last) {
if (params.self_mode == 1) {
// Look in the current context for a n-gram and return the following tokens as the draft.
llama_tokens draft_self = common_speculative_gen_self_draft(prompt_tgt_main_model, id_last,
params.self_ngram_size, params.n_draft);
if (!draft_self.empty()) {
return draft_self;
}
}
if (spec == nullptr) {
return {};
}
auto & batch = spec->batch;
auto & ctx_tgt = spec->ctx_tgt;
auto & ctx_dft = spec->ctx_dft;
@ -359,3 +371,60 @@ llama_tokens common_speculative_gen_draft(
}
return result;
}
llama_tokens common_speculative_gen_self_draft(const llama_tokens & tokens, llama_token sampled,
size_t n_draft_min, size_t n_draft_max) {
const size_t cur_len = tokens.size();
// vector for tokens we want to verify.
// return empty vector if there is no match.
llama_tokens draft_tokens;
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
return draft_tokens;
}
// pattern search
llama_tokens pattern;
pattern.reserve(n_draft_min);
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
pattern.push_back(tokens[j]);
}
pattern.push_back(sampled); // add the last token to the pattern
size_t match_pos = 0; // we ignore position 0, position 0 == no match
// search backwards, but skip the current match (we are currently there)
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
bool match = true;
for (size_t k = 0; k < pattern.size(); ++k) {
if (tokens[j + k] != pattern[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
}
if (match_pos == 0) {
return draft_tokens;
}
const size_t copy_max = std::min(
n_draft_max,
cur_len - (match_pos + n_draft_min)
);
if (copy_max < n_draft_min) {
return draft_tokens;
}
LOG_DBG("%s: #tokens = %ld: found matching pattern at pos %ld, length %ld, draft length %ld\n",
__func__, (int64_t) cur_len,
(int64_t) match_pos, (int64_t) pattern.size(), copy_max);
draft_tokens.reserve(copy_max);
for (size_t j = 0; j < copy_max; ++j) {
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
}
return draft_tokens;
}

View File

@ -6,10 +6,13 @@
struct common_speculative;
struct common_speculative_params {
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
int n_draft = 16; // max drafted tokens
int n_reuse = 256;
float p_min = 0.75f; // min probability required to accept a token in the draft
float p_min = 0.75f; // min probability required to accept a token in the draft
int self_mode = 0; // 0: off, 1: self speculative lookup
int self_ngram_size = 12; // length of pattern to search for in self mode
};
struct common_speculative * common_speculative_init(
@ -33,3 +36,19 @@ llama_tokens common_speculative_gen_draft(
struct common_speculative_params params,
const llama_tokens & prompt,
llama_token id_last);
/**
* Perform speculative generation using the model's own token history.
* Searches for a matching pattern in the token history and returns draft tokens.
*
* @param tokens Token history to search in
* @param sampled Last sampled token
* @param n_draft_min Minimum number of draft tokens required
* @param n_draft_max Maximum number of draft tokens to generate
* @return Vector of draft tokens, empty if no matching pattern is found
*/
llama_tokens common_speculative_gen_self_draft(
const llama_tokens & tokens,
llama_token sampled,
size_t n_draft_min,
size_t n_draft_max);

View File

@ -251,8 +251,9 @@ struct server_slot {
return state != SLOT_STATE_IDLE;
}
// Checks if a draft model is active or self-speculation using context-tokens
bool can_speculate() const {
return ctx_dft;
return ctx_dft || task->params.speculative.use_self;
}
void add_token(const completion_token_output & token) {
@ -1153,7 +1154,7 @@ private:
// initialize draft batch
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
if (slot.ctx_dft) {
if (slot.can_speculate()) {
llama_batch_free(slot.batch_spec);
slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1);
@ -1975,9 +1976,11 @@ private:
}
struct common_speculative_params params_spec;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
params_spec.p_min = slot.task->params.speculative.p_min;
params_spec.n_draft = n_draft_max;
params_spec.n_reuse = slot.ctx_dft ? (llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max) : 0;
params_spec.p_min = slot.task->params.speculative.p_min;
params_spec.self_mode = slot.task->params.speculative.use_self;
params_spec.self_ngram_size = std::max(5, slot.task->params.speculative.n_min);
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
@ -2748,6 +2751,7 @@ private:
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens());
}
}
SRV_DBG("%s", "run slots completed\n");

View File

@ -206,9 +206,10 @@ task_params server_task::params_from_json_cmpl(
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
params.speculative.use_self = json_value(data, "speculative.use_self", defaults.speculative.use_self);
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
params.speculative.n_min = std::max(params.speculative.n_min, 0);