mtp-batch(chore): Fix logit flags for speculative sampling and remove debug logs
This commit is contained in:
parent
a99709d0c1
commit
b4cbe030ac
|
|
@ -380,11 +380,6 @@ llama_token mtp_speculative_gen_draft(
|
|||
|
||||
mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN;
|
||||
|
||||
// LOG_INF("[DEBUG-DRAFT-CALL] Calling llama_decode for draft. update_mtp_kv=%s, use_mtp_head=%s\n",
|
||||
// mtp_batch.update_mtp_kv ? "true" : "false",
|
||||
// mtp_batch.use_mtp_head ? "true" : "false"
|
||||
// );
|
||||
|
||||
// Perform the MTP draft generation decode. This writes the MTP layer's
|
||||
// KV state for the draft token into the cache.
|
||||
llama_decode(ctx, mtp_batch);
|
||||
|
|
@ -416,7 +411,7 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
|
|||
return;
|
||||
}
|
||||
|
||||
LOG_INF("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
|
||||
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
|
||||
|
||||
llama_batch mtp_batch = batch;
|
||||
if (is_prompt_warmup) {
|
||||
|
|
@ -426,7 +421,7 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
|
|||
}
|
||||
|
||||
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
|
||||
mtp_batch.logits[i] = false;
|
||||
mtp_batch.logits[i] = true;
|
||||
}
|
||||
llama_decode(ctx, mtp_batch);
|
||||
}
|
||||
|
|
@ -447,7 +442,7 @@ void mtp_accept_tokens(
|
|||
|
||||
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, false);
|
||||
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
|
||||
}
|
||||
|
||||
mtp_update_kv_cache(ctx, accepted_batch, false);
|
||||
|
|
@ -456,15 +451,3 @@ void mtp_accept_tokens(
|
|||
|
||||
llama_batch_free(accepted_batch);
|
||||
}
|
||||
|
||||
// Debug function - It will be removed later
|
||||
double calculate_vector_sum_double(const float* vec, size_t size) {
|
||||
if (!vec) {
|
||||
return 0.0;
|
||||
}
|
||||
double sum = 0.0;
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
sum += vec[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
|
@ -57,5 +57,3 @@ void mtp_accept_tokens(
|
|||
int32_t n_past_base,
|
||||
llama_seq_id seq_id
|
||||
);
|
||||
|
||||
double calculate_vector_sum_double(const float* vec, size_t size);
|
||||
|
|
@ -809,15 +809,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
||||
}
|
||||
|
||||
const int64_t t_exec_start_us = ggml_time_us();
|
||||
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
|
||||
const int64_t t_exec_end_us = ggml_time_us();
|
||||
// LLAMA_LOG_INFO(
|
||||
// "[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n",
|
||||
// (t_exec_end_us - t_exec_start_us) / 1000.0,
|
||||
// ubatch.n_tokens,
|
||||
// do_mtp_kv_update ? "yes" : "no"
|
||||
// );
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
||||
ret = status;
|
||||
|
|
@ -827,9 +819,6 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
ret = GGML_STATUS_SUCCESS;
|
||||
if (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum");
|
||||
if (sum_tensor) {
|
||||
LLAMA_LOG_WARN("[DEBUG-SUM] MTP input sum node successfully created.\n");
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
|
@ -1123,20 +1112,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
do {
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
if (ubatch.n_tokens > 0) {
|
||||
std::string pos_str;
|
||||
for (uint32_t i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) {
|
||||
pos_str += std::to_string(ubatch.pos[i]) + " ";
|
||||
}
|
||||
// LLAMA_LOG_WARN(
|
||||
// "[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n",
|
||||
// ubatch.n_tokens,
|
||||
// batch_inp.update_mtp_kv ? "true" : "false",
|
||||
// batch_inp.use_mtp_head ? "true" : "false",
|
||||
// pos_str.c_str()
|
||||
// );
|
||||
}
|
||||
|
||||
// count the outputs in this ubatch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
|
@ -1281,8 +1256,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LLAMA_LOG_WARN("[DEBUG-EMBD-COPY] Skipping embedding buffer copy for MTP operation (use_mtp_head=true).\n");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1347,13 +1320,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
}
|
||||
|
||||
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
|
||||
synchronize();
|
||||
const size_t n_embd = this->model.hparams.n_embd;
|
||||
double full_buffer_sum = calculate_vector_sum(this->embd, n_outputs_all * n_embd);
|
||||
LLAMA_LOG_WARN("[INTEGRITY-CHECK|A] After main decode. ubatch_size=%d. Checksum: %e\n", n_outputs_all, full_buffer_sum);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -3124,7 +3090,7 @@ std::unique_ptr<llama_memory_context_i> llama_context::initialize_decode_context
|
|||
if (cparams.warmup) {
|
||||
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||
} else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) {
|
||||
LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n");
|
||||
LLAMA_LOG_DEBUG("%s: Forcing sinfos, bypassing find_slot.\n", __func__);
|
||||
mctx = static_cast<llama_kv_cache_unified *>(memory.get())->init_batch_with_sinfos(
|
||||
*balloc, cparams.n_ubatch, *kvd->forced_sinfos, true
|
||||
);
|
||||
|
|
@ -3160,10 +3126,6 @@ bool llama_context::prepare_mtp_graph_inputs(
|
|||
}
|
||||
|
||||
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
|
||||
const size_t n_embd = this->model.hparams.n_embd;
|
||||
const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
|
||||
double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd);
|
||||
|
||||
const char * op_type;
|
||||
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
op_type = "MTP_UPDATE";
|
||||
|
|
@ -3171,8 +3133,6 @@ bool llama_context::prepare_mtp_graph_inputs(
|
|||
op_type = "DRAFT_GEN";
|
||||
}
|
||||
|
||||
LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum);
|
||||
|
||||
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",
|
||||
|
|
|
|||
|
|
@ -766,7 +766,6 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|||
}
|
||||
|
||||
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||
LLAMA_LOG_WARN("%s: Entering find_slot for ubatch of %d tokens.\n", __func__, ubatch.n_tokens);
|
||||
if (debug > 0) {
|
||||
const auto & cells = v_cells[seq_to_stream[1]];
|
||||
|
||||
|
|
@ -972,9 +971,6 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||
}
|
||||
}
|
||||
}
|
||||
LLAMA_LOG_WARN("%s: find_slot SUCCEEDED for ubatch of %d tokens. Idxs:%s\n", __func__, ubatch.n_tokens, idxs_str.c_str());
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: find_slot FAILED to allocate cells for ubatch of %d tokens.\n", __func__, ubatch.n_tokens);
|
||||
}
|
||||
|
||||
return res;
|
||||
|
|
|
|||
|
|
@ -13788,35 +13788,11 @@ struct llm_build_glm4 : public llm_graph_context {
|
|||
|
||||
struct llm_build_glm4_moe : public llm_graph_context {
|
||||
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
// LLAMA_LOG_WARN(
|
||||
// "[GRAPH_BUILD] Building graph. Path: %s, MTP_Update: %s, UBatch_Tokens: %d, First_Pos: %d\n",
|
||||
// params.use_mtp_head ? "MTP" : "MAIN",
|
||||
// params.update_mtp_kv ? "true" : "false",
|
||||
// n_tokens,
|
||||
// n_tokens > 0 ? ubatch.pos[0] : -1
|
||||
// );
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
// LLAMA_LOG_WARN(
|
||||
// "[DEBUG-GRAPH-STATE] Building graph. MTP Head=%s, MTP KV Update=%s, n_tokens=%d\n",
|
||||
// params.use_mtp_head ? "true" : "false",
|
||||
// params.update_mtp_kv ? "true" : "false",
|
||||
// n_tokens
|
||||
// );
|
||||
// for (int i = 0; i < n_tokens; ++i) {
|
||||
// LLAMA_LOG_WARN(" - ubatch token[%d]: ID=%d, Pos=%d\n", i, ubatch.token[i], ubatch.pos[i]);
|
||||
// }
|
||||
if (n_tokens > 0) {
|
||||
LLAMA_LOG_WARN(
|
||||
" - ubatch tokens: [ID=%d, Pos=%d] ... [ID=%d, Pos=%d]\n",
|
||||
ubatch.token[0], ubatch.pos[0],
|
||||
ubatch.token[n_tokens-1], ubatch.pos[n_tokens-1]
|
||||
);
|
||||
}
|
||||
|
||||
if (params.mtp_params.op_type != MTP_OP_NONE) {
|
||||
ggml_tensor* hidden_states_from_main_model;
|
||||
|
||||
|
|
@ -13913,10 +13889,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
|
|||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (ubatch.n_tokens > 0) {
|
||||
LLAMA_LOG_WARN("[KV_WRITE] path=MAIN, layer=%d, n_tokens=%d, pos_start=%d, pos_end=%d\n",
|
||||
il, ubatch.n_tokens, ubatch.pos[0], ubatch.pos[ubatch.n_tokens-1]);
|
||||
}
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
|
|
@ -14066,14 +14039,7 @@ private:
|
|||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
// LLAMA_LOG_WARN("[DEBUG-MTP-ATTN] Inputs for build_attn in the layer %d:\n", il);
|
||||
// LLAMA_LOG_WARN(" - Qcur shape: [%d, %d, %d]\n", Qcur->ne[0], Qcur->ne[1], Qcur->ne[2]);
|
||||
// LLAMA_LOG_WARN(" - Kcur shape: [%d, %d, %d]\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2]);
|
||||
// LLAMA_LOG_WARN(" - Vcur shape: [%d, %d, %d]\n", Vcur->ne[0], Vcur->ne[1], Vcur->ne[2]);
|
||||
if (ubatch.n_tokens > 0) {
|
||||
LLAMA_LOG_WARN("[KV_WRITE] path=MTP, layer=%d, n_tokens=%d, pos_start=%d, pos_end=%d\n",
|
||||
il, ubatch.n_tokens, ubatch.pos[0], ubatch.pos[ubatch.n_tokens-1]);
|
||||
}
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
mtp_layer.wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
|
|
|
|||
|
|
@ -1738,7 +1738,7 @@ struct server_queue {
|
|||
|
||||
while (true) {
|
||||
QUE_DBG("%s", "processing new tasks\n");
|
||||
const int64_t t_turn_start_us = ggml_time_us();
|
||||
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (!running) {
|
||||
|
|
@ -1761,11 +1761,7 @@ struct server_queue {
|
|||
QUE_DBG("%s", "update slots\n");
|
||||
|
||||
callback_update_slots();
|
||||
const int64_t t_turn_end_us = ggml_time_us();
|
||||
SRV_DBG(
|
||||
"[PERF] Server turn time: %.2f ms\n",
|
||||
(t_turn_end_us - t_turn_start_us) / 1000.0
|
||||
);
|
||||
|
||||
QUE_DBG("%s", "waiting for new tasks\n");
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
|
|
@ -3471,7 +3467,6 @@ struct server_context {
|
|||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
};
|
||||
LOG_INF("\n[DEBUG-CHUNK] Processing main model chunk. Batch size: %d\n", n_tokens);
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
||||
|
|
@ -3569,10 +3564,8 @@ struct server_context {
|
|||
}
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||
slot.last_tok_idx = tok_idx;
|
||||
//SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str());
|
||||
|
||||
slot.i_batch = -1;
|
||||
SLT_INF(slot, "[SAMPLER-ACCEPT] Accepting token ID %d at index %zu\n", id, i);
|
||||
common_sampler_accept(slot.smpl, id, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
|
@ -3644,7 +3637,6 @@ struct server_context {
|
|||
|
||||
llama_tokens draft;
|
||||
if (slot.has_mtp) {
|
||||
SLT_INF(slot, "[POS-SYNC] Before draft gen. n_past = %d\n", slot.n_past);
|
||||
llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx);
|
||||
draft.reserve(1);
|
||||
draft.push_back(draft_id);
|
||||
|
|
@ -3680,26 +3672,14 @@ struct server_context {
|
|||
}
|
||||
|
||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
||||
SLT_INF(slot, "[POS-SYNC] Before validation decode. n_past = %d, spec_batch_size = %d\n", slot.n_past, slot.batch_spec.n_tokens);
|
||||
llama_decode(ctx, slot.batch_spec);
|
||||
|
||||
const size_t n_embd = llama_n_embd(llama_get_model(ctx));
|
||||
const size_t golden_buffer_size_in_floats = slot.batch_spec.n_tokens * n_embd;
|
||||
const float* golden_embd_ptr = llama_get_embeddings(ctx);
|
||||
double golden_checksum = calculate_vector_sum_double(golden_embd_ptr, golden_buffer_size_in_floats);
|
||||
SLT_INF(slot, "[VERIFY] Golden checksum after validation: %e (size: %zu tokens)\n", golden_checksum, slot.batch_spec.n_tokens);
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||
SLT_INF(slot, "[POS-SYNC] Tokens accepted: %zu\n", ids.size());
|
||||
|
||||
if (slot.has_mtp) {
|
||||
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1));
|
||||
|
||||
const float* embd_after_draft_ptr = llama_get_embeddings(ctx);
|
||||
double checksum_after_draft = calculate_vector_sum_double(embd_after_draft_ptr, golden_buffer_size_in_floats);
|
||||
SLT_INF(slot, "[VERIFY] Checksum after draft gen (should be unchanged): %e\n", checksum_after_draft);
|
||||
|
||||
if (!ids.empty()) {
|
||||
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1));
|
||||
} else {
|
||||
|
|
@ -3707,14 +3687,9 @@ struct server_context {
|
|||
}
|
||||
|
||||
mtp_accept_tokens(ctx, ids, slot.n_past, slot.id);
|
||||
|
||||
const float* embd_after_update_ptr = llama_get_embeddings(ctx);
|
||||
double checksum_after_update = calculate_vector_sum_double(embd_after_update_ptr, golden_buffer_size_in_floats);
|
||||
SLT_INF(slot, "[VERIFY] Checksum after MTP update (should be unchanged): %e\n", checksum_after_update);
|
||||
}
|
||||
|
||||
slot.n_past += ids.size();
|
||||
SLT_INF(slot, "[POS-SYNC] After n_past update. New n_past = %d\n", slot.n_past);
|
||||
slot.n_decoded += ids.size();
|
||||
|
||||
// update how many tokens out of those tested were accepted
|
||||
|
|
|
|||
Loading…
Reference in New Issue