This commit is contained in:
Sascha Rogmann 2026-04-04 10:15:44 +03:00 committed by GitHub
commit 56919a38e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 695 additions and 126 deletions

View File

@ -324,6 +324,7 @@ struct common_params_speculative {
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models
std::shared_ptr<common_ngram_mod> ngram_mod;

View File

@ -208,7 +208,7 @@ void common_ngram_map_begin(
count_keys, count_keys_del, count_values_del, count_map_entries_upd);
}
map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0;
map.idx_last_check = size_begin;
map.size_last_begin = size_begin;
}
@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map,
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
}
if (map.idx_last_check > cur_len) {
if (map.idx_last_check > cur_len) {
// Should not happen because of common_ngram_map_begin().
GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
}
@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map,
LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
curr_key.key_idx, key_offset, curr_key.key_num, draft.size());
map.last_draft_created = false;
map.last_draft_created = true;
map.last_draft_key_idx = key_offset;
map.last_draft_value_idx = 0; // value 0 is used for simple mode
return;
@ -524,7 +524,7 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
// update the value statistics
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
LOG_DBG("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
n_accepted, curr_value.n_accepted);
curr_value.n_accepted = n_accepted;
}

View File

@ -13,6 +13,7 @@
#include <cstring>
#include <iomanip>
#include <map>
#include <cinttypes>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@ -144,10 +145,28 @@ struct common_speculative_state {
virtual void accept(uint16_t n_accepted) = 0;
};
struct common_speculative_checkpoint {
llama_pos pos_min = 0;
llama_pos pos_max = 0;
int64_t n_tokens = 0;
std::vector<uint8_t> data;
size_t size() const {
return data.size();
}
size_t ckpt_size = 0;
};
struct common_speculative_state_draft : public common_speculative_state {
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
llama_context * ctx_dft;
struct common_speculative_checkpoint ckpt;
bool use_checkpoint;
common_sampler * smpl;
llama_batch batch;
@ -160,10 +179,12 @@ struct common_speculative_state_draft : public common_speculative_state {
enum common_speculative_type type,
llama_context * ctx_tgt,
llama_context * ctx_dft,
const std::vector<std::pair<std::string, std::string>> & replacements)
const std::vector<std::pair<std::string, std::string>> & replacements,
bool use_checkpoint)
: common_speculative_state(type)
, ctx_tgt(ctx_tgt)
, ctx_dft(ctx_dft)
, use_checkpoint(use_checkpoint)
{
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
smpl = nullptr;
@ -218,7 +239,47 @@ struct common_speculative_state_draft : public common_speculative_state {
}
void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt);
if (use_checkpoint && ckpt.size() > 0) {
// delete checkpoint
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
ckpt.pos_min = 0;
ckpt.pos_max = 0;
ckpt.n_tokens = 0;
ckpt.ckpt_size = 0;
ckpt.data.clear();
}
}
size_t draft_init_checkpoint(int n_tokens_prompt, int n_tokens_batch) {
int slot_id = 0;
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
ckpt.n_tokens = n_tokens_prompt - n_tokens_batch;
ckpt.data.resize(checkpoint_size);
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}
LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__,
ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024);
return n;
}
size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) {
int slot_id = 0;
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx_dft,
ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_part_expected) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
}
return n;
}
void draft(
@ -236,8 +297,8 @@ struct common_speculative_state_draft : public common_speculative_state {
auto * mem_dft = llama_get_memory(ctx_dft);
int reuse_i = 0;
int reuse_n = 0;
int reuse_i = 0; // index of part to be reused in prompt_dft
int reuse_n = 0; // length of part to be reused in prompt_dft
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
@ -287,18 +348,26 @@ struct common_speculative_state_draft : public common_speculative_state {
}
}
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) {
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
__func__, reuse_i, reuse_n);
reuse_i = 0;
reuse_n = 0;
}
result.clear();
result.reserve(params.n_max);
if (reuse_n == 0) {
bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0;
if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) {
llama_memory_clear(mem_dft, false);
prompt_dft.clear();
} else {
// this happens when a previous draft has been discarded (for example, due to being too small), but the
// target model agreed with it. in this case, we simply pass back the previous results to save compute
if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
result.push_back(prompt_dft[i]);
@ -310,19 +379,50 @@ struct common_speculative_state_draft : public common_speculative_state {
return;
}
bool do_restore = false;
if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) {
// This can happen after a partial acceptance (speculative decoding with checkpoints)
LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n",
__func__, prompt_dft.size(), prompt_cur.size());
prompt_dft.resize(prompt_cur.size());
do_restore = true;
}
if (reuse_i > 0) {
llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
if (!is_removed) {
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
}
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
}
if (reuse_n < (int) prompt_dft.size()) {
llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
if (reuse_n < (int) prompt_dft.size() || do_restore) {
if (use_checkpoint) {
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%zu, reuse_n=%d, prompt_dft.size=%zu\n",
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
}
draft_restore_checkpoint(ckpt.ckpt_size);
reuse_n = ckpt.n_tokens;
prompt_dft.resize(reuse_n);
needs_ckpt = false;
} else {
bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
if (!is_removed) {
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n",
__func__, reuse_n, prompt_dft.size());
}
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
}
}
}
if (needs_ckpt && use_checkpoint) {
ckpt.ckpt_size = draft_init_checkpoint(prompt_dft.size(), batch.n_tokens);
}
// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);
@ -337,7 +437,11 @@ struct common_speculative_state_draft : public common_speculative_state {
if (batch.n_tokens > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
llama_decode(ctx_dft, batch);
int ret = llama_decode(ctx_dft, batch);
if (ret != 0 && ret != 1) {
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
__func__, ret, prompt_cur.size());
}
}
const llama_pos n_past = prompt_dft.size();
@ -351,7 +455,11 @@ struct common_speculative_state_draft : public common_speculative_state {
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
llama_decode(ctx_dft, batch);
int ret = llama_decode(ctx_dft, batch);
if (ret != 0 && ret != 1) {
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
__func__, ret, prompt_cur.size(), prompt_dft.size());
}
common_sampler_reset(smpl);
@ -387,7 +495,11 @@ struct common_speculative_state_draft : public common_speculative_state {
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
// evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch);
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
__func__, i, ret, prompt_cur.size(), prompt_dft.size());
}
prompt_dft.push_back(id);
}
@ -798,13 +910,13 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
return it->second;
}
bool common_speculative_is_compat(llama_context * ctx_tgt) {
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) {
auto * mem = llama_get_memory(ctx_tgt);
if (mem == nullptr) {
return false;
return COMMON_SPECULATIVE_COMPAT_TYPE_NO;
}
bool res = true;
common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL;
llama_memory_clear(mem, true);
@ -816,14 +928,14 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) {
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = false;
res = COMMON_SPECULATIVE_COMPAT_TYPE_NO;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = false;
res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT;
goto done;
}
@ -909,9 +1021,10 @@ common_speculative * common_speculative_init(
break;
case COMMON_SPECULATIVE_TYPE_DRAFT: {
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements
/* .ctx_tgt = */ ctx_tgt,
/* .ctx_dft = */ ctx_dft,
/* .replacements = */ params.replacements,
/* .use_checkpoint= */ params.use_checkpoints
));
break;
}
@ -1072,3 +1185,265 @@ void common_speculative_print_stats(const common_speculative * spec) {
str_perf.c_str());
}
}
// server callbacks
//
common_speculative_callback::~common_speculative_callback() = default;
// server session
//
struct common_speculative_session::impl {
common_speculative_callback & callback;
common_params_speculative params_spec;
llama_context * ctx_tgt = nullptr;
common_speculative * spec = nullptr;
// `i_batch_dft`, idx of draft tokens in the main batch are stored in the caller
llama_tokens draft;
// use of checkpoints in speculative mode
bool spec_has_ckpt = false; // true if a checkpoint for rollback after partial speculation has been created
uint16_t spec_ckpt_n_denials = 0; // number of drafts not accepted at the current position (0 or 1)
size_t spec_ckpt_size_part = 0; // size of partial checkpoint
// Speculative decoding stats
int32_t n_draft_total = 0; // Total draft tokens generated
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
impl(common_speculative_callback & callback,
const common_params_speculative & params,
llama_context * ctx_tgt)
: callback(callback), params_spec(params), ctx_tgt(ctx_tgt) {
spec = common_speculative_init(params_spec, ctx_tgt);
}
void begin(const llama_tokens & prompt_history) {
common_speculative_begin(spec, prompt_history);
}
bool has_batch_dft() {
return !draft.empty();
}
void clear_draft() {
draft.clear();
spec_ckpt_n_denials = 0;
}
llama_tokens compute_draft(
const llama_tokens & cached_text_tokens,
llama_token id_last,
const int n_draft_max) {
if (spec == nullptr) {
// no implementation, nothing to do
clear_draft();
return draft;
}
if (n_draft_max == 0) {
clear_draft();
return draft;
}
if (params_spec.use_checkpoints && spec_ckpt_n_denials > 1) {
// We shouldn't get two denials.
LOG_WRN("%s: #tokens=%zu, spec_ckpt_n_denials=%d, id_last=%d, #draft=%zu\n", __func__,
cached_text_tokens.size(), spec_ckpt_n_denials, id_last, draft.size());
clear_draft();
return draft;
}
if (spec_ckpt_n_denials == 1) {
// there is a previous speculation which wasn't accepted in full length
if (draft.empty()) {
// switch to non-draft inference
LOG_DBG("%s: draft of length 0 after denied checkpoint\n", __func__);
return draft;
}
// we use the shortened draft of previous speculation
LOG_DBG("%s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n", __func__,
cached_text_tokens.size(), id_last, draft.size());
} else if (spec_ckpt_n_denials > 1) {
GGML_ABORT("illegal state: spec_ckpt_n_denials = %d > 1", spec_ckpt_n_denials);
} else {
// call the speculative implementation to create a draft
draft = common_speculative_draft(spec, params_spec, cached_text_tokens, id_last);
LOG_DBG("draft: id_last=%d, #draft=%zu\n", id_last, draft.size());
if (draft.empty()) {
clear_draft();
return draft;
}
}
if (draft.size() > (size_t) n_draft_max) {
LOG_WRN("draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
draft.resize(n_draft_max);
}
bool do_checkpoint = !draft.empty() && params_spec.use_checkpoints;
if (do_checkpoint && cached_text_tokens.size() > 5 && draft.size() >= 3) {
LOG_DBG("%s: #tokens=%zu, draft.size=%zu, n_spec_denials=%d, do_checkpoint=%s, id_last=%d, tokens=[..., %d, %d, %d], draft=[%d, %d, %d, ...]\n",
__func__,
cached_text_tokens.size(),
draft.size(), spec_ckpt_n_denials,
do_checkpoint ? "yes" : "no", id_last,
cached_text_tokens[cached_text_tokens.size() - 3],
cached_text_tokens[cached_text_tokens.size() - 2],
cached_text_tokens[cached_text_tokens.size() - 1],
draft[0], draft[1], draft[2]);
}
if (params_spec.n_min > (int) draft.size()) {
LOG_DBG("ignoring small draft: %d < %d\n", (int) draft.size(), params_spec.n_min);
clear_draft();
return draft;
}
if (do_checkpoint) {
const size_t n = callback.create_checkpoint();
if (n == 0) {
LOG_WRN("%s: checkpoint creation failed (#tokens=%zu)\n", __func__, cached_text_tokens.size());
clear_draft();
return draft;
}
spec_ckpt_size_part = n;
spec_has_ckpt = true;
}
// add last sampled token to the batch
callback.batch_add_token(id_last, true);
// add all drafted tokens to the batch
for (size_t i = 0; i < draft.size(); i++) {
callback.batch_add_token(draft[i], true);
}
return draft;
}
common_speculative_accept_response sample_and_accept() {
const size_t n_draft = draft.size();
// the accepted tokens from the speculation
auto ids = callback.sampler_sample_and_accept_n(draft);
LOG_DBG("%s: n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size());
if (ids.size() < n_draft + 1) {
// the main model rejected some tokens
// we shorten the draft
draft.resize(ids.size() - 1);
if (spec_has_ckpt) {
// we need to rollback to the state before sampling the draft tokens
// (restore_checkpoint shortens context and slot.prompt.tokens)
const size_t n = callback.restore_checkpoint(spec_ckpt_size_part);
LOG_DBG("%s: partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
__func__,
ids.size() - 1, n_draft, n);
// delete Checkpoint
callback.delete_checkpoint();
spec_has_ckpt = false;
spec_ckpt_n_denials++;
if (ids.size() > 1u + static_cast<std::size_t>(params_spec.n_min) && spec_ckpt_n_denials == 1) {
// we will do the batch again but with the shortened draft
return common_speculative_accept_response(std::move(ids), n_draft, true);
}
LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size());
draft.clear();
// use the sampled token only
ids.resize(1);
// drafted tokens in prompt have been deleted in restore_checkpoint(...).
// skip acceptance, don't calculate a new draft
return common_speculative_accept_response{std::move(ids), 0, true};
}
}
const size_t draft_size_accepted = draft.size();
LOG_DBG("%s: draft.size=%zu, ids.size=%zu\n", __func__, draft_size_accepted, ids.size());
common_speculative_accept(spec, draft_size_accepted);
draft.clear();
return common_speculative_accept_response{std::move(ids), n_draft, false};
}
void rewind(llama_pos p0) {
spec_ckpt_n_denials = 0;
if (spec_has_ckpt) {
// Delete Checkpoint
callback.delete_checkpoint();
spec_has_ckpt = false;
} else {
callback.memory_seq_rm(p0, -1);
}
}
void print_stats() const {
if (spec == nullptr) {
return;
}
common_speculative_print_stats(spec);
}
void reset() {
if (spec == nullptr) {
return;
}
clear_draft();
spec_has_ckpt = false;
spec_ckpt_size_part = 0;
}
};
common_speculative_session::common_speculative_session(
common_speculative_callback & callback,
const common_params_speculative & params,
llama_context * ctx_tgt) : p_impl(new impl{callback, params, ctx_tgt}) {
}
common_speculative_session::~common_speculative_session() {
common_speculative_free(p_impl->spec);
delete p_impl;
}
void common_speculative_session::begin(const llama_tokens & prompt_history) {
p_impl->begin(prompt_history);
}
bool common_speculative_session::has_batch_dft() {
return !p_impl->has_batch_dft();
}
llama_tokens common_speculative_session::compute_draft(
const llama_tokens & prompt,
llama_token id_last,
int n_draft_max_slot) {
return p_impl->compute_draft(prompt, id_last, n_draft_max_slot);
}
common_speculative_accept_response common_speculative_session::sample_and_accept() {
return p_impl->sample_and_accept();
}
void common_speculative_session::rewind(const llama_pos p0) {
p_impl->rewind(p0);
}
void common_speculative_session::print_stats() const {
p_impl->print_stats();
}
void common_speculative_session::reset() {
p_impl->reset();
}

View File

@ -3,6 +3,15 @@
#include "llama.h"
#include "common.h"
// common/speculative.h has two interfaces:
//
// 1) struct common_speculative with init, begin, draft, accept and print_stats
// Simple interface, see examples/speculative/speculative.cpp
//
// 2) struct common_speculative_session with struct common_speculative_callback
// Complex interface which supports checkpoints, see tools/server/server-context.cpp
//
struct common_speculative;
// comma separated list of all types
@ -14,9 +23,15 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
enum common_speculative_compat_type {
COMMON_SPECULATIVE_COMPAT_TYPE_NO = 0,
COMMON_SPECULATIVE_COMPAT_TYPE_FULL = 1,
COMMON_SPECULATIVE_COMPAT_TYPE_CKPT = 2,
};
// check if the llama_context is compatible for speculative decoding
// note: clears the memory of the context
bool common_speculative_is_compat(llama_context * ctx_tgt);
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt);
common_speculative * common_speculative_init(
common_params_speculative & params,
@ -39,3 +54,88 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
// print statistics about the speculative decoding
void common_speculative_print_stats(const common_speculative * spec);
// Interactions with server
//
// callback implemented by the server
struct common_speculative_callback {
virtual ~common_speculative_callback();
// Add a token to the draft sequence.
virtual void batch_add_token(llama_token token, bool logits) = 0;
// Sample and accept tokens from the main model.
virtual llama_tokens sampler_sample_and_accept_n(const llama_tokens & drafted) = 0;
// Deletes a part of the context.
// Returns true if the memory was modified.
virtual bool memory_seq_rm(llama_pos p0, llama_pos p1) = 0;
// Creates a checkpoint of the current state of the context.
// Returns the size of the checkpoint in bytes.
virtual size_t create_checkpoint() = 0;
// Restore a checkpoint previously created by create_checkpoint().
// Returns the size of the restored checkpoint in bytes.
virtual size_t restore_checkpoint(size_t ckpt_size_part_expected) = 0;
// Delete a checkpoint previously created by create_checkpoint().
virtual void delete_checkpoint() = 0;
};
struct common_speculative_accept_response {
llama_tokens tokens;
size_t draft_size_initial;
bool skip_acceptance;
common_speculative_accept_response(llama_tokens t, size_t draft_size_initial, bool skip)
: tokens(std::move(t)), draft_size_initial(draft_size_initial), skip_acceptance(skip) {}
};
// speculative decoding which may use checkpoints to rewind in tokens history
struct common_speculative_session {
common_speculative_session(
common_speculative_callback & callback,
const common_params_speculative & params,
llama_context * ctx_tgt);
~common_speculative_session();
// don't copy
common_speculative_session(const common_speculative_session &) = delete;
common_speculative_session & operator=(const common_speculative_session &) = delete;
// call once at the beginning of a new generation
// some spec implementations use the prompt history to initialize lookup maps
void begin(const llama_tokens & prompt_history);
bool has_batch_dft();
// do speculative decoding to compute a draft of tokens
llama_tokens compute_draft(const llama_tokens & prompt,
llama_token id_last,
int n_draft_max_slot);
// check if and how far the current draft is accepted
common_speculative_accept_response sample_and_accept();
// rewind (because of a draft not fully accepted)
void rewind(llama_pos p0);
// print statistics
void print_stats() const;
// reset and delete structures
void reset();
private:
struct impl;
impl * p_impl;
};

View File

@ -1,3 +1,4 @@
#include "server-context.h"
#include "server-common.h"
#include "server-http.h"
@ -56,7 +57,9 @@ struct server_slot {
// multimodal
mtmd_context * mctx = nullptr;
common_speculative * spec = nullptr;
std::unique_ptr<common_speculative_callback> spec_callback;
std::unique_ptr<common_speculative_session> spec_session = nullptr;
struct common_sampler * spec_saved_sampler = nullptr;
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
@ -147,7 +150,6 @@ struct server_slot {
common_sampler_ptr smpl;
llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted;
// stats
size_t n_sent_text = 0; // number of sent text character
@ -177,7 +179,9 @@ struct server_slot {
stopping_word = "";
n_sent_text = 0;
drafted.clear();
if (spec_session != nullptr) {
spec_session->reset();
}
i_batch_dft.clear();
generated_tokens.clear();
generated_token_probs.clear();
@ -259,7 +263,7 @@ struct server_slot {
}
bool can_speculate() const {
return !!spec;
return !!spec_session;
}
void add_token(const completion_token_output & token) {
@ -399,7 +403,9 @@ struct server_slot {
);
}
common_speculative_print_stats(spec);
if (spec_session) {
spec_session->print_stats();
}
}
json to_json(bool only_metrics = false) const {
@ -598,8 +604,14 @@ private:
// Clear any sampling context
for (server_slot & slot : slots) {
common_speculative_free(slot.spec);
slot.spec = nullptr;
if (slot.spec_session != nullptr) {
slot.spec_session->reset();
slot.spec_session = nullptr;
}
if (slot.spec_saved_sampler != nullptr) {
common_sampler_free(slot.spec_saved_sampler);
slot.spec_saved_sampler = nullptr;
}
}
llama_batch_free(batch);
@ -630,6 +642,97 @@ private:
sleeping = new_state;
}
//
// callback for speculative decoding
//
struct server_speculative_callback : public common_speculative_callback {
int slot_id; // store slot.id instead of server_slot & slot
server_context_impl & ctx_impl;
server_speculative_callback(int slot_id, server_context_impl & ctx_impl)
: slot_id(slot_id), ctx_impl(ctx_impl) {}
server_slot * get_slot() {
server_slot * slot = ctx_impl.get_slot_by_id(slot_id);
if (slot == nullptr) {
GGML_ABORT("missing slot, slot.id=%d", slot_id);
}
return slot;
}
void batch_add_token(llama_token token, bool logits) override {
server_slot * slot = get_slot();
slot->i_batch_dft.push_back(ctx_impl.batch.n_tokens);
common_batch_add(ctx_impl.batch, token, slot->prompt.tokens.pos_next(), { slot_id }, logits);
slot->prompt.tokens.push_back(token);
}
std::vector<llama_token> sampler_sample_and_accept_n(const llama_tokens & drafted) override {
const server_slot * slot = get_slot();
if (slot->i_batch_dft.size() != 1 + drafted.size()) {
GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu",
__func__, slot->i_batch_dft.size(), 1 + drafted.size());
}
const auto ids = common_sampler_sample_and_accept_n(slot->smpl.get(), ctx_impl.ctx, slot->i_batch_dft, drafted);
return ids;
}
bool memory_seq_rm(llama_pos p0, llama_pos p1) override {
return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot_id, p0, p1);
}
size_t create_checkpoint() override {
server_slot * slot = get_slot();
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot_id);
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot_id);
const auto n_tokens_cur = 0; // TODO was ctx_impl.batch.n_tokens; The draft model doesn't change the prompt?
const auto cur = ctx_impl.get_checkpoint(*slot, n_tokens_cur, pos_min, pos_max);
SLT_DBG(*slot, "created context checkpoint %zu of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
slot->prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints,
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
if (slot->spec_saved_sampler != nullptr) {
common_sampler_free(slot->spec_saved_sampler);
}
// save sampler (we may want to restore the RNG in the sampler after refusal of a draft)
slot->spec_saved_sampler = common_sampler_clone(slot->smpl.get());
return cur.size();
}
size_t restore_checkpoint(size_t ckpt_size_part_expected) override {
server_slot * slot = get_slot();
auto & ckpt = slot->prompt.checkpoints.back();
SLT_DBG(*slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx,
ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != ckpt_size_part_expected) {
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
}
// remove entries after ckpt.pos_max
llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot->id, ckpt.pos_max + 1, -1);
slot->prompt.tokens.keep_first(ckpt.pos_max + 1);
if (slot->spec_saved_sampler != nullptr) {
slot->smpl.reset(slot->spec_saved_sampler);
slot->spec_saved_sampler = nullptr;
}
return n;
}
void delete_checkpoint() override {
server_slot * slot = get_slot();
slot->prompt.checkpoints.pop_back();
}
};
// load the model and initialize llama_context
// this may also be called to resume from sleeping state
bool load_model(const common_params & params) {
@ -656,6 +759,7 @@ private:
add_bos_token = llama_vocab_get_add_bos(vocab);
if (params_base.speculative.has_dft()) {
// TODO speculative: move to common/speculative.cpp?
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
const auto & params_spec = params_base.speculative;
@ -764,14 +868,23 @@ private:
slots.clear();
const bool can_spec = common_speculative_is_compat(ctx);
if (!can_spec) {
const auto spec_type = common_speculative_is_compat(ctx);
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) {
SRV_WRN("%s", "speculative decoding will use checkpoints\n");
params_base.speculative.use_checkpoints = true;
}
// initialize slots
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
// Create a new slot in the vector.
slots.emplace_back();
// Get a reference of the new slot.
server_slot & slot = slots.back();
slot.id = i;
slot.ctx = ctx;
@ -781,17 +894,15 @@ private:
slot.prompt.tokens.has_mtmd = mctx != nullptr;
// try speculative decoding
if (can_spec) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx) {
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
return false;
}
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
} else {
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) {
if (mctx) {
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
return false;
}
slot.spec_callback = std::make_unique<server_speculative_callback>(slot.id, *this);
slot.spec_session = std::make_unique<common_speculative_session>(*slot.spec_callback,
params_base.speculative, slot.ctx);
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
}
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
@ -801,8 +912,6 @@ private:
};
slot.reset();
slots.push_back(std::move(slot));
}
{
@ -1192,7 +1301,7 @@ private:
backend_sampling &= task.params.sampling.backend_sampling;
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0);
backend_sampling &= !(slot.spec_session && task.params.speculative.n_max > 0);
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
@ -1698,6 +1807,38 @@ private:
return true;
}
// Creates a checkpoint.
//
// n_tokens_cur: the number of tokens added to the batch for the current slot
server_prompt_checkpoint get_checkpoint(server_slot & slot, const int64_t n_tokens_cur,
llama_pos pos_min, llama_pos pos_max) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.n_tokens = */ slot.prompt.n_tokens() - n_tokens_cur,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
if (n != checkpoint_size) {
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
}
return cur;
}
void process_single_task(server_task && task) {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
@ -2098,56 +2239,30 @@ private:
// generate draft tokens in speculative decoding mode
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
// perform the speculative drafting for all sequences at the same time in a single batch
const int n_draft_max = slot.get_n_draft_max();
if (n_draft_max > 0) {
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}
llama_tokens draft;
const int n_draft_max_slot = slot.get_n_draft_max();
if (n_draft_max_slot > 0) {
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
const auto & params_spec = slot.task->params.speculative;
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
if (draft.size() > (size_t) n_draft_max) {
SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
draft.resize(n_draft_max);
// compute draft and add draft to internal batch
draft = slot.spec_session->compute_draft(cached_text_tokens, slot.sampled, n_draft_max_slot);
if (draft.size() > 0) {
SLT_DBG(slot, "compute_draft: id=%d, #cached_text_tokens=%zu, #tokens=%zu, #i_batch_dft=%zu\n",
slot.sampled,
cached_text_tokens.size(), draft.size(), slot.i_batch_dft.size());
}
}
// add the sampled token to the batch
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);
if (slot.task->params.speculative.n_min > (int) draft.size()) {
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
// fallback to normal decoding
slot.i_batch = slot.i_batch_dft[0];
slot.drafted.clear();
slot.i_batch_dft.clear();
} else {
// keep track of total number of drafted tokens tested
slot.n_draft_total += draft.size();
// add all drafted tokens to the batch
for (size_t i = 0; i < draft.size(); i++) {
slot.i_batch_dft.push_back(batch.n_tokens);
common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(draft[i]);
}
slot.drafted = std::move(draft);
}
} else {
if (draft.empty()) {
// no speculative decoding
slot.i_batch = batch.n_tokens;
slot.i_batch_dft.clear();
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
slot.prompt.tokens.push_back(slot.sampled);
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
SLT_DBG(slot, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n",
slot.sampled,
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
}
}
@ -2646,35 +2761,12 @@ private:
// no need to create checkpoints that are too close together
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64);
SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max);
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
// yet processed and therefore it is not part of the checkpoint.
if (do_checkpoint) {
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
// make room for the new checkpoint, if needed
const auto & cur = slot.prompt.checkpoints.front();
SLT_WRN(slot,
"erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
", size = %.3f MiB)\n",
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
}
const size_t checkpoint_size =
llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
/*.pos_min = */ pos_min,
/*.pos_max = */ pos_max,
/*.n_tokens = */ slot.prompt.n_tokens() - n_tokens_cur,
/*.data = */ std::vector<uint8_t>(checkpoint_size),
});
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id,
LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
const auto cur = get_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
SLT_WRN(slot,
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
", size = %.3f MiB)\n",
@ -2851,7 +2943,7 @@ private:
slot.state = SLOT_STATE_GENERATING;
if (slot.can_speculate()) {
common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
slot.spec_session->begin(slot.prompt.tokens.get_text_tokens());
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
@ -2908,24 +3000,24 @@ private:
continue;
}
const size_t n_draft = slot.drafted.size();
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
auto accept_response = slot.spec_session->sample_and_accept();
slot.i_batch_dft.clear();
slot.drafted.clear();
const size_t n_draft = accept_response.draft_size_initial;
if (accept_response.skip_acceptance) {
SLT_DBG(slot, "partial acceptance: n_tokens=%zu, n_draft=%zu\n", accept_response.tokens.size(), n_draft);
continue;
}
const auto ids = accept_response.tokens;
const int64_t t_current = ggml_time_us();
slot.n_decoded += ids.size();
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
// update how many tokens out of those tested were accepted
slot.n_draft_accepted += ids.size() - 1;
// inform the speculative decoding about the number of accepted tokens
common_speculative_accept(slot.spec, ids.size() - 1);
slot.n_draft_total += n_draft;
// rollback to the state before sampling the draft tokens
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
@ -2933,8 +3025,9 @@ private:
// add accepted tokens to the prompt
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
slot.sampled = ids.back(); // last accepted token
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
slot.spec_session->rewind(slot.prompt.n_tokens());
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;