speculative : checkpoints with draft model, logging
This commit is contained in:
parent
af3b630e0b
commit
fe4f859a67
|
|
@ -524,7 +524,7 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
|
|||
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
|
||||
|
||||
// update the value statistics
|
||||
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
|
||||
LOG_DBG("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
|
||||
n_accepted, curr_value.n_accepted);
|
||||
curr_value.n_accepted = n_accepted;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -144,10 +144,28 @@ struct common_speculative_state {
|
|||
virtual void accept(uint16_t n_accepted) = 0;
|
||||
};
|
||||
|
||||
struct common_speculative_checkpoint {
|
||||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
|
||||
int64_t n_tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
size_t size() const {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
size_t ckpt_size;
|
||||
};
|
||||
|
||||
struct common_speculative_state_draft : public common_speculative_state {
|
||||
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
|
||||
llama_context * ctx_dft;
|
||||
|
||||
struct common_speculative_checkpoint ckpt;
|
||||
bool use_checkpoint;
|
||||
|
||||
common_sampler * smpl;
|
||||
|
||||
llama_batch batch;
|
||||
|
|
@ -160,10 +178,12 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
enum common_speculative_type type,
|
||||
llama_context * ctx_tgt,
|
||||
llama_context * ctx_dft,
|
||||
const std::vector<std::pair<std::string, std::string>> & replacements)
|
||||
const std::vector<std::pair<std::string, std::string>> & replacements,
|
||||
bool use_checkpoint)
|
||||
: common_speculative_state(type)
|
||||
, ctx_tgt(ctx_tgt)
|
||||
, ctx_dft(ctx_dft)
|
||||
, use_checkpoint(use_checkpoint)
|
||||
{
|
||||
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||
smpl = nullptr;
|
||||
|
|
@ -218,7 +238,48 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
}
|
||||
|
||||
void begin(const llama_tokens & prompt) override {
|
||||
GGML_UNUSED(prompt);
|
||||
if (use_checkpoint && ckpt.size() > 0) {
|
||||
// delete checkpoint
|
||||
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%zu, size=%.3f MiB\n",
|
||||
__func__, prompt.size(),
|
||||
ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
|
||||
ckpt.pos_min = 0;
|
||||
ckpt.pos_max = 0;
|
||||
ckpt.n_tokens = 0;
|
||||
ckpt.ckpt_size = 0;
|
||||
ckpt.data.clear();
|
||||
}
|
||||
}
|
||||
|
||||
size_t draft_init_checkpoint(int n_tokens_prompt, int n_tokens_batch) {
|
||||
int slot_id = 0;
|
||||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
|
||||
ckpt.n_tokens = n_tokens_prompt - n_tokens_batch;
|
||||
ckpt.data.resize(checkpoint_size);
|
||||
|
||||
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != checkpoint_size) {
|
||||
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
|
||||
}
|
||||
|
||||
LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__,
|
||||
ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024);
|
||||
return n;
|
||||
}
|
||||
|
||||
size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) {
|
||||
int slot_id = 0;
|
||||
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
|
||||
const size_t n = llama_state_seq_set_data_ext(ctx_dft,
|
||||
ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
if (n != ckpt_size_part_expected) {
|
||||
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
|
||||
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
|
||||
}
|
||||
return n;
|
||||
}
|
||||
|
||||
void draft(
|
||||
|
|
@ -236,8 +297,8 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
|
||||
auto * mem_dft = llama_get_memory(ctx_dft);
|
||||
|
||||
int reuse_i = 0;
|
||||
int reuse_n = 0;
|
||||
int reuse_i = 0; // index of part to be reused in prompt_dft
|
||||
int reuse_n = 0; // length of part to be reused in prompt_dft
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
|
||||
|
||||
|
|
@ -287,18 +348,26 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
}
|
||||
}
|
||||
|
||||
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
|
||||
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
|
||||
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
|
||||
if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) {
|
||||
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
|
||||
__func__, reuse_i, reuse_n);
|
||||
reuse_i = 0;
|
||||
reuse_n = 0;
|
||||
}
|
||||
|
||||
result.clear();
|
||||
result.reserve(params.n_max);
|
||||
|
||||
if (reuse_n == 0) {
|
||||
bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0;
|
||||
if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) {
|
||||
llama_memory_clear(mem_dft, false);
|
||||
prompt_dft.clear();
|
||||
} else {
|
||||
// this happens when a previous draft has been discarded (for example, due to being too small), but the
|
||||
// target model agreed with it. in this case, we simply pass back the previous results to save compute
|
||||
if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
|
||||
if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
|
||||
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
|
||||
result.push_back(prompt_dft[i]);
|
||||
|
||||
|
|
@ -310,19 +379,50 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
return;
|
||||
}
|
||||
|
||||
bool do_restore = false;
|
||||
if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) {
|
||||
// This can happen after a partial acceptance (speculative decoding with checkpoints)
|
||||
LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n",
|
||||
__func__, prompt_dft.size(), prompt_cur.size());
|
||||
prompt_dft.resize(prompt_cur.size());
|
||||
do_restore = true;
|
||||
}
|
||||
|
||||
if (reuse_i > 0) {
|
||||
llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
|
||||
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
|
||||
if (!is_removed) {
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
|
||||
}
|
||||
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
|
||||
|
||||
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
|
||||
}
|
||||
|
||||
if (reuse_n < (int) prompt_dft.size()) {
|
||||
llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
|
||||
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
|
||||
if (reuse_n < (int) prompt_dft.size() || do_restore) {
|
||||
if (use_checkpoint) {
|
||||
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
|
||||
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%zu, reuse_n=%d, prompt_dft.size=%zu\n",
|
||||
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
|
||||
}
|
||||
draft_restore_checkpoint(ckpt.ckpt_size);
|
||||
reuse_n = ckpt.n_tokens;
|
||||
prompt_dft.resize(reuse_n);
|
||||
needs_ckpt = false;
|
||||
} else {
|
||||
bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
|
||||
if (!is_removed) {
|
||||
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n",
|
||||
__func__, reuse_n, prompt_dft.size());
|
||||
}
|
||||
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (needs_ckpt && use_checkpoint) {
|
||||
ckpt.ckpt_size = draft_init_checkpoint(prompt_dft.size(), batch.n_tokens);
|
||||
}
|
||||
|
||||
// prepare a batch to evaluate any new tokens in the prompt
|
||||
common_batch_clear(batch);
|
||||
|
||||
|
|
@ -337,7 +437,11 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
if (batch.n_tokens > 0) {
|
||||
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||
|
||||
llama_decode(ctx_dft, batch);
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0 && ret != 1) {
|
||||
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
|
||||
__func__, ret, prompt_cur.size());
|
||||
}
|
||||
}
|
||||
|
||||
const llama_pos n_past = prompt_dft.size();
|
||||
|
|
@ -351,7 +455,11 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
|
||||
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
|
||||
|
||||
llama_decode(ctx_dft, batch);
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0 && ret != 1) {
|
||||
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
|
||||
__func__, ret, prompt_cur.size(), prompt_dft.size());
|
||||
}
|
||||
|
||||
common_sampler_reset(smpl);
|
||||
|
||||
|
|
@ -387,7 +495,11 @@ struct common_speculative_state_draft : public common_speculative_state {
|
|||
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
llama_decode(ctx_dft, batch);
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
|
||||
__func__, i, ret, prompt_cur.size(), prompt_dft.size());
|
||||
}
|
||||
|
||||
prompt_dft.push_back(id);
|
||||
}
|
||||
|
|
@ -909,9 +1021,10 @@ common_speculative * common_speculative_init(
|
|||
break;
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT: {
|
||||
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .ctx_dft = */ ctx_dft,
|
||||
/* .replacements = */ params.replacements
|
||||
/* .ctx_tgt = */ ctx_tgt,
|
||||
/* .ctx_dft = */ ctx_dft,
|
||||
/* .replacements = */ params.replacements,
|
||||
/* .use_checkpoint= */ params.use_checkpoints
|
||||
));
|
||||
break;
|
||||
}
|
||||
|
|
@ -1147,13 +1260,16 @@ struct common_speculative_session::impl {
|
|||
if (spec_ckpt_n_denials == 1) {
|
||||
// there is a previous speculation which wasn't accepted in full length
|
||||
if (draft.empty()) {
|
||||
LOG_WRN("%s: draft of length 0 after denied checkpoint\n", __func__);
|
||||
// switch to non-draft inference
|
||||
LOG_DBG("%s: draft of length 0 after denied checkpoint\n", __func__);
|
||||
clear_draft();
|
||||
return draft;
|
||||
}
|
||||
// we use the shortened draft of previous speculation
|
||||
LOG_DBG("%s: reuse shortened draft, #tokens=%zu, id_last=%d, size=%zu\n", __func__,
|
||||
cached_text_tokens.size(), id_last, draft.size());
|
||||
} else if (spec_ckpt_n_denials > 1) {
|
||||
GGML_ABORT("illegal state: spec_ckpt_n_denials = %d > 1", spec_ckpt_n_denials);
|
||||
} else {
|
||||
// call the speculative implementation to create a draft
|
||||
draft = common_speculative_draft(spec, params_spec, cached_text_tokens, id_last);
|
||||
|
|
@ -1224,6 +1340,7 @@ struct common_speculative_session::impl {
|
|||
draft.resize(ids.size() - 1);
|
||||
if (spec_has_ckpt) {
|
||||
// we need to rollback to the state before sampling the draft tokens
|
||||
// (restore_checkpoint shortens context and slot.prompt.tokens)
|
||||
const size_t n = callback.restore_checkpoint(spec_ckpt_size_part);
|
||||
LOG_DBG("%s: partial acceptance: %zu < %zu, restored checkpoint: got %zu bytes\n",
|
||||
__func__,
|
||||
|
|
@ -1236,7 +1353,8 @@ struct common_speculative_session::impl {
|
|||
spec_ckpt_n_denials++;
|
||||
if (ids.size() > 1u + static_cast<std::size_t>(params_spec.n_min) && spec_ckpt_n_denials == 1) {
|
||||
// we will do the batch again but with the shortened draft
|
||||
return common_speculative_accept_response(std::move(ids), n_draft, true);
|
||||
//return common_speculative_accept_response(std::move(ids), n_draft, true);
|
||||
LOG_DBG("%s: partial draft disabled\n", __func__);
|
||||
}
|
||||
|
||||
LOG_DBG("%s: don't accept partial draft, n_draft=%zu, ids.size=%zu\n", __func__, n_draft, ids.size());
|
||||
|
|
@ -1245,7 +1363,9 @@ struct common_speculative_session::impl {
|
|||
// use the sampled token only
|
||||
ids.resize(1);
|
||||
// drafted tokens in prompt have been deleted in restore_checkpoint(...).
|
||||
return common_speculative_accept_response{std::move(ids), 0, false};
|
||||
|
||||
// skip acceptance, don't calculate a new draft
|
||||
return common_speculative_accept_response{std::move(ids), 0, true};
|
||||
}
|
||||
}
|
||||
const size_t draft_size_accepted = draft.size();
|
||||
|
|
|
|||
|
|
@ -2184,7 +2184,8 @@ private:
|
|||
// compute draft and add draft to internal batch
|
||||
draft = slot.spec_session->compute_draft(cached_text_tokens, slot.sampled, n_draft_max_slot);
|
||||
if (draft.size() > 0) {
|
||||
SLT_DBG(slot, "compute_draft: #cached_text_tokens=%zu, #tokens=%zu, #i_batch_dft=%zu\n",
|
||||
SLT_DBG(slot, "compute_draft: id=%d, #cached_text_tokens=%zu, #tokens=%zu, #i_batch_dft=%zu\n",
|
||||
slot.sampled,
|
||||
cached_text_tokens.size(), draft.size(), slot.i_batch_dft.size());
|
||||
}
|
||||
}
|
||||
|
|
@ -2198,7 +2199,8 @@ private:
|
|||
|
||||
slot.prompt.tokens.push_back(slot.sampled);
|
||||
|
||||
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
||||
SLT_DBG(slot, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n",
|
||||
slot.sampled,
|
||||
slot.n_ctx, slot.prompt.n_tokens(), slot.truncated);
|
||||
}
|
||||
}
|
||||
|
|
@ -2954,6 +2956,7 @@ private:
|
|||
|
||||
// update how many tokens out of those tested were accepted
|
||||
slot.n_draft_accepted += ids.size() - 1;
|
||||
slot.n_draft_total += n_draft;
|
||||
|
||||
// rollback to the state before sampling the draft tokens
|
||||
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);
|
||||
|
|
@ -2961,6 +2964,7 @@ private:
|
|||
// add accepted tokens to the prompt
|
||||
slot.prompt.tokens.insert({ids.begin(), ids.end() - 1});
|
||||
slot.sampled = ids.back(); // last accepted token
|
||||
SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft);
|
||||
|
||||
slot.spec_session->rewind(slot.prompt.n_tokens());
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue