failed attempt to implement MTP; outputs tokens but KV cache management is unreasonable
This commit is contained in:
parent
cf0f7c0448
commit
6e9bafc7a7
|
|
@ -348,6 +348,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|||
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
/*for (int k = 0; k < (int)cur_p.size; ++k) {
|
||||
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f)\n",
|
||||
k, 0, cur_p.data[k].id, cur_p.data[k].p);
|
||||
}*/
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
|
||||
const llama_token id = cur_p.data[cur_p.selected].id;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "../src/llama-graph.h"
|
||||
#include "../src/llama-context.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
|
|
@ -362,126 +363,40 @@ llama_tokens common_speculative_gen_draft(
|
|||
}
|
||||
|
||||
|
||||
llama_tokens mtp_speculative_gen_draft(
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
int32_t last_tok_idx) {
|
||||
llama_token mtp_speculative_gen_draft(
|
||||
struct common_sampler* smpl,
|
||||
struct llama_context* ctx,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
int32_t last_tok_idx) {
|
||||
|
||||
llama_tokens result;
|
||||
|
||||
LOG_INF("step: '%d'\n", 1);
|
||||
|
||||
// sample one token from the draft model -- this does NOT generalize to >1 MTP head
|
||||
result.reserve(1);
|
||||
|
||||
// need to determine which architecture we're using so we call the correct MTP model
|
||||
const auto * model = llama_get_model(ctx);
|
||||
|
||||
LOG_INF("step: '%d'\n", 2);
|
||||
|
||||
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
||||
//auto * gf = model.build_graph(gparams);
|
||||
|
||||
LOG_INF("step: '%d'\n", 3);
|
||||
|
||||
/*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
||||
ret = GGML_STATUS_ALLOC_FAILED;
|
||||
return nullptr;
|
||||
}*/
|
||||
|
||||
//llm_graph_result res_mtp(ctx->graph_max_nodes());
|
||||
llm_graph_result * res_mtp;
|
||||
llama_ubatch ubatch_mtp;
|
||||
ubatch_mtp.n_tokens = 1;
|
||||
ubatch_mtp.pos = &n_past; // Critical for positional encoding
|
||||
|
||||
// We also need a minimal ubatch to provide positional context (RoPE)
|
||||
// ubatch_mtp.tokens = &last_token_id;
|
||||
// ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper
|
||||
// ubatch_mtp.logits = nullptr;
|
||||
// ubatch_mtp.all_pos_0 = -1;
|
||||
// ubatch_mtp.all_pos_1 = -1;
|
||||
// ubatch_mtp.all_seq_id = -1;
|
||||
|
||||
// Manually construct the graph parameters
|
||||
//const llm_graph_params params_mtp = {
|
||||
// /*.arch =*/ model->arch,
|
||||
// /*.hparams =*/ model->hparams,
|
||||
// /*.cparams =*/ ctx->cparams,
|
||||
// /*.ubatch =*/ ubatch_mtp,
|
||||
// /*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
|
||||
// /*.sched =*/ ctx->sched.get(),
|
||||
// /*.backend_cpu =*/ ctx->backend_cpu,
|
||||
// /*.cvec =*/ &ctx->cvec,
|
||||
// /*.loras =*/ &ctx->loras,
|
||||
// /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context
|
||||
// /*.cross =*/ &ctx->cross,
|
||||
// /*.n_outputs =*/ 1,
|
||||
// /*.cb =*/ ctx->graph_get_cb(),
|
||||
// /*.res =*/ &res_mtp, // Point to our temporary result object
|
||||
//};
|
||||
llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp);
|
||||
|
||||
LOG_INF("step: '%d'\n", 4);
|
||||
|
||||
// ggml_cgraph* build_mtp_graph(const llm_graph_params & params,
|
||||
// ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const;
|
||||
auto * last_embd = llama_get_embeddings_tensor(ctx);
|
||||
|
||||
LOG_INF("step: '%d'\n", 5);
|
||||
|
||||
GGML_ASSERT(model != nullptr);
|
||||
GGML_ASSERT(last_embd != nullptr);
|
||||
llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx);
|
||||
|
||||
auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past);
|
||||
common_sampler_sample(smpl, ctx, last_tok_idx, true);
|
||||
|
||||
if (!gf) {
|
||||
LOG_INF("%s: failed to initialize graph\n", __func__);
|
||||
//ret = GGML_STATUS_FAILED;
|
||||
return result;
|
||||
}
|
||||
const auto* cur_p = common_sampler_get_candidates(smpl);
|
||||
/*LOG_INF("cur_p->size: %d\n", cur_p->size);
|
||||
|
||||
LOG_INF("step: '%d'\n", 6);
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
|
||||
}*/
|
||||
|
||||
const auto status = llama_graph_compute(ctx, gf, false);
|
||||
// add drafted token for each sequence
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
LOG_INF("step: '%d'\n", 7);
|
||||
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
|
||||
// smpl will accept the token if it doesn't get rejected by main model later
|
||||
// common_sampler_accept(smpl, id, true);
|
||||
|
||||
struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp);
|
||||
float * ctx_logit_pointer = llama_get_logits(ctx);
|
||||
|
||||
LOG_INF("step: '%d'\n", 8);
|
||||
|
||||
if (logits_mtp) {
|
||||
llama_set_logits(ctx, logits_mtp);
|
||||
}
|
||||
|
||||
LOG_INF("step: '%d'\n", 9);
|
||||
|
||||
{
|
||||
common_sampler_sample(smpl, ctx, last_tok_idx, true);
|
||||
|
||||
LOG_INF("step: '%d'\n", 10);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
|
||||
}
|
||||
|
||||
// add drafted token for each sequence
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
|
||||
// smpl will accept the token if it doesn't get rejected by main model later
|
||||
// common_sampler_accept(smpl, id, true);
|
||||
|
||||
result.push_back(id);
|
||||
}
|
||||
|
||||
return result;
|
||||
//llama_tokens result;
|
||||
//result.reserve(1);
|
||||
//result.push_back(id);
|
||||
//return result;
|
||||
return id;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ void common_speculative_add_replacement_tgt_dft(
|
|||
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
llama_tokens mtp_speculative_gen_draft(
|
||||
llama_token mtp_speculative_gen_draft(
|
||||
struct common_sampler* smpl,
|
||||
struct llama_context* ctx,
|
||||
llama_token id_last,
|
||||
|
|
|
|||
|
|
@ -977,8 +977,6 @@ extern "C" {
|
|||
// returns NULL for invalid ids.
|
||||
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
LLAMA_API void llama_set_logits(struct llama_context* ctx, struct ggml_tensor* logit_override);
|
||||
|
||||
// Get all output token embeddings.
|
||||
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
|
||||
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
|
||||
|
|
@ -1465,6 +1463,9 @@ extern "C" {
|
|||
|
||||
LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched);
|
||||
|
||||
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
|
||||
|
||||
LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res);
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -523,12 +523,16 @@ float * llama_context::get_logits() {
|
|||
return logits;
|
||||
}
|
||||
|
||||
void llama_context::set_logits(struct ggml_tensor * logit_override) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), logit_override);
|
||||
void llama_context::set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i) {
|
||||
output_reorder();
|
||||
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched_override, logit_override);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
|
||||
ggml_backend_tensor_get_async(backend_res, logit_override, logits, 0, model.vocab.n_tokens() * sizeof(float));
|
||||
int64_t j = output_ids[i];
|
||||
|
||||
ggml_backend_tensor_get_async(backend_res, logit_override, logits + j*model.vocab.n_tokens(), 0, model.vocab.n_tokens() * sizeof(float));
|
||||
}
|
||||
|
||||
float * llama_context::get_logits_ith(int32_t i) {
|
||||
|
|
@ -1445,21 +1449,23 @@ llm_graph_params llama_context::graph_params(
|
|||
|
||||
llm_graph_params llama_context::mtp_graph_params(
|
||||
llm_graph_result* res,
|
||||
const llama_ubatch& ubatch) const {
|
||||
const llama_ubatch& ubatch) {
|
||||
size_t n_nodes = std::max<uint32_t>(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer));
|
||||
ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes);
|
||||
return {
|
||||
/*.arch =*/ model.arch,
|
||||
/*.hparams =*/ model.hparams,
|
||||
/*.cparams =*/ cparams,
|
||||
/*.ubatch =*/ ubatch,
|
||||
/*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
|
||||
/*.sched =*/ sched.get(),
|
||||
/*.sched =*/ temp_sched,
|
||||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ memory->init_batch(*balloc, 1, false).get(),
|
||||
/*.cross =*/ &cross,
|
||||
/*.n_outputs =*/ 1,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
/*.cb =*/ graph_get_cb(temp_sched),
|
||||
/*.res =*/ res,
|
||||
};
|
||||
}
|
||||
|
|
@ -1491,8 +1497,10 @@ ggml_status llama_context::graph_compute(
|
|||
return status;
|
||||
}
|
||||
|
||||
llm_graph_cb llama_context::graph_get_cb() const {
|
||||
return [&](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
|
||||
llm_graph_cb llama_context::graph_get_cb(ggml_backend_sched * sched_override) const {
|
||||
ggml_backend_sched * cb_sched = sched_override ? sched_override : sched.get();
|
||||
|
||||
return [=](const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il) {
|
||||
if (il >= 0) {
|
||||
ggml_format_name(cur, "%s-%d", name, il);
|
||||
} else {
|
||||
|
|
@ -1502,7 +1510,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
|
|||
if (!cparams.offload_kqv) {
|
||||
if (strcmp(name, "kqv_merged_cont") == 0) {
|
||||
// all nodes between the KV store and the attention output are run on the CPU
|
||||
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
|
||||
ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend_cpu);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1515,7 +1523,7 @@ llm_graph_cb llama_context::graph_get_cb() const {
|
|||
for (const auto & backend : backends) {
|
||||
if (ggml_backend_get_device(backend.get()) == dev_layer) {
|
||||
if (ggml_backend_supports_op(backend.get(), cur)) {
|
||||
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend.get());
|
||||
ggml_backend_sched_set_tensor_backend(cb_sched, cur, backend.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1524,6 +1532,10 @@ llm_graph_cb llama_context::graph_get_cb() const {
|
|||
};
|
||||
}
|
||||
|
||||
ggml_backend_sched_t llama_context::create_temp_scheduler(size_t n_nodes) {
|
||||
return ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), n_nodes, false, cparams.op_offload);
|
||||
}
|
||||
|
||||
//
|
||||
// state save/load
|
||||
//
|
||||
|
|
@ -2450,10 +2462,6 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
|||
return ctx->get_logits_ith(i);
|
||||
}
|
||||
|
||||
void llama_set_logits(llama_context* ctx, struct ggml_tensor* logit_override) {
|
||||
ctx->set_logits(logit_override);
|
||||
}
|
||||
|
||||
|
||||
float * llama_get_embeddings(llama_context * ctx) {
|
||||
ctx->synchronize();
|
||||
|
|
@ -2985,3 +2993,37 @@ llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* re
|
|||
ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) {
|
||||
return ctx->graph_compute(gf, batched);
|
||||
}
|
||||
|
||||
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
|
||||
|
||||
const auto * model = llama_get_model(ctx);
|
||||
|
||||
auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
|
||||
|
||||
llama_ubatch ubatch_mtp;
|
||||
ubatch_mtp.n_tokens = 1;
|
||||
ubatch_mtp.pos = &n_past;
|
||||
|
||||
auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp));
|
||||
|
||||
auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past);
|
||||
|
||||
ggml_backend_sched_t sched = params_mtp->sched;
|
||||
|
||||
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
|
||||
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
|
||||
|
||||
ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
|
||||
|
||||
ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
|
||||
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
|
||||
|
||||
struct ggml_tensor * logits_mtp = res_mtp->get_logits();;
|
||||
LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
|
||||
|
||||
if (logits_mtp) {
|
||||
ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -200,9 +200,11 @@ public:
|
|||
// reserve a graph with a dummy ubatch of the specified size
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
||||
|
||||
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch) const;
|
||||
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch);
|
||||
|
||||
void set_logits(struct ggml_tensor* logit_override);
|
||||
void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i);
|
||||
|
||||
ggml_backend_sched_t create_temp_scheduler(size_t n_nodes);
|
||||
|
||||
private:
|
||||
llm_graph_params graph_params(
|
||||
|
|
@ -211,7 +213,7 @@ private:
|
|||
const llama_memory_context_i * mctx,
|
||||
llm_graph_type gtype) const;
|
||||
|
||||
llm_graph_cb graph_get_cb() const;
|
||||
llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const;
|
||||
|
||||
// TODO: read/write lora adapters and cvec
|
||||
size_t state_write_data(llama_io_write_i & io);
|
||||
|
|
|
|||
|
|
@ -13950,7 +13950,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
|||
// For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
|
||||
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past
|
||||
) : llm_graph_context(params) {
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
|
|
@ -13958,22 +13957,43 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
|||
const int il = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il];
|
||||
|
||||
ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
|
||||
ggml_set_i32(inp_pos, n_past);
|
||||
llm_graph_input_attn_no_cache * inp_attn = nullptr;
|
||||
// ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
|
||||
// ggml_set_i32(inp_pos, n_past);
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr;
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
// get MTP embedding for last (conventionally sampled) token
|
||||
// ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
|
||||
// LLAMA_LOG_INFO("step: '%d'\n", 5641);
|
||||
// ggml_set_i32(inp_token_id, last_token_id);
|
||||
//ggml_set_no_alloc(ctx0, false);
|
||||
//LLAMA_LOG_INFO("last token id: '%d'\n", last_token_id);
|
||||
|
||||
ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
|
||||
ggml_set_i32(inp_token_id, last_token_id);
|
||||
ggml_set_name(inp_token_id, "mtp_token_id_input");
|
||||
ggml_set_input(inp_token_id);
|
||||
|
||||
//ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id);
|
||||
//ggml_set_no_alloc(ctx0, true);
|
||||
|
||||
ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
|
||||
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
ggml_tensor * prev_embedding_leaf = ggml_dup_tensor(ctx0, hidden_state_inp);
|
||||
ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_leaf");
|
||||
ggml_cpy(ctx0, hidden_state_inp, prev_embedding_leaf);
|
||||
|
||||
// vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
ggml_tensor * hidden_state_norm = build_norm(hidden_state_inp, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
//token_emb_norm = ggml_cont(ctx0, token_emb_norm);
|
||||
//hidden_state_norm = ggml_cont(ctx0, hidden_state_norm);
|
||||
|
||||
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat
|
||||
|
||||
|
||||
cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj
|
||||
|
||||
|
||||
|
|
@ -14071,7 +14091,6 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
|||
cur = ggml_add(ctx0, routed_out, shared_out);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
|
||||
|
|
@ -18680,14 +18699,12 @@ ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params,
|
|||
switch (arch) {
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
{
|
||||
printf("step: '%d'\n", 56);
|
||||
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params, hidden_state_inp, last_token_id, n_past);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
printf("step: '%d'\n", 57);
|
||||
return llm->res->get_gf();
|
||||
}
|
||||
|
||||
|
|
@ -19009,8 +19026,8 @@ const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_te
|
|||
|
||||
ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params,
|
||||
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) {
|
||||
printf("step: '%d'\n", 55);
|
||||
|
||||
return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2132,6 +2132,8 @@ struct server_context {
|
|||
|
||||
// assume one speculative token (true of all well-known MTP models so far)
|
||||
slot.batch_spec = llama_batch_init(2, 0, 1);
|
||||
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
|
||||
|
||||
params_base.speculative.n_min = 0;
|
||||
params_base.speculative.n_max = 1;
|
||||
}
|
||||
|
|
@ -3587,9 +3589,7 @@ struct server_context {
|
|||
}
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
SLT_DBG(slot, "starting mtp draft: %d\n", 2);
|
||||
int n_draft_max = slot.params.speculative.n_max;
|
||||
SLT_DBG(slot, "starting mtp draft: %d\n", 3);
|
||||
|
||||
// note: n_past is not yet increased for the `id` token sampled above
|
||||
// also, need to leave space for 1 extra token to allow context shifts
|
||||
|
|
@ -3607,14 +3607,13 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "slot has mtp: %d\n", slot.has_mtp);
|
||||
|
||||
llama_token id = slot.sampled;
|
||||
|
||||
llama_tokens draft;
|
||||
if (slot.has_mtp) {
|
||||
SLT_DBG(slot, "starting mtp draft: %d\n", 1);
|
||||
llama_tokens draft = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx);
|
||||
llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx);
|
||||
draft.reserve(1);
|
||||
draft.push_back(draft_id);
|
||||
}
|
||||
else {
|
||||
struct common_speculative_params params_spec;
|
||||
|
|
@ -3624,7 +3623,16 @@ struct server_context {
|
|||
|
||||
const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens();
|
||||
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
|
||||
draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
|
||||
}
|
||||
|
||||
//llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx);
|
||||
//llama_tokens draft;
|
||||
//draft.reserve(1);
|
||||
//draft.push_back(draft_id);
|
||||
|
||||
for (const auto& str : draft) {
|
||||
SLT_DBG(slot, "%s\n", str);
|
||||
}
|
||||
|
||||
// ignore small drafts
|
||||
|
|
@ -3636,6 +3644,7 @@ struct server_context {
|
|||
|
||||
// keep track of total number of drafted tokens tested
|
||||
slot.n_draft_total += draft.size();
|
||||
SLT_DBG(slot, "draft size = %d\n", draft.size());
|
||||
|
||||
// construct the speculation batch
|
||||
common_batch_clear(slot.batch_spec);
|
||||
|
|
@ -3652,6 +3661,9 @@ struct server_context {
|
|||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||
|
||||
// if slot has mtp
|
||||
// call
|
||||
|
||||
slot.n_past += ids.size();
|
||||
slot.n_decoded += ids.size();
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue