mtp-batch (fix): prevent mtp draft from polluting the cache
This commit is contained in:
parent
5e1d719bef
commit
6f74ba3807
|
|
@ -374,6 +374,8 @@ llama_token mtp_speculative_gen_draft(
|
|||
return -1;
|
||||
}
|
||||
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
|
||||
const llama_pos draft_pos = n_past;
|
||||
const llama_seq_id draft_seq_id = 0;
|
||||
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
|
||||
|
||||
mtp_batch.update_mtp_kv = false;
|
||||
|
|
@ -387,6 +389,8 @@ llama_token mtp_speculative_gen_draft(
|
|||
llama_decode(ctx, mtp_batch);
|
||||
llama_batch_free(mtp_batch);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx, draft_seq_id, draft_pos, draft_pos + 1);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
const int n_vocab = llama_n_vocab(vocab);
|
||||
|
|
|
|||
|
|
@ -1460,9 +1460,13 @@ extern "C" {
|
|||
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
|
||||
|
||||
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
|
||||
|
||||
|
||||
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);
|
||||
|
||||
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);
|
||||
|
||||
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -3105,6 +3105,20 @@ void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float
|
|||
ctx->draft_input_hidden_state = hidden_state;
|
||||
}
|
||||
|
||||
bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx) {
|
||||
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
|
||||
const auto & last_sinfo = kvd->last_main_model_sinfos;
|
||||
|
||||
if (last_sinfo.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: The main call sinfo is not available for warmup.\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
kvd->forced_sinfos = &last_sinfo;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) {
|
||||
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
|
||||
const auto & last_sinfo = kvd->last_main_model_sinfos;
|
||||
|
|
@ -3126,4 +3140,14 @@ bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_acc
|
|||
void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) {
|
||||
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
|
||||
kvd->forced_sinfos = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
if (memory) {
|
||||
static_cast<llama_kv_cache_unified *>(memory.get())->seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
ctx->kv_cache_seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,6 +100,8 @@ struct llama_context {
|
|||
int32_t il_start,
|
||||
int32_t il_end);
|
||||
|
||||
void kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1);
|
||||
|
||||
// process a single ubatch with a specific graph type
|
||||
// if memory_context is provided, it will be applied first to the context's memory
|
||||
// ret contains the status of the graph computation
|
||||
|
|
|
|||
|
|
@ -3520,11 +3520,10 @@ struct server_context {
|
|||
needs_mtp_warmup = true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (needs_mtp_warmup) {
|
||||
if (llama_mtp_prepare_sinfo_for_update(ctx, batch_view.n_tokens)) {
|
||||
if (llama_mtp_prepare_sinfo_for_warmup(ctx)) {
|
||||
mtp_update_kv_cache(ctx, batch_view, true);
|
||||
|
||||
llama_mtp_cancel_sinfo_update(ctx);
|
||||
} else {
|
||||
LOG_ERR("%s: Failed to prepare the MTP symphony for warmup.", __func__);
|
||||
|
|
|
|||
Loading…
Reference in New Issue