From cae85fe531876762ee02524fc4c3f6c5e7824c63 Mon Sep 17 00:00:00 2001 From: samuel Date: Thu, 16 Oct 2025 13:42:31 -0300 Subject: [PATCH] mtp-batch(fix): avoid logits for mtp kv cache operations --- src/llama-context.cpp | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a5345ee2a4..fb35d6c79d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1155,16 +1155,25 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract logits if (t_logits && n_outputs > 0) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + // MTP operations that are purely for updating the KV cache + // (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor + // as a side effect of running the graph. If these logits are copied + // back to the main context buffer, they will overwrite the valid logits + // produced by the main model's pass, leading to incorrect sampling. + // This condition explicitly prevents that copy for cache-only operations. + if (batch_inp.mtp_params.op_type != MTP_OP_WARMUP && + batch_inp.mtp_params.op_type != MTP_OP_UPDATE_ACCEPTED) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); - float * logits_out = logits + n_outputs_prev*n_vocab; + float * logits_out = logits + n_outputs_prev*n_vocab; - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); - ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } } }