mtp-batch (wip): Isolate MTP graph to prevent host embedding buffer corruption
This commit is contained in:
parent
75dc25e6fe
commit
67c6c069e0
|
|
@ -373,19 +373,9 @@ llama_token mtp_speculative_gen_draft(
|
|||
if (!smpl) {
|
||||
return -1;
|
||||
}
|
||||
const float * draft_input_hidden_state = llama_get_embeddings(ctx);
|
||||
llama_set_draft_input_hidden_state(ctx, draft_input_hidden_state);
|
||||
LOG_INF("[DEBUG-DRAFT-STATE] Main model final embd pointer: %p, State being used for draft: %p\n",
|
||||
(void*)llama_get_embeddings(ctx), (void*)draft_input_hidden_state);
|
||||
|
||||
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
|
||||
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
|
||||
|
||||
LOG_INF(
|
||||
"[DEBUG-DRAFT-IN] Generating draft. id_last=%d, n_past=%d, last_tok_idx=%d\n",
|
||||
id_last, n_past, draft_input_hidden_state
|
||||
);
|
||||
|
||||
mtp_batch.update_mtp_kv = false;
|
||||
mtp_batch.use_mtp_head = true;
|
||||
|
||||
|
|
@ -413,7 +403,9 @@ llama_token mtp_speculative_gen_draft(
|
|||
}
|
||||
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, const char* tag) {
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens,
|
||||
bool is_prompt_warmup) {
|
||||
|
||||
if (tokens.empty()) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -423,26 +415,34 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_d
|
|||
for (size_t i = 0; i < std::min((size_t)5, n_to_process); ++i) {
|
||||
details_str += " {id: " + std::to_string(tokens[i].id) + ", pos: " + std::to_string(tokens[i].n_past) + "}";
|
||||
}
|
||||
LOG_INF("[MTP-UPDATE|%s] Updating %zu tokens. Details:%s ...\n", tag, n_to_process, details_str.c_str());
|
||||
LOG_INF("[MTP-UPDATE|%s] Updating %zu tokens. Details:%s ...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", n_to_process, details_str.c_str());
|
||||
|
||||
// LOG_INF("[DEBUG-CHUNK] Warming up MTP model chunk. Batch size: %zu\n", n_to_process);
|
||||
// std::string positions_str;
|
||||
// for (size_t i = 0; i < std::min((size_t)5, n_to_process); ++i) {
|
||||
// positions_str += std::to_string(tokens[i].n_past) + " ";
|
||||
// }
|
||||
// LOG_INF("[DEBUG-CHUNK] MTP warm-up positions: %s...\n", positions_str.c_str());
|
||||
llama_batch mtp_batch = llama_batch_init(n_to_process, 0, 1);
|
||||
|
||||
for (size_t i = 0; i < n_to_process; ++i) {
|
||||
const mtp_kv_update_data& token_data = tokens[i];
|
||||
// Check seq_id {0}, it may be a problem with multiple sequences.
|
||||
common_batch_add(mtp_batch, token_data.id, token_data.n_past, {0}, false);
|
||||
}
|
||||
|
||||
mtp_batch.update_mtp_kv = true;
|
||||
mtp_batch.use_mtp_head = true;
|
||||
mtp_batch.is_mtp_prompt_warmup = is_prompt_warmup;
|
||||
|
||||
llama_decode(ctx, mtp_batch);
|
||||
|
||||
llama_batch_free(mtp_batch);
|
||||
tokens.clear();
|
||||
}
|
||||
|
||||
// Debug function - It will be removed later
|
||||
double calculate_vector_sum_double(const float* vec, size_t size) {
|
||||
if (!vec) {
|
||||
return 0.0;
|
||||
}
|
||||
double sum = 0.0;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
sum += vec[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
|
@ -49,4 +49,7 @@ llama_tokens common_speculative_gen_draft(
|
|||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, const char* tag);
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens,
|
||||
bool is_prompt_warmup);
|
||||
|
||||
double calculate_vector_sum_double(const float* vec, size_t size);
|
||||
|
|
@ -232,6 +232,7 @@ extern "C" {
|
|||
int8_t * logits; // TODO: rename this to "output"
|
||||
bool update_mtp_kv;
|
||||
bool use_mtp_head;
|
||||
bool is_mtp_prompt_warmup;
|
||||
} llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
|
|
|
|||
|
|
@ -843,6 +843,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|||
/*logits =*/ nullptr,
|
||||
/*.use_mtp_head =*/ false,
|
||||
/*update_mtp_kv =*/ false,
|
||||
/*.is_mtp_prompt_warmup =*/ false,
|
||||
};
|
||||
|
||||
if (embd) {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <stdexcept>
|
||||
#include <numeric>
|
||||
|
||||
//
|
||||
// llama_context
|
||||
|
|
@ -729,8 +730,19 @@ bool llama_context::apply_adapter_cvec(
|
|||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
static double calculate_vector_sum(const float* vec, size_t size) {
|
||||
if (!vec) {
|
||||
return 0.0;
|
||||
}
|
||||
double sum = 0.0;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
sum += vec[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret,
|
||||
bool do_mtp_kv_update, bool use_mtp_head) {
|
||||
bool do_mtp_kv_update, bool use_mtp_head, bool is_mtp_prompt_warmup) {
|
||||
if (mctx && !mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
|
|
@ -778,9 +790,20 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name);
|
||||
|
||||
const float * source_hidden_state = nullptr;
|
||||
source_hidden_state = this->draft_input_hidden_state;
|
||||
if (is_mtp_prompt_warmup || (do_mtp_kv_update && !is_mtp_prompt_warmup)) {
|
||||
source_hidden_state = this->embd;
|
||||
} else {
|
||||
source_hidden_state = this->draft_input_hidden_state;
|
||||
}
|
||||
|
||||
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
|
||||
const size_t n_embd = this->model.hparams.n_embd;
|
||||
const size_t n_tokens_for_sum = (do_mtp_kv_update && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
|
||||
double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd);
|
||||
const char * op_type = (do_mtp_kv_update) ? "MTP_UPDATE" : "DRAFT_GEN";
|
||||
|
||||
LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum);
|
||||
|
||||
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",
|
||||
|
|
@ -881,7 +904,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
cparams.causal_attn = false;
|
||||
|
||||
ggml_status status;
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false, false);
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false, false, false);
|
||||
|
||||
cparams.causal_attn = causal_attn_org;
|
||||
|
||||
|
|
@ -1107,7 +1130,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
int64_t n_outputs_prev = 0;
|
||||
const bool do_mtp_kv_update = batch_inp.update_mtp_kv;
|
||||
const bool use_mtp_head = batch_inp.use_mtp_head;
|
||||
const bool is_prompt_warmup = batch_inp.n_tokens > 1 && (this->model.hparams.nextn_predict_layers > 0);
|
||||
const bool is_prompt_warmup = batch_inp.is_mtp_prompt_warmup;
|
||||
|
||||
do {
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
|
|
@ -1117,7 +1140,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
pos_str += std::to_string(ubatch.pos[i]) + " ";
|
||||
}
|
||||
LLAMA_LOG_WARN(
|
||||
"[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Posições: %s...\n",
|
||||
"[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n",
|
||||
ubatch.n_tokens,
|
||||
batch_inp.update_mtp_kv ? "true" : "false",
|
||||
batch_inp.use_mtp_head ? "true" : "false",
|
||||
|
|
@ -1149,7 +1172,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
LLAMA_LOG_WARN("[DEBUG-MTP-UPDATE] Positions: %s...\n", positions_str.c_str());
|
||||
}
|
||||
ggml_status status;
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head);
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head, is_prompt_warmup);
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||
llama_pos pos_min[LLAMA_MAX_SEQ];
|
||||
|
|
@ -1186,20 +1209,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
// if (is_prompt_warmup) {
|
||||
// auto res_mtp = std::make_unique<llm_graph_result>(graph_max_nodes());
|
||||
// ggml_status status_mtp;
|
||||
|
||||
// process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status_mtp, do_mtp_kv_update, use_mtp_head);
|
||||
|
||||
// if (status_mtp != GGML_STATUS_SUCCESS) {
|
||||
// LLAMA_LOG_WARN("%s: Failure in MTP heating ubatch\n", __func__);
|
||||
// }
|
||||
// }
|
||||
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
embd_tensor = res->get_embd();
|
||||
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
t_embd = res->get_embd_pooled();
|
||||
|
|
@ -1220,58 +1231,69 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
}
|
||||
|
||||
if (use_mtp_head) {
|
||||
if (t_embd != nullptr) {
|
||||
LLAMA_LOG_ERROR("[MTP-GRAPH-BUG] The MTP graph returned an embedding tensor when it shouldn't have! This will cause corruption.\n");
|
||||
} else {
|
||||
LLAMA_LOG_WARN("[MTP-GRAPH-OK] The MTP graph correctly did not return an embedding tensor.\n");
|
||||
}
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd && n_outputs > 0) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
if (!use_mtp_head) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
float * embd_out = embd + n_outputs_prev*n_embd;
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
float * embd_out = embd + n_outputs_prev*n_embd;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - n_cls_out floats per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_cls_out);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - n_cls_out floats per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_cls_out);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LLAMA_LOG_WARN("[DEBUG-EMBD-COPY] Skipping embedding buffer copy for MTP operation (use_mtp_head=true).\n");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1336,8 +1358,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
}
|
||||
if (!do_mtp_kv_update && !use_mtp_head) {
|
||||
LLAMA_LOG_WARN("[DEBUG-EMBD-WRITE] Main decode completed. ctx->embd (%p) now contains the hidden state for the next draft.\n", (void*)this->embd);
|
||||
|
||||
if (!use_mtp_head) {
|
||||
synchronize();
|
||||
const size_t n_embd = this->model.hparams.n_embd;
|
||||
double full_buffer_sum = calculate_vector_sum(this->embd, n_outputs_all * n_embd);
|
||||
LLAMA_LOG_WARN("[INTEGRITY-CHECK|A] After main decode. ubatch_size=%d. Checksum: %e\n", n_outputs_all, full_buffer_sum);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,7 +103,8 @@ struct llama_context {
|
|||
llama_memory_context_i * mctx,
|
||||
ggml_status & ret,
|
||||
const bool do_mtp_kv_update,
|
||||
const bool use_mtp_head);
|
||||
const bool use_mtp_head,
|
||||
bool is_mtp_prompt_warmup);
|
||||
|
||||
int encode(const llama_batch & batch_inp);
|
||||
int decode(const llama_batch & batch_inp);
|
||||
|
|
|
|||
|
|
@ -13821,20 +13821,19 @@ struct llm_build_glm4_moe : public llm_graph_context {
|
|||
auto inp_mtp = std::make_unique<llm_graph_input_mtp_states>();
|
||||
inp_mtp->states = hidden_states_from_main_model;
|
||||
res->add_input(std::move(inp_mtp));
|
||||
} else {
|
||||
hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
|
||||
ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
|
||||
ggml_set_input(hidden_states_from_main_model);
|
||||
} else {
|
||||
hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
|
||||
ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
|
||||
ggml_set_input(hidden_states_from_main_model);
|
||||
|
||||
auto inp_mtp = std::make_unique<llm_graph_input_mtp_states>();
|
||||
inp_mtp->states = hidden_states_from_main_model;
|
||||
res->add_input(std::move(inp_mtp));
|
||||
}
|
||||
res->t_embd = hidden_states_from_main_model;
|
||||
auto inp_mtp = std::make_unique<llm_graph_input_mtp_states>();
|
||||
inp_mtp->states = hidden_states_from_main_model;
|
||||
res->add_input(std::move(inp_mtp));
|
||||
}
|
||||
|
||||
const int il_mtp = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il_mtp];
|
||||
res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head);
|
||||
const int il_mtp = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il_mtp];
|
||||
res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head);
|
||||
|
||||
} else {
|
||||
ggml_tensor * inpL = build_inp_embd(model.tok_embd);
|
||||
|
|
@ -13991,9 +13990,11 @@ private:
|
|||
ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings,
|
||||
int64_t n_embd_head
|
||||
) {
|
||||
ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings);
|
||||
|
||||
const int il = hparams.n_layer - 1;
|
||||
// LLAMA_LOG_WARN("[DEBUG-KV] MTP Head Path: Accessing layer %d\n", il);
|
||||
ggml_tensor * sum_node = ggml_sum(ctx0, prev_embeddings);
|
||||
ggml_tensor * sum_node = ggml_sum(ctx0, embd_copy);
|
||||
|
||||
ggml_set_name(sum_node, "mtp_input_sum");
|
||||
|
||||
|
|
@ -14002,7 +14003,7 @@ private:
|
|||
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens);
|
||||
|
||||
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
|
||||
ggml_tensor * hidden_state_norm = build_norm(prev_embeddings, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
|
||||
ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined);
|
||||
|
|
@ -18694,13 +18695,15 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
// add on pooling layer
|
||||
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
|
||||
const int64_t t_end_us = ggml_time_us(); // Fim do cronômetro
|
||||
if (!params.use_mtp_head) {
|
||||
// add on pooling layer
|
||||
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
|
||||
}
|
||||
const int64_t t_end_us = ggml_time_us();
|
||||
LLAMA_LOG_INFO(
|
||||
"[PERF] Graph build time: %.2f ms (MTP path: %s)\n",
|
||||
(t_end_us - t_start_us) / 1000.0,
|
||||
build_mtp ? "yes" : "no"
|
||||
params.use_mtp_head ? "yes" : "no"
|
||||
);
|
||||
return llm->res->get_gf();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3522,8 +3522,6 @@ struct server_context {
|
|||
}
|
||||
|
||||
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
|
||||
// Aquece o cache MTP para os pedaços do prompt que acabaram de ser processados.
|
||||
// Esta lógica SÓ deve ser executada durante o processamento do prompt.
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT && slot.has_mtp && !slot.mtp_kv_update_batch.empty()) {
|
||||
SLT_INF(slot, "DEBUG-KV-REQ: Warming up MTP cache for prompt chunk of size %zu. Positions: %d ... %d\n",
|
||||
|
|
@ -3531,7 +3529,7 @@ struct server_context {
|
|||
slot.mtp_kv_update_batch.front().n_past,
|
||||
slot.mtp_kv_update_batch.back().n_past
|
||||
);
|
||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, "PROMPT_WARMUP");
|
||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3569,13 +3567,16 @@ struct server_context {
|
|||
}
|
||||
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
|
||||
// Sets the initial state for the first draft generation.
|
||||
if (slot.has_mtp) {
|
||||
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1));
|
||||
}
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||
slot.last_tok_idx = tok_idx;
|
||||
//SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str());
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
SLT_INF(slot, "[SAMPLER-ACCEPT] Accepting token ID %d at index %zu\n", id, i);
|
||||
common_sampler_accept(slot.smpl, id, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
|
@ -3647,6 +3648,7 @@ struct server_context {
|
|||
|
||||
llama_tokens draft;
|
||||
if (slot.has_mtp) {
|
||||
SLT_INF(slot, "[POS-SYNC] Before draft gen. n_past = %d\n", slot.n_past);
|
||||
llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx);
|
||||
draft.reserve(1);
|
||||
draft.push_back(draft_id);
|
||||
|
|
@ -3682,21 +3684,39 @@ struct server_context {
|
|||
}
|
||||
|
||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
||||
|
||||
SLT_INF(slot, "[POS-SYNC] Before validation decode. n_past = %d, spec_batch_size = %d\n", slot.n_past, slot.batch_spec.n_tokens);
|
||||
llama_decode(ctx, slot.batch_spec);
|
||||
|
||||
const size_t n_embd = llama_n_embd(llama_get_model(ctx));
|
||||
const size_t golden_buffer_size_in_floats = slot.batch_spec.n_tokens * n_embd;
|
||||
const float* golden_embd_ptr = llama_get_embeddings(ctx);
|
||||
double golden_checksum = calculate_vector_sum_double(golden_embd_ptr, golden_buffer_size_in_floats);
|
||||
SLT_INF(slot, "[VERIFY] Golden checksum after validation: %e (size: %zu tokens)\n", golden_checksum, slot.batch_spec.n_tokens);
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||
|
||||
SLT_INF(slot, "[POS-SYNC] Tokens accepted: %zu\n", ids.size());
|
||||
|
||||
if (slot.has_mtp) {
|
||||
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1));
|
||||
|
||||
const float* embd_after_draft_ptr = llama_get_embeddings(ctx);
|
||||
double checksum_after_draft = calculate_vector_sum_double(embd_after_draft_ptr, golden_buffer_size_in_floats);
|
||||
SLT_INF(slot, "[VERIFY] Checksum after draft gen (should be unchanged): %e\n", checksum_after_draft);
|
||||
|
||||
slot.mtp_kv_update_batch.clear();
|
||||
for (int32_t i = 0; i < ids.size(); ++i) {
|
||||
slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + i, i });
|
||||
}
|
||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, "GEN_ACCEPTED");
|
||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, false);
|
||||
|
||||
const float* embd_after_update_ptr = llama_get_embeddings(ctx);
|
||||
double checksum_after_update = calculate_vector_sum_double(embd_after_update_ptr, golden_buffer_size_in_floats);
|
||||
SLT_INF(slot, "[VERIFY] Checksum after MTP update (should be unchanged): %e\n", checksum_after_update);
|
||||
}
|
||||
|
||||
slot.n_past += ids.size();
|
||||
SLT_INF(slot, "[POS-SYNC] After n_past update. New n_past = %d\n", slot.n_past);
|
||||
slot.n_decoded += ids.size();
|
||||
|
||||
// update how many tokens out of those tested were accepted
|
||||
|
|
|
|||
Loading…
Reference in New Issue