mtp-batch (wip): Isolate MTP graph to prevent host embedding buffer corruption

This commit is contained in:
samuel 2025-09-27 19:42:32 -03:00
parent 75dc25e6fe
commit 67c6c069e0
8 changed files with 168 additions and 113 deletions

View File

@ -373,19 +373,9 @@ llama_token mtp_speculative_gen_draft(
if (!smpl) {
return -1;
}
const float * draft_input_hidden_state = llama_get_embeddings(ctx);
llama_set_draft_input_hidden_state(ctx, draft_input_hidden_state);
LOG_INF("[DEBUG-DRAFT-STATE] Main model final embd pointer: %p, State being used for draft: %p\n",
(void*)llama_get_embeddings(ctx), (void*)draft_input_hidden_state);
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
LOG_INF(
"[DEBUG-DRAFT-IN] Generating draft. id_last=%d, n_past=%d, last_tok_idx=%d\n",
id_last, n_past, draft_input_hidden_state
);
mtp_batch.update_mtp_kv = false;
mtp_batch.use_mtp_head = true;
@ -413,7 +403,9 @@ llama_token mtp_speculative_gen_draft(
}
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, const char* tag) {
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens,
bool is_prompt_warmup) {
if (tokens.empty()) {
return;
}
@ -423,26 +415,34 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_d
for (size_t i = 0; i < std::min((size_t)5, n_to_process); ++i) {
details_str += " {id: " + std::to_string(tokens[i].id) + ", pos: " + std::to_string(tokens[i].n_past) + "}";
}
LOG_INF("[MTP-UPDATE|%s] Updating %zu tokens. Details:%s ...\n", tag, n_to_process, details_str.c_str());
LOG_INF("[MTP-UPDATE|%s] Updating %zu tokens. Details:%s ...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", n_to_process, details_str.c_str());
// LOG_INF("[DEBUG-CHUNK] Warming up MTP model chunk. Batch size: %zu\n", n_to_process);
// std::string positions_str;
// for (size_t i = 0; i < std::min((size_t)5, n_to_process); ++i) {
// positions_str += std::to_string(tokens[i].n_past) + " ";
// }
// LOG_INF("[DEBUG-CHUNK] MTP warm-up positions: %s...\n", positions_str.c_str());
llama_batch mtp_batch = llama_batch_init(n_to_process, 0, 1);
for (size_t i = 0; i < n_to_process; ++i) {
const mtp_kv_update_data& token_data = tokens[i];
// Check seq_id {0}, it may be a problem with multiple sequences.
common_batch_add(mtp_batch, token_data.id, token_data.n_past, {0}, false);
}
mtp_batch.update_mtp_kv = true;
mtp_batch.use_mtp_head = true;
mtp_batch.is_mtp_prompt_warmup = is_prompt_warmup;
llama_decode(ctx, mtp_batch);
llama_batch_free(mtp_batch);
tokens.clear();
}
// Debug function - It will be removed later
double calculate_vector_sum_double(const float* vec, size_t size) {
if (!vec) {
return 0.0;
}
double sum = 0.0;
for (size_t i = 0; i < size; ++i) {
sum += vec[i];
}
return sum;
}

View File

@ -49,4 +49,7 @@ llama_tokens common_speculative_gen_draft(
const llama_tokens & prompt,
llama_token id_last);
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, const char* tag);
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens,
bool is_prompt_warmup);
double calculate_vector_sum_double(const float* vec, size_t size);

View File

@ -232,6 +232,7 @@ extern "C" {
int8_t * logits; // TODO: rename this to "output"
bool update_mtp_kv;
bool use_mtp_head;
bool is_mtp_prompt_warmup;
} llama_batch;
enum llama_model_kv_override_type {

View File

@ -843,6 +843,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
/*logits =*/ nullptr,
/*.use_mtp_head =*/ false,
/*update_mtp_kv =*/ false,
/*.is_mtp_prompt_warmup =*/ false,
};
if (embd) {

View File

@ -13,6 +13,7 @@
#include <cstring>
#include <limits>
#include <stdexcept>
#include <numeric>
//
// llama_context
@ -729,8 +730,19 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end);
}
static double calculate_vector_sum(const float* vec, size_t size) {
if (!vec) {
return 0.0;
}
double sum = 0.0;
for (size_t i = 0; i < size; ++i) {
sum += vec[i];
}
return sum;
}
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret,
bool do_mtp_kv_update, bool use_mtp_head) {
bool do_mtp_kv_update, bool use_mtp_head, bool is_mtp_prompt_warmup) {
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED;
@ -778,9 +790,20 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name);
const float * source_hidden_state = nullptr;
source_hidden_state = this->draft_input_hidden_state;
if (is_mtp_prompt_warmup || (do_mtp_kv_update && !is_mtp_prompt_warmup)) {
source_hidden_state = this->embd;
} else {
source_hidden_state = this->draft_input_hidden_state;
}
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
const size_t n_embd = this->model.hparams.n_embd;
const size_t n_tokens_for_sum = (do_mtp_kv_update && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd);
const char * op_type = (do_mtp_kv_update) ? "MTP_UPDATE" : "DRAFT_GEN";
LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum);
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
} else {
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",
@ -881,7 +904,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
cparams.causal_attn = false;
ggml_status status;
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false, false);
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false, false, false);
cparams.causal_attn = causal_attn_org;
@ -1107,7 +1130,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
int64_t n_outputs_prev = 0;
const bool do_mtp_kv_update = batch_inp.update_mtp_kv;
const bool use_mtp_head = batch_inp.use_mtp_head;
const bool is_prompt_warmup = batch_inp.n_tokens > 1 && (this->model.hparams.nextn_predict_layers > 0);
const bool is_prompt_warmup = batch_inp.is_mtp_prompt_warmup;
do {
const auto & ubatch = mctx->get_ubatch();
@ -1117,7 +1140,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
pos_str += std::to_string(ubatch.pos[i]) + " ";
}
LLAMA_LOG_WARN(
"[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Posições: %s...\n",
"[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n",
ubatch.n_tokens,
batch_inp.update_mtp_kv ? "true" : "false",
batch_inp.use_mtp_head ? "true" : "false",
@ -1149,7 +1172,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s...\n", positions_str.c_str());
}
ggml_status status;
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head);
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head, is_prompt_warmup);
if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
llama_pos pos_min[LLAMA_MAX_SEQ];
@ -1186,20 +1209,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
// if (is_prompt_warmup) {
// auto res_mtp = std::make_unique<llm_graph_result>(graph_max_nodes());
// ggml_status status_mtp;
// process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status_mtp, do_mtp_kv_update, use_mtp_head);
// if (status_mtp != GGML_STATUS_SUCCESS) {
// LLAMA_LOG_WARN("%s: Failure in MTP heating ubatch\n", __func__);
// }
// }
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
embd_tensor = res->get_embd();
if (t_embd && res->get_embd_pooled()) {
t_embd = res->get_embd_pooled();
@ -1220,58 +1231,69 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
}
if (use_mtp_head) {
if (t_embd != nullptr) {
LLAMA_LOG_ERROR("[MTP-GRAPH-BUG] The MTP graph returned an embedding tensor when it shouldn't have! This will cause corruption.\n");
} else {
LLAMA_LOG_WARN("[MTP-GRAPH-OK] The MTP graph correctly did not return an embedding tensor.\n");
}
}
// extract embeddings
if (t_embd && n_outputs > 0) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);
if (!use_mtp_head) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
float * embd_out = embd + n_outputs_prev*n_embd;
switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
float * embd_out = embd + n_outputs_prev*n_embd;
if (n_outputs) {
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
if (n_outputs) {
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
// extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq;
const uint32_t n_cls_out = hparams.n_cls_out;
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
}
} break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_LAST:
{
// extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq;
const uint32_t n_cls_out = hparams.n_cls_out;
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
}
}
} else {
LLAMA_LOG_WARN("[DEBUG-EMBD-COPY] Skipping embedding buffer copy for MTP operation (use_mtp_head=true).\n");
}
}
@ -1336,8 +1358,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
// overlap with device computation.
ggml_backend_sched_reset(sched.get());
}
if (!do_mtp_kv_update && !use_mtp_head) {
LLAMA_LOG_WARN("[DEBUG-EMBD-WRITE] Main decode completed. ctx->embd (%p) now contains the hidden state for the next draft.\n", (void*)this->embd);
if (!use_mtp_head) {
synchronize();
const size_t n_embd = this->model.hparams.n_embd;
double full_buffer_sum = calculate_vector_sum(this->embd, n_outputs_all * n_embd);
LLAMA_LOG_WARN("[INTEGRITY-CHECK|A] After main decode. ubatch_size=%d. Checksum: %e\n", n_outputs_all, full_buffer_sum);
}
return 0;
}

View File

@ -103,7 +103,8 @@ struct llama_context {
llama_memory_context_i * mctx,
ggml_status & ret,
const bool do_mtp_kv_update,
const bool use_mtp_head);
const bool use_mtp_head,
bool is_mtp_prompt_warmup);
int encode(const llama_batch & batch_inp);
int decode(const llama_batch & batch_inp);

View File

@ -13821,20 +13821,19 @@ struct llm_build_glm4_moe : public llm_graph_context {
auto inp_mtp = std::make_unique<llm_graph_input_mtp_states>();
inp_mtp->states = hidden_states_from_main_model;
res->add_input(std::move(inp_mtp));
} else {
hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
ggml_set_input(hidden_states_from_main_model);
} else {
hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
ggml_set_input(hidden_states_from_main_model);
auto inp_mtp = std::make_unique<llm_graph_input_mtp_states>();
inp_mtp->states = hidden_states_from_main_model;
res->add_input(std::move(inp_mtp));
}
res->t_embd = hidden_states_from_main_model;
auto inp_mtp = std::make_unique<llm_graph_input_mtp_states>();
inp_mtp->states = hidden_states_from_main_model;
res->add_input(std::move(inp_mtp));
}
const int il_mtp = hparams.n_layer - 1;
const auto & mtp_layer = model.layers[il_mtp];
res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head);
const int il_mtp = hparams.n_layer - 1;
const auto & mtp_layer = model.layers[il_mtp];
res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head);
} else {
ggml_tensor * inpL = build_inp_embd(model.tok_embd);
@ -13991,9 +13990,11 @@ private:
ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings,
int64_t n_embd_head
) {
ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings);
const int il = hparams.n_layer - 1;
// LLAMA_LOG_WARN("[DEBUG-KV] MTP Head Path: Accessing layer %d\n", il);
ggml_tensor * sum_node = ggml_sum(ctx0, prev_embeddings);
ggml_tensor * sum_node = ggml_sum(ctx0, embd_copy);
ggml_set_name(sum_node, "mtp_input_sum");
@ -14002,7 +14003,7 @@ private:
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens);
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
ggml_tensor * hidden_state_norm = build_norm(prev_embeddings, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined);
@ -18694,13 +18695,15 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
GGML_ABORT("fatal error");
}
// add on pooling layer
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
const int64_t t_end_us = ggml_time_us(); // Fim do cronômetro
if (!params.use_mtp_head) {
// add on pooling layer
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
}
const int64_t t_end_us = ggml_time_us();
LLAMA_LOG_INFO(
"[PERF] Graph build time: %.2f ms (MTP path: %s)\n",
(t_end_us - t_start_us) / 1000.0,
build_mtp ? "yes" : "no"
params.use_mtp_head ? "yes" : "no"
);
return llm->res->get_gf();
}

View File

@ -3522,8 +3522,6 @@ struct server_context {
}
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
// Aquece o cache MTP para os pedaços do prompt que acabaram de ser processados.
// Esta lógica SÓ deve ser executada durante o processamento do prompt.
for (auto & slot : slots) {
if (slot.state == SLOT_STATE_PROCESSING_PROMPT && slot.has_mtp && !slot.mtp_kv_update_batch.empty()) {
SLT_INF(slot, "DEBUG-KV-REQ: Warming up MTP cache for prompt chunk of size %zu. Positions: %d ... %d\n",
@ -3531,7 +3529,7 @@ struct server_context {
slot.mtp_kv_update_batch.front().n_past,
slot.mtp_kv_update_batch.back().n_past
);
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, "PROMPT_WARMUP");
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, true);
}
}
@ -3569,13 +3567,16 @@ struct server_context {
}
const int tok_idx = slot.i_batch - i;
// Sets the initial state for the first draft generation.
if (slot.has_mtp) {
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1));
}
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
slot.last_tok_idx = tok_idx;
//SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str());
slot.i_batch = -1;
SLT_INF(slot, "[SAMPLER-ACCEPT] Accepting token ID %d at index %zu\n", id, i);
common_sampler_accept(slot.smpl, id, true);
slot.n_decoded += 1;
@ -3647,6 +3648,7 @@ struct server_context {
llama_tokens draft;
if (slot.has_mtp) {
SLT_INF(slot, "[POS-SYNC] Before draft gen. n_past = %d\n", slot.n_past);
llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx);
draft.reserve(1);
draft.push_back(draft_id);
@ -3682,21 +3684,39 @@ struct server_context {
}
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
SLT_INF(slot, "[POS-SYNC] Before validation decode. n_past = %d, spec_batch_size = %d\n", slot.n_past, slot.batch_spec.n_tokens);
llama_decode(ctx, slot.batch_spec);
const size_t n_embd = llama_n_embd(llama_get_model(ctx));
const size_t golden_buffer_size_in_floats = slot.batch_spec.n_tokens * n_embd;
const float* golden_embd_ptr = llama_get_embeddings(ctx);
double golden_checksum = calculate_vector_sum_double(golden_embd_ptr, golden_buffer_size_in_floats);
SLT_INF(slot, "[VERIFY] Golden checksum after validation: %e (size: %zu tokens)\n", golden_checksum, slot.batch_spec.n_tokens);
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
SLT_INF(slot, "[POS-SYNC] Tokens accepted: %zu\n", ids.size());
if (slot.has_mtp) {
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1));
const float* embd_after_draft_ptr = llama_get_embeddings(ctx);
double checksum_after_draft = calculate_vector_sum_double(embd_after_draft_ptr, golden_buffer_size_in_floats);
SLT_INF(slot, "[VERIFY] Checksum after draft gen (should be unchanged): %e\n", checksum_after_draft);
slot.mtp_kv_update_batch.clear();
for (int32_t i = 0; i < ids.size(); ++i) {
slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + i, i });
}
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, "GEN_ACCEPTED");
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, false);
const float* embd_after_update_ptr = llama_get_embeddings(ctx);
double checksum_after_update = calculate_vector_sum_double(embd_after_update_ptr, golden_buffer_size_in_floats);
SLT_INF(slot, "[VERIFY] Checksum after MTP update (should be unchanged): %e\n", checksum_after_update);
}
slot.n_past += ids.size();
SLT_INF(slot, "[POS-SYNC] After n_past update. New n_past = %d\n", slot.n_past);
slot.n_decoded += ids.size();
// update how many tokens out of those tested were accepted