server : refactored spec logic to speculative.cpp
This commit is contained in:
parent
01763e800d
commit
e994c4ec1f
|
|
@ -1072,3 +1072,252 @@ void common_speculative_print_stats(const common_speculative * spec) {
|
||||||
str_perf.c_str());
|
str_perf.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// server callbacks
|
||||||
|
//
|
||||||
|
|
||||||
|
common_speculative_callback::~common_speculative_callback() = default;
|
||||||
|
|
||||||
|
// server session
|
||||||
|
//
|
||||||
|
struct common_speculative_session::impl {
|
||||||
|
common_speculative_callback & callback;
|
||||||
|
common_params_speculative params_spec;
|
||||||
|
|
||||||
|
llama_context * ctx_tgt = nullptr;
|
||||||
|
|
||||||
|
common_speculative * spec = nullptr;
|
||||||
|
|
||||||
|
// `i_batch_dft`, idx of draft tokens in the main batch are stored in the caller
|
||||||
|
|
||||||
|
llama_tokens draft;
|
||||||
|
|
||||||
|
// use of checkpoints in speculative mode
|
||||||
|
bool spec_has_ckpt = false; // true if a checkpoint for rollback after partial speculation has been created
|
||||||
|
uint16_t spec_ckpt_n_denials = 0; // number of drafts not accepted at the current position (0 or 1)
|
||||||
|
size_t spec_ckpt_size_part = 0; // size of partial checkpoint
|
||||||
|
|
||||||
|
// Speculative decoding stats
|
||||||
|
int32_t n_draft_total = 0; // Total draft tokens generated
|
||||||
|
int32_t n_draft_accepted = 0; // Draft tokens actually accepted
|
||||||
|
|
||||||
|
impl(common_speculative_callback & callback,
|
||||||
|
const common_params_speculative & params,
|
||||||
|
llama_context * ctx_tgt)
|
||||||
|
: callback(callback), params_spec(params), ctx_tgt(ctx_tgt) {
|
||||||
|
spec = common_speculative_init(params_spec, ctx_tgt);
|
||||||
|
}
|
||||||
|
|
||||||
|
void begin(const llama_tokens & prompt_history) {
|
||||||
|
common_speculative_begin(spec, prompt_history);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_batch_dft() {
|
||||||
|
return !draft.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear_draft() {
|
||||||
|
draft.clear();
|
||||||
|
spec_ckpt_n_denials = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_tokens compute_draft(
|
||||||
|
const llama_tokens & cached_text_tokens,
|
||||||
|
llama_token id_last,
|
||||||
|
const int n_draft_max) {
|
||||||
|
if (spec == nullptr) {
|
||||||
|
// no implementation, nothing to do
|
||||||
|
clear_draft();
|
||||||
|
return draft;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_draft_max == 0) {
|
||||||
|
clear_draft();
|
||||||
|
return draft;
|
||||||
|
}
|
||||||
|
if (params_spec.ckpt_num_tries > 0
|
||||||
|
&& spec_ckpt_n_denials >= params_spec.ckpt_num_tries) {
|
||||||
|
clear_draft();
|
||||||
|
return draft;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (spec_ckpt_n_denials > 0) {
|
||||||
|
// there is a previous speculation which wasn't accepted in full length
|
||||||
|
if (draft.empty()) {
|
||||||
|
LOG_WRN("%s: draft of length 0 after denied checkpoint\n", __func__);
|
||||||
|
clear_draft();
|
||||||
|
return draft;
|
||||||
|
}
|
||||||
|
// we use the shortened draft of previous speculation
|
||||||
|
LOG_INF("%s: resuse shortened draft, size=%zu\n", __func__, draft.size());
|
||||||
|
} else {
|
||||||
|
// call the speculative implementation to create a draft
|
||||||
|
draft = common_speculative_draft(spec, params_spec, cached_text_tokens, id_last);
|
||||||
|
LOG_DBG("draft: id_last=%d, #draft=%zu\n", id_last, draft.size());
|
||||||
|
if (draft.empty()) {
|
||||||
|
clear_draft();
|
||||||
|
return draft;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (draft.size() > (size_t) n_draft_max) {
|
||||||
|
LOG_WRN("draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
|
||||||
|
draft.resize(n_draft_max);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool do_checkpoint = !draft.empty() && params_spec.ckpt_num_tries > 0;
|
||||||
|
if (do_checkpoint && cached_text_tokens.size() > 5) {
|
||||||
|
LOG_DBG("draft.size = %zu, n_spec_denials = %d, do_checkpoint = %s, tokens=[..., %d, %d, %d]\n",
|
||||||
|
draft.size(), spec_ckpt_n_denials,
|
||||||
|
do_checkpoint ? "yes" : "no",
|
||||||
|
cached_text_tokens[cached_text_tokens.size() - 3],
|
||||||
|
cached_text_tokens[cached_text_tokens.size() - 2],
|
||||||
|
cached_text_tokens[cached_text_tokens.size() - 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (do_checkpoint) {
|
||||||
|
const size_t n = callback.create_checkpoint();
|
||||||
|
if (n == 0) {
|
||||||
|
LOG_WRN("checkpoint creation failed");
|
||||||
|
clear_draft();
|
||||||
|
return draft;
|
||||||
|
}
|
||||||
|
spec_ckpt_size_part = n;
|
||||||
|
spec_has_ckpt = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params_spec.n_min > (int) draft.size()) {
|
||||||
|
LOG_DBG("ignoring small draft: %d < %d\n", (int) draft.size(), params_spec.n_min);
|
||||||
|
clear_draft();
|
||||||
|
return draft;
|
||||||
|
}
|
||||||
|
|
||||||
|
// add last sampled token to the batch
|
||||||
|
callback.batch_add_token(id_last, true);
|
||||||
|
|
||||||
|
// add all drafted tokens to the batch
|
||||||
|
for (size_t i = 0; i < draft.size(); i++) {
|
||||||
|
callback.batch_add_token(draft[i], true);
|
||||||
|
}
|
||||||
|
|
||||||
|
return draft;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_speculative_accept_response sample_and_accept() {
|
||||||
|
const size_t n_draft = draft.size();
|
||||||
|
|
||||||
|
// the accepted tokens from the speculation
|
||||||
|
auto ids = callback.sampler_sample_and_accept_n(draft);
|
||||||
|
|
||||||
|
LOG_DBG("%s: n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size());
|
||||||
|
if (ids.size() < n_draft + 1) {
|
||||||
|
// the main model rejected some tokens
|
||||||
|
|
||||||
|
// we shorten the draft
|
||||||
|
draft.resize(ids.size() - 1);
|
||||||
|
if (spec_has_ckpt) {
|
||||||
|
// we need to rollback to the state before sampling the draft tokens
|
||||||
|
const size_t n = callback.restore_checkpoint(spec_ckpt_size_part);
|
||||||
|
LOG_INF("partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
|
||||||
|
ids.size() -1 , n_draft, n);
|
||||||
|
|
||||||
|
// rollback to the state before sampling the draft tokens
|
||||||
|
|
||||||
|
// Delete Checkpoint
|
||||||
|
callback.delete_checkpoint();
|
||||||
|
spec_has_ckpt = false;
|
||||||
|
|
||||||
|
if (n_draft > 0 && spec_ckpt_n_denials == 0) {
|
||||||
|
// we will do the batch again but with the shortened draft
|
||||||
|
spec_ckpt_n_denials++;
|
||||||
|
|
||||||
|
return common_speculative_accept_response(std::move(ids), n_draft, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
//spec_ckpt_n_accepted = (spec_ckpt_n_denials < params_spec.ckpt_num_tries) ? (int) (ids.size() - 1) : 0;
|
||||||
|
|
||||||
|
callback.batch_clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const size_t draft_size_accepted = draft.size();
|
||||||
|
LOG_DBG("%s: draft.size=%zu\n", __func__, draft_size_accepted);
|
||||||
|
common_speculative_accept(spec, draft_size_accepted);
|
||||||
|
draft.clear();
|
||||||
|
|
||||||
|
return common_speculative_accept_response{std::move(ids), n_draft, false};
|
||||||
|
}
|
||||||
|
|
||||||
|
void rewind(const llama_pos p0) {
|
||||||
|
spec_ckpt_n_denials = 0;
|
||||||
|
if (spec_has_ckpt) {
|
||||||
|
// Delete Checkpoint
|
||||||
|
callback.delete_checkpoint();
|
||||||
|
spec_has_ckpt = false;
|
||||||
|
} else {
|
||||||
|
callback.memory_seq_rm(p0, -1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_stats() const {
|
||||||
|
if (spec == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_speculative_print_stats(spec);
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
if (spec == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
clear_draft();
|
||||||
|
|
||||||
|
spec_has_ckpt = false;
|
||||||
|
spec_ckpt_size_part = 0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
common_speculative_session::common_speculative_session(
|
||||||
|
common_speculative_callback & callback,
|
||||||
|
const common_params_speculative & params,
|
||||||
|
llama_context * ctx_tgt) : p_impl(new impl{callback, params, ctx_tgt}) {
|
||||||
|
}
|
||||||
|
|
||||||
|
common_speculative_session::~common_speculative_session() {
|
||||||
|
common_speculative_free(p_impl->spec);
|
||||||
|
delete p_impl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_session::begin(const llama_tokens & prompt_history) {
|
||||||
|
p_impl->begin(prompt_history);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_speculative_session::has_batch_dft() {
|
||||||
|
return !p_impl->has_batch_dft();
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_tokens common_speculative_session::compute_draft(
|
||||||
|
const llama_tokens & prompt,
|
||||||
|
llama_token id_last,
|
||||||
|
int n_draft_max_slot) {
|
||||||
|
return p_impl->compute_draft(prompt, id_last, n_draft_max_slot);
|
||||||
|
}
|
||||||
|
|
||||||
|
common_speculative_accept_response common_speculative_session::sample_and_accept() {
|
||||||
|
return p_impl->sample_and_accept();
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_session::rewind(const llama_pos p0) {
|
||||||
|
p_impl->rewind(p0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_session::print_stats() const {
|
||||||
|
p_impl->print_stats();
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_speculative_session::reset() {
|
||||||
|
p_impl->reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,15 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
|
// common/speculative.h has two interfaces:
|
||||||
|
//
|
||||||
|
// 1) struct common_speculative with init, begin, draft, accept and print_stats
|
||||||
|
// Simple interface, see examples/speculative/speculative.cpp
|
||||||
|
//
|
||||||
|
// 2) struct common_speculative_session with struct common_speculative_callback
|
||||||
|
// Complex interface which supports checkpoints, see tools/server/server-context.cpp
|
||||||
|
//
|
||||||
|
|
||||||
struct common_speculative;
|
struct common_speculative;
|
||||||
|
|
||||||
// comma separated list of all types
|
// comma separated list of all types
|
||||||
|
|
@ -39,3 +48,91 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||||
|
|
||||||
// print statistics about the speculative decoding
|
// print statistics about the speculative decoding
|
||||||
void common_speculative_print_stats(const common_speculative * spec);
|
void common_speculative_print_stats(const common_speculative * spec);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Interactions with server
|
||||||
|
//
|
||||||
|
|
||||||
|
// callback implemented by the server
|
||||||
|
struct common_speculative_callback {
|
||||||
|
virtual ~common_speculative_callback();
|
||||||
|
|
||||||
|
// Add a token to the draft sequence.
|
||||||
|
virtual void batch_add_token(const llama_token token, bool logits) = 0;
|
||||||
|
|
||||||
|
// Clears the batch context.
|
||||||
|
virtual void batch_clear() = 0;
|
||||||
|
|
||||||
|
// Sample and accept tokens from the main model.
|
||||||
|
virtual llama_tokens sampler_sample_and_accept_n(const llama_tokens & drafted) = 0;
|
||||||
|
|
||||||
|
// Deletes a part of the context.
|
||||||
|
// Returns true if the memory was modified.
|
||||||
|
virtual bool memory_seq_rm(llama_pos p0, llama_pos p1) = 0;
|
||||||
|
|
||||||
|
// Creates a checkpoint of the current state of the context.
|
||||||
|
// Returns the size of the checkpoint in bytes.
|
||||||
|
virtual size_t create_checkpoint() = 0;
|
||||||
|
|
||||||
|
// Restore a checkpoint previously created by create_checkpoint().
|
||||||
|
// Returns the size of the restored checkpoint in bytes.
|
||||||
|
virtual size_t restore_checkpoint(size_t ckpt_size_part_expected) = 0;
|
||||||
|
|
||||||
|
// Delete a checkpoint previously created by create_checkpoint().
|
||||||
|
virtual void delete_checkpoint() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_speculative_accept_response {
|
||||||
|
llama_tokens tokens;
|
||||||
|
size_t draft_size_initial;
|
||||||
|
bool skip_acceptance;
|
||||||
|
|
||||||
|
common_speculative_accept_response(llama_tokens t, size_t draft_size_initial, bool skip)
|
||||||
|
: tokens(std::move(t)), draft_size_initial(draft_size_initial), skip_acceptance(skip) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// speculative decoding which may use checkpoints to rewind in tokens history
|
||||||
|
struct common_speculative_session {
|
||||||
|
|
||||||
|
common_speculative_session(
|
||||||
|
common_speculative_callback & callback,
|
||||||
|
const common_params_speculative & params,
|
||||||
|
llama_context * ctx_tgt);
|
||||||
|
|
||||||
|
~common_speculative_session();
|
||||||
|
|
||||||
|
// dont copy
|
||||||
|
common_speculative_session(const common_speculative_session &) = delete;
|
||||||
|
common_speculative_session & operator=(const common_speculative_session &) = delete;
|
||||||
|
|
||||||
|
|
||||||
|
// call once at the beginning of a new generation
|
||||||
|
// some spec implementations use the prompt history to initialize lookup maps
|
||||||
|
void begin(const llama_tokens & prompt_history);
|
||||||
|
|
||||||
|
bool has_batch_dft();
|
||||||
|
|
||||||
|
// do speculative decoding to compute a draft of tokens
|
||||||
|
llama_tokens compute_draft(const llama_tokens & prompt,
|
||||||
|
llama_token id_last,
|
||||||
|
int n_draft_max_slot);
|
||||||
|
|
||||||
|
// check if and how far the current draft is accepted
|
||||||
|
common_speculative_accept_response sample_and_accept();
|
||||||
|
|
||||||
|
// rewind (because of a draft not fully accepted)
|
||||||
|
void rewind(const llama_pos p0);
|
||||||
|
|
||||||
|
// print statistics
|
||||||
|
void print_stats() const;
|
||||||
|
|
||||||
|
// reset and delete structures
|
||||||
|
void reset();
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct impl;
|
||||||
|
impl * p_impl;
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -56,7 +56,8 @@ struct server_slot {
|
||||||
// multimodal
|
// multimodal
|
||||||
mtmd_context * mctx = nullptr;
|
mtmd_context * mctx = nullptr;
|
||||||
|
|
||||||
common_speculative * spec = nullptr;
|
std::unique_ptr<common_speculative_callback> spec_callback;
|
||||||
|
common_speculative_session * spec_session = nullptr;
|
||||||
|
|
||||||
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
// TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
|
||||||
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
// see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
|
||||||
|
|
@ -147,14 +148,6 @@ struct server_slot {
|
||||||
common_sampler_ptr smpl;
|
common_sampler_ptr smpl;
|
||||||
|
|
||||||
llama_token sampled; // in speculative mode, this is the last accepted token
|
llama_token sampled; // in speculative mode, this is the last accepted token
|
||||||
llama_tokens drafted;
|
|
||||||
|
|
||||||
// use of checkpoints in speculative mode
|
|
||||||
bool spec_has_ckpt = false; // true if a checkpoint for rollback after partial speculation has been created
|
|
||||||
uint16_t spec_ckpt_n_denials = 0; // number of drafts not accepted at the current position
|
|
||||||
int spec_ckpt_n_accepted = 0; // number of accepted tokens at current position
|
|
||||||
size_t spec_ckpt_size_part = 0; // size of partial checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
// stats
|
// stats
|
||||||
size_t n_sent_text = 0; // number of sent text character
|
size_t n_sent_text = 0; // number of sent text character
|
||||||
|
|
@ -184,7 +177,9 @@ struct server_slot {
|
||||||
stopping_word = "";
|
stopping_word = "";
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
|
|
||||||
drafted.clear();
|
if (spec_session != nullptr) {
|
||||||
|
spec_session->reset();
|
||||||
|
}
|
||||||
i_batch_dft.clear();
|
i_batch_dft.clear();
|
||||||
generated_tokens.clear();
|
generated_tokens.clear();
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
|
|
@ -194,11 +189,6 @@ struct server_slot {
|
||||||
n_draft_total = 0;
|
n_draft_total = 0;
|
||||||
n_draft_accepted = 0;
|
n_draft_accepted = 0;
|
||||||
|
|
||||||
spec_ckpt_n_denials = 0;
|
|
||||||
spec_ckpt_n_accepted = 0;
|
|
||||||
spec_has_ckpt = false;
|
|
||||||
spec_ckpt_size_part = 0;
|
|
||||||
|
|
||||||
task_prev = std::move(task);
|
task_prev = std::move(task);
|
||||||
task.reset();
|
task.reset();
|
||||||
|
|
||||||
|
|
@ -271,7 +261,7 @@ struct server_slot {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool can_speculate() const {
|
bool can_speculate() const {
|
||||||
return !!spec;
|
return !!spec_session;
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_token(const completion_token_output & token) {
|
void add_token(const completion_token_output & token) {
|
||||||
|
|
@ -411,7 +401,9 @@ struct server_slot {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
common_speculative_print_stats(spec);
|
if (spec_session) {
|
||||||
|
spec_session->print_stats();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
json to_json(bool only_metrics = false) const {
|
json to_json(bool only_metrics = false) const {
|
||||||
|
|
@ -610,8 +602,10 @@ private:
|
||||||
|
|
||||||
// Clear any sampling context
|
// Clear any sampling context
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
common_speculative_free(slot.spec);
|
if (slot.spec_session != nullptr) {
|
||||||
slot.spec = nullptr;
|
slot.spec_session->reset();
|
||||||
|
slot.spec_session = nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
|
@ -631,6 +625,74 @@ private:
|
||||||
sleeping = new_state;
|
sleeping = new_state;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// callback for speculative decoding
|
||||||
|
//
|
||||||
|
struct server_speculative_callback : public common_speculative_callback {
|
||||||
|
server_slot & slot;
|
||||||
|
server_context_impl & ctx_impl;
|
||||||
|
|
||||||
|
server_speculative_callback(server_slot & slot, server_context_impl & ctx_impl)
|
||||||
|
: slot(slot), ctx_impl(ctx_impl) {}
|
||||||
|
|
||||||
|
void batch_add_token(const llama_token token, bool logits) override {
|
||||||
|
slot.i_batch_dft.push_back(ctx_impl.batch.n_tokens);
|
||||||
|
common_batch_add(ctx_impl.batch, token, slot.prompt.tokens.pos_next(), { slot.id }, logits);
|
||||||
|
slot.prompt.tokens.push_back(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
void batch_clear() override {
|
||||||
|
common_batch_clear(ctx_impl.batch);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token> sampler_sample_and_accept_n(const llama_tokens & drafted) override {
|
||||||
|
if (slot.i_batch_dft.size() != 1 + drafted.size()) {
|
||||||
|
GGML_ABORT("%s: #i_batch_dft = %zu != 1 + #drafted=%zu",
|
||||||
|
__func__, slot.i_batch_dft.size(), 1 + drafted.size());
|
||||||
|
}
|
||||||
|
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx_impl.ctx, slot.i_batch_dft, drafted);
|
||||||
|
|
||||||
|
return ids;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool memory_seq_rm(llama_pos p0, llama_pos p1) override {
|
||||||
|
return llama_memory_seq_rm(llama_get_memory(ctx_impl.ctx), slot.id, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t create_checkpoint() override {
|
||||||
|
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_impl.ctx), slot.id);
|
||||||
|
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_impl.ctx), slot.id);
|
||||||
|
const auto n_tokens_cur = batch.n_tokens;
|
||||||
|
const auto & cur_with_size = ctx_impl.get_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
|
||||||
|
auto & cur = cur_with_size.checkpoint;
|
||||||
|
|
||||||
|
SLT_INF(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||||
|
(int) slot.prompt.checkpoints.size(), ctx_impl.params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||||
|
return cur_with_size.size;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t restore_checkpoint(size_t ckpt_size_part_expected) override {
|
||||||
|
auto & ckpt = slot.prompt.checkpoints.back();
|
||||||
|
|
||||||
|
SLT_INF(slot, "restoring checkpoint (pos_min = %d, pos_max = %d)\n", ckpt.pos_min, ckpt.pos_max);
|
||||||
|
const size_t n = llama_state_seq_set_data_ext(ctx_impl.ctx,
|
||||||
|
ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
|
if (n != ckpt_size_part_expected) {
|
||||||
|
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||||
|
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
slot.prompt.tokens.keep_first(ckpt.pos_max + 1);
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
|
void delete_checkpoint() override {
|
||||||
|
slot.prompt.checkpoints.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
// load the model and initialize llama_context
|
// load the model and initialize llama_context
|
||||||
// this may also be called to resume from sleeping state
|
// this may also be called to resume from sleeping state
|
||||||
bool load_model(const common_params & params) {
|
bool load_model(const common_params & params) {
|
||||||
|
|
@ -657,6 +719,7 @@ private:
|
||||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||||
|
|
||||||
if (params_base.speculative.has_dft()) {
|
if (params_base.speculative.has_dft()) {
|
||||||
|
// TODO speculative: move to common/speculative.cpp?
|
||||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
|
SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str());
|
||||||
|
|
||||||
const auto & params_spec = params_base.speculative;
|
const auto & params_spec = params_base.speculative;
|
||||||
|
|
@ -772,7 +835,11 @@ private:
|
||||||
|
|
||||||
// initialize slots
|
// initialize slots
|
||||||
for (int i = 0; i < params_base.n_parallel; i++) {
|
for (int i = 0; i < params_base.n_parallel; i++) {
|
||||||
server_slot slot;
|
// Create a new slot in the vector.
|
||||||
|
slots.emplace_back();
|
||||||
|
|
||||||
|
// Get a reference of the new slot.
|
||||||
|
server_slot & slot = slots.back();
|
||||||
|
|
||||||
slot.id = i;
|
slot.id = i;
|
||||||
slot.ctx = ctx;
|
slot.ctx = ctx;
|
||||||
|
|
@ -783,16 +850,14 @@ private:
|
||||||
|
|
||||||
// try speculative decoding
|
// try speculative decoding
|
||||||
if (can_spec || params_base.speculative.ckpt_num_tries > 0) {
|
if (can_spec || params_base.speculative.ckpt_num_tries > 0) {
|
||||||
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
|
if (mctx) {
|
||||||
if (slot.spec) {
|
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
||||||
if (mctx) {
|
return false;
|
||||||
SRV_ERR("%s\n", "speculative decoding is not supported with multimodal");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
|
||||||
} else {
|
|
||||||
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
|
|
||||||
}
|
}
|
||||||
|
slot.spec_callback = std::make_unique<server_speculative_callback>(slot, *this);
|
||||||
|
slot.spec_session = new common_speculative_session(*slot.spec_callback,
|
||||||
|
params_base.speculative, slot.ctx);
|
||||||
|
SLT_INF(slot, "%s", "speculative decoding context initialized\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
||||||
|
|
@ -802,8 +867,6 @@ private:
|
||||||
};
|
};
|
||||||
|
|
||||||
slot.reset();
|
slot.reset();
|
||||||
|
|
||||||
slots.push_back(std::move(slot));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|
@ -1180,7 +1243,7 @@ private:
|
||||||
backend_sampling &= task.params.sampling.backend_sampling;
|
backend_sampling &= task.params.sampling.backend_sampling;
|
||||||
|
|
||||||
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
// TODO: speculative decoding requires multiple samples per batch - not supported yet
|
||||||
backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0);
|
backend_sampling &= !(slot.spec_session && task.params.speculative.n_max > 0);
|
||||||
|
|
||||||
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
|
||||||
backend_sampling &= !need_logits;
|
backend_sampling &= !need_logits;
|
||||||
|
|
@ -1686,6 +1749,43 @@ private:
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct server_prompt_checkpoint_with_size {
|
||||||
|
server_prompt_checkpoint checkpoint;
|
||||||
|
size_t size;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Creates a checkpoint.
|
||||||
|
//
|
||||||
|
// n_tokens_cur: the number of tokens added to the batch for the current slot
|
||||||
|
server_prompt_checkpoint_with_size get_checkpoint(server_slot & slot, const int64_t n_tokens_cur,
|
||||||
|
llama_pos pos_min, llama_pos pos_max) {
|
||||||
|
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||||
|
// make room for the new checkpoint, if needed
|
||||||
|
const auto & cur = slot.prompt.checkpoints.front();
|
||||||
|
|
||||||
|
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||||
|
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||||
|
|
||||||
|
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
|
|
||||||
|
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||||
|
/*.pos_min = */ pos_min,
|
||||||
|
/*.pos_max = */ pos_max,
|
||||||
|
/*.n_tokens = */ slot.prompt.n_tokens() - n_tokens_cur,
|
||||||
|
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||||
|
});
|
||||||
|
|
||||||
|
const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||||
|
if (n != checkpoint_size) {
|
||||||
|
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
return server_prompt_checkpoint_with_size{ cur, checkpoint_size };
|
||||||
|
}
|
||||||
|
|
||||||
void process_single_task(server_task && task) {
|
void process_single_task(server_task && task) {
|
||||||
switch (task.type) {
|
switch (task.type) {
|
||||||
case SERVER_TASK_TYPE_COMPLETION:
|
case SERVER_TASK_TYPE_COMPLETION:
|
||||||
|
|
@ -2080,107 +2180,18 @@ private:
|
||||||
// generate draft tokens in speculative decoding mode
|
// generate draft tokens in speculative decoding mode
|
||||||
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
// TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK]
|
||||||
// perform the speculative drafting for all sequences at the same time in a single batch
|
// perform the speculative drafting for all sequences at the same time in a single batch
|
||||||
const int n_draft_max = (slot.spec_ckpt_n_accepted > 0) ? slot.spec_ckpt_n_accepted : slot.get_n_draft_max();
|
const int n_draft_max_slot = slot.get_n_draft_max();
|
||||||
if (n_draft_max > 0 && (params_base.speculative.ckpt_num_tries == 0
|
|
||||||
|| slot.spec_ckpt_n_denials < params_base.speculative.ckpt_num_tries)) {
|
|
||||||
if (mctx) {
|
|
||||||
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
|
|
||||||
GGML_ABORT("not supported by multimodal");
|
|
||||||
}
|
|
||||||
|
|
||||||
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
||||||
|
llama_tokens draft = slot.spec_session->compute_draft(cached_text_tokens, slot.sampled, n_draft_max_slot);
|
||||||
|
if (draft.size() > 0) {
|
||||||
|
SLT_DBG(slot, "compute_draft: #tokens=%d\n", (int) draft.size());
|
||||||
|
}
|
||||||
|
|
||||||
const auto & params_spec = slot.task->params.speculative;
|
if (draft.empty()) {
|
||||||
|
|
||||||
llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
|
||||||
|
|
||||||
if (draft.size() > (size_t) n_draft_max) {
|
|
||||||
SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max);
|
|
||||||
draft.resize(n_draft_max);
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
|
||||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
|
||||||
bool do_checkpoint = !draft.empty() && params_base.speculative.ckpt_num_tries > 0
|
|
||||||
&& slot.prompt.checkpoints.size() < (size_t) params_base.n_ctx_checkpoints;
|
|
||||||
if (do_checkpoint && cached_text_tokens.size() > 5) {
|
|
||||||
SLT_DBG(slot, "draft.size = %zu, n_spec_denials = %d, #ckpts=%zu, do_checkpoint = %s, pos_min = %d, pos_max = %d, tokens=[..., %d, %d, %d]\n",
|
|
||||||
draft.size(), slot.spec_ckpt_n_denials,
|
|
||||||
slot.prompt.checkpoints.size(),
|
|
||||||
do_checkpoint ? "yes" : "no", pos_min, pos_max,
|
|
||||||
cached_text_tokens[cached_text_tokens.size() - 3],
|
|
||||||
cached_text_tokens[cached_text_tokens.size() - 2],
|
|
||||||
cached_text_tokens[cached_text_tokens.size() - 1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (do_checkpoint) {
|
|
||||||
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
|
||||||
// make room for the new checkpoint, if needed
|
|
||||||
const auto & cur = slot.prompt.checkpoints.front();
|
|
||||||
|
|
||||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
|
||||||
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
|
||||||
|
|
||||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, 0);
|
|
||||||
|
|
||||||
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
|
||||||
/*.pos_min = */ pos_min,
|
|
||||||
/*.pos_max = */ pos_max,
|
|
||||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
|
||||||
});
|
|
||||||
|
|
||||||
const size_t n = llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
||||||
|
|
||||||
SLT_INF(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
|
||||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
|
||||||
|
|
||||||
slot.spec_ckpt_size_part = n;
|
|
||||||
slot.spec_has_ckpt = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// add the sampled token to the batch
|
|
||||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
|
||||||
SLT_DBG(slot, "before common_batch_add: sampled=%d, pos_next=%d, tokens.size=%zu, tokens.last=%d\n",
|
|
||||||
slot.sampled, slot.prompt.tokens.pos_next(), slot.prompt.tokens.size(), slot.prompt.tokens[slot.prompt.tokens.size() -1]);
|
|
||||||
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
||||||
slot.prompt.tokens.push_back(slot.sampled);
|
|
||||||
|
|
||||||
if (slot.task->params.speculative.n_min > (int) draft.size()) {
|
|
||||||
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min);
|
|
||||||
// fallback to normal decoding
|
|
||||||
slot.i_batch = slot.i_batch_dft[0];
|
|
||||||
slot.drafted.clear();
|
|
||||||
slot.i_batch_dft.clear();
|
|
||||||
|
|
||||||
if (slot.spec_has_ckpt) {
|
|
||||||
slot.spec_ckpt_n_accepted = 0;
|
|
||||||
slot.spec_ckpt_n_denials = 0;
|
|
||||||
|
|
||||||
// Delete Checkpoint
|
|
||||||
slot.prompt.checkpoints.pop_back();
|
|
||||||
slot.spec_has_ckpt = false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// keep track of total number of drafted tokens tested
|
|
||||||
slot.n_draft_total += draft.size();
|
|
||||||
|
|
||||||
// add all drafted tokens to the batch
|
|
||||||
for (size_t i = 0; i < draft.size(); i++) {
|
|
||||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
|
||||||
common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true);
|
|
||||||
slot.prompt.tokens.push_back(draft[i]);
|
|
||||||
}
|
|
||||||
slot.drafted = std::move(draft);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// no speculative decoding
|
// no speculative decoding
|
||||||
slot.i_batch = batch.n_tokens;
|
slot.i_batch = batch.n_tokens;
|
||||||
|
slot.i_batch_dft.clear();
|
||||||
slot.spec_ckpt_n_denials = 0;
|
|
||||||
slot.spec_ckpt_n_accepted = 0;
|
|
||||||
|
|
||||||
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true);
|
||||||
|
|
||||||
|
|
@ -2690,31 +2701,8 @@ private:
|
||||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||||
// yet processed and therefore it is not part of the checkpoint.
|
// yet processed and therefore it is not part of the checkpoint.
|
||||||
if (do_checkpoint) {
|
if (do_checkpoint) {
|
||||||
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
auto cur_with_size = get_checkpoint(slot, n_tokens_cur, pos_min, pos_max);
|
||||||
// make room for the new checkpoint, if needed
|
auto & cur = cur_with_size.checkpoint;
|
||||||
const auto & cur = slot.prompt.checkpoints.front();
|
|
||||||
|
|
||||||
SLT_WRN(slot,
|
|
||||||
"erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
|
|
||||||
", size = %.3f MiB)\n",
|
|
||||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
|
||||||
|
|
||||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t checkpoint_size =
|
|
||||||
llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
||||||
|
|
||||||
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
|
||||||
/*.pos_min = */ pos_min,
|
|
||||||
/*.pos_max = */ pos_max,
|
|
||||||
/*.n_tokens = */ slot.prompt.n_tokens() - n_tokens_cur,
|
|
||||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
|
||||||
});
|
|
||||||
|
|
||||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id,
|
|
||||||
LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
||||||
|
|
||||||
SLT_WRN(slot,
|
SLT_WRN(slot,
|
||||||
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
|
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64
|
||||||
", size = %.3f MiB)\n",
|
", size = %.3f MiB)\n",
|
||||||
|
|
@ -2891,7 +2879,7 @@ private:
|
||||||
slot.state = SLOT_STATE_GENERATING;
|
slot.state = SLOT_STATE_GENERATING;
|
||||||
|
|
||||||
if (slot.can_speculate()) {
|
if (slot.can_speculate()) {
|
||||||
common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
|
slot.spec_session->begin(slot.prompt.tokens.get_text_tokens());
|
||||||
}
|
}
|
||||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
|
|
@ -2948,61 +2936,23 @@ private:
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t n_draft = slot.drafted.size();
|
auto accept_response = slot.spec_session->sample_and_accept();
|
||||||
|
|
||||||
// the accepted tokens from the speculation
|
|
||||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
|
|
||||||
slot.i_batch_dft.clear();
|
slot.i_batch_dft.clear();
|
||||||
slot.drafted.clear();
|
const size_t n_draft = accept_response.draft_size_initial;
|
||||||
|
if (accept_response.skip_acceptance) {
|
||||||
const int64_t t_current = ggml_time_us();
|
SLT_INF(slot, "partial acceptance: n_tokens=%zu, n_draft=%zu\n", accept_response.tokens.size(), n_draft);
|
||||||
|
|
||||||
|
|
||||||
if (slot.spec_has_ckpt && ids.size() < n_draft + 1) {
|
|
||||||
// the main model rejected some tokens, so we need to rollback to the state before sampling the draft tokens
|
|
||||||
auto & ckpt = slot.prompt.checkpoints.back();
|
|
||||||
SLT_INF(slot, "partial acceptance: %zu < %zu, restoring checkpoint (pos_min = %d, pos_max = %d)\n",
|
|
||||||
ids.size() - 1, n_draft,
|
|
||||||
ckpt.pos_min, ckpt.pos_max);
|
|
||||||
const size_t n = llama_state_seq_set_data_ext(ctx,
|
|
||||||
ckpt.data.data(), ckpt.size(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
|
||||||
if (n != slot.spec_ckpt_size_part) {
|
|
||||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
|
||||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), slot.spec_ckpt_size_part, n);
|
|
||||||
}
|
|
||||||
SRV_INF("partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
|
|
||||||
ids.size() -1 , n_draft, n);
|
|
||||||
|
|
||||||
// rollback to the state before sampling the draft tokens
|
|
||||||
SLT_INF(slot, "partial acceptance: n_tokens=%d, n_draft=%zu, pos_max=%d\n",
|
|
||||||
slot.prompt.n_tokens(), n_draft, ckpt.pos_max);
|
|
||||||
|
|
||||||
slot.prompt.tokens.keep_first(ckpt.pos_max + 1);
|
|
||||||
|
|
||||||
// Delete Checkpoint
|
|
||||||
slot.prompt.checkpoints.pop_back();
|
|
||||||
slot.spec_has_ckpt = false;
|
|
||||||
|
|
||||||
// Inform the speculative implementation of the number of valid tokens.
|
|
||||||
// common_speculative_accept(slot.spec, ids.size() - 1);
|
|
||||||
|
|
||||||
slot.spec_ckpt_n_denials++;
|
|
||||||
slot.spec_ckpt_n_accepted = (slot.spec_ckpt_n_denials < params_base.speculative.ckpt_num_tries) ? (int) (ids.size() - 1) : 0;
|
|
||||||
|
|
||||||
common_batch_clear(batch);
|
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
const auto ids = accept_response.tokens;
|
||||||
|
|
||||||
|
|
||||||
|
const int64_t t_current = ggml_time_us();
|
||||||
|
|
||||||
slot.n_decoded += ids.size();
|
slot.n_decoded += ids.size();
|
||||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||||
|
|
||||||
// update how many tokens out of those tested were accepted
|
// update how many tokens out of those tested were accepted
|
||||||
slot.n_draft_accepted += ids.size() - 1;
|
slot.n_draft_accepted += ids.size() - 1;
|
||||||
slot.spec_ckpt_n_accepted = 0;
|
|
||||||
|
|
||||||
// inform the speculative decoding about the number of accepted tokens
|
|
||||||
common_speculative_accept(slot.spec, ids.size() - 1);
|
|
||||||
|
|
||||||
// rollback to the state before sampling the draft tokens
|
// rollback to the state before sampling the draft tokens
|
||||||
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
||||||
|
|
@ -3011,17 +2961,7 @@ private:
|
||||||
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
||||||
slot.sampled = ids.back(); // last accepted token
|
slot.sampled = ids.back(); // last accepted token
|
||||||
|
|
||||||
slot.spec_ckpt_n_denials = 0;
|
slot.spec_session->rewind(slot.prompt.n_tokens());
|
||||||
if (slot.spec_has_ckpt) {
|
|
||||||
// Delete Checkpoint
|
|
||||||
if (slot.prompt.checkpoints.empty()) {
|
|
||||||
GGML_ABORT("missing checkpoint to delete");
|
|
||||||
}
|
|
||||||
slot.prompt.checkpoints.pop_back();
|
|
||||||
slot.spec_has_ckpt = false;
|
|
||||||
} else {
|
|
||||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 0; i < ids.size(); ++i) {
|
for (size_t i = 0; i < ids.size(); ++i) {
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue