mtp-batch(fix): avoid logits for mtp kv cache operations

This commit is contained in:
samuel 2025-10-16 13:42:31 -03:00
parent 0127c6beeb
commit cae85fe531
1 changed files with 17 additions and 8 deletions

View File

@ -1155,6 +1155,14 @@ int llama_context::decode(const llama_batch & batch_inp) {
// extract logits // extract logits
if (t_logits && n_outputs > 0) { if (t_logits && n_outputs > 0) {
// MTP operations that are purely for updating the KV cache
// (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor
// as a side effect of running the graph. If these logits are copied
// back to the main context buffer, they will overwrite the valid logits
// produced by the main model's pass, leading to incorrect sampling.
// This condition explicitly prevents that copy for cache-only operations.
if (batch_inp.mtp_params.op_type != MTP_OP_WARMUP &&
batch_inp.mtp_params.op_type != MTP_OP_UPDATE_ACCEPTED) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(backend_res != nullptr);
GGML_ASSERT(logits != nullptr); GGML_ASSERT(logits != nullptr);
@ -1167,6 +1175,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
} }
} }
}
// extract embeddings // extract embeddings
if (t_embd && n_outputs > 0) { if (t_embd && n_outputs > 0) {