broad thrust of the mtp implementation
This commit is contained in:
parent
03231da69e
commit
cf0f7c0448
|
|
@ -5,6 +5,7 @@
|
|||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "../src/llama-graph.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
|
|
@ -359,3 +360,128 @@ llama_tokens common_speculative_gen_draft(
|
|||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
llama_tokens mtp_speculative_gen_draft(
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
int32_t last_tok_idx) {
|
||||
|
||||
llama_tokens result;
|
||||
|
||||
LOG_INF("step: '%d'\n", 1);
|
||||
|
||||
// sample one token from the draft model -- this does NOT generalize to >1 MTP head
|
||||
result.reserve(1);
|
||||
|
||||
// need to determine which architecture we're using so we call the correct MTP model
|
||||
const auto * model = llama_get_model(ctx);
|
||||
|
||||
LOG_INF("step: '%d'\n", 2);
|
||||
|
||||
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
||||
//auto * gf = model.build_graph(gparams);
|
||||
|
||||
LOG_INF("step: '%d'\n", 3);
|
||||
|
||||
/*if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
||||
ret = GGML_STATUS_ALLOC_FAILED;
|
||||
return nullptr;
|
||||
}*/
|
||||
|
||||
//llm_graph_result res_mtp(ctx->graph_max_nodes());
|
||||
llm_graph_result * res_mtp;
|
||||
llama_ubatch ubatch_mtp;
|
||||
ubatch_mtp.n_tokens = 1;
|
||||
ubatch_mtp.pos = &n_past; // Critical for positional encoding
|
||||
|
||||
// We also need a minimal ubatch to provide positional context (RoPE)
|
||||
// ubatch_mtp.tokens = &last_token_id;
|
||||
// ubatch_mtp.seq_id = llama_get_main_seq_id(ctx); // Assuming a helper
|
||||
// ubatch_mtp.logits = nullptr;
|
||||
// ubatch_mtp.all_pos_0 = -1;
|
||||
// ubatch_mtp.all_pos_1 = -1;
|
||||
// ubatch_mtp.all_seq_id = -1;
|
||||
|
||||
// Manually construct the graph parameters
|
||||
//const llm_graph_params params_mtp = {
|
||||
// /*.arch =*/ model->arch,
|
||||
// /*.hparams =*/ model->hparams,
|
||||
// /*.cparams =*/ ctx->cparams,
|
||||
// /*.ubatch =*/ ubatch_mtp,
|
||||
// /*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
|
||||
// /*.sched =*/ ctx->sched.get(),
|
||||
// /*.backend_cpu =*/ ctx->backend_cpu,
|
||||
// /*.cvec =*/ &ctx->cvec,
|
||||
// /*.loras =*/ &ctx->loras,
|
||||
// /*.mctx =*/ llama_get_memory(ctx), // Use the KV cache's memory context
|
||||
// /*.cross =*/ &ctx->cross,
|
||||
// /*.n_outputs =*/ 1,
|
||||
// /*.cb =*/ ctx->graph_get_cb(),
|
||||
// /*.res =*/ &res_mtp, // Point to our temporary result object
|
||||
//};
|
||||
llm_graph_params params_mtp = llama_mtp_graph_params(ctx, res_mtp, ubatch_mtp);
|
||||
|
||||
LOG_INF("step: '%d'\n", 4);
|
||||
|
||||
// ggml_cgraph* build_mtp_graph(const llm_graph_params & params,
|
||||
// ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const;
|
||||
auto * last_embd = llama_get_embeddings_tensor(ctx);
|
||||
|
||||
LOG_INF("step: '%d'\n", 5);
|
||||
|
||||
GGML_ASSERT(model != nullptr);
|
||||
GGML_ASSERT(last_embd != nullptr);
|
||||
|
||||
auto * gf = llama_build_mtp_graph(model, params_mtp, last_embd, id_last, n_past);
|
||||
|
||||
if (!gf) {
|
||||
LOG_INF("%s: failed to initialize graph\n", __func__);
|
||||
//ret = GGML_STATUS_FAILED;
|
||||
return result;
|
||||
}
|
||||
|
||||
LOG_INF("step: '%d'\n", 6);
|
||||
|
||||
const auto status = llama_graph_compute(ctx, gf, false);
|
||||
|
||||
LOG_INF("step: '%d'\n", 7);
|
||||
|
||||
struct ggml_tensor * logits_mtp = llama_graph_result_get_logits(res_mtp);
|
||||
float * ctx_logit_pointer = llama_get_logits(ctx);
|
||||
|
||||
LOG_INF("step: '%d'\n", 8);
|
||||
|
||||
if (logits_mtp) {
|
||||
llama_set_logits(ctx, logits_mtp);
|
||||
}
|
||||
|
||||
LOG_INF("step: '%d'\n", 9);
|
||||
|
||||
{
|
||||
common_sampler_sample(smpl, ctx, last_tok_idx, true);
|
||||
|
||||
LOG_INF("step: '%d'\n", 10);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
|
||||
}
|
||||
|
||||
// add drafted token for each sequence
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
|
||||
// smpl will accept the token if it doesn't get rejected by main model later
|
||||
// common_sampler_accept(smpl, id, true);
|
||||
|
||||
result.push_back(id);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,15 @@ void common_speculative_add_replacement_tgt_dft(
|
|||
struct common_speculative * spec,
|
||||
const char *source, const char *dest);
|
||||
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
llama_tokens mtp_speculative_gen_draft(
|
||||
struct common_sampler* smpl,
|
||||
struct llama_context* ctx,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
int32_t last_tok_idx);
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
|
|
|
|||
|
|
@ -544,12 +544,17 @@ extern "C" {
|
|||
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
|
||||
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
|
||||
|
||||
LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params,
|
||||
struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past);
|
||||
|
||||
// Returns 0 on success
|
||||
LLAMA_API uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
const char * fname_out,
|
||||
const llama_model_quantize_params * params);
|
||||
|
||||
|
||||
|
||||
//
|
||||
// Adapters
|
||||
//
|
||||
|
|
@ -972,6 +977,8 @@ extern "C" {
|
|||
// returns NULL for invalid ids.
|
||||
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
LLAMA_API void llama_set_logits(struct llama_context* ctx, struct ggml_tensor* logit_override);
|
||||
|
||||
// Get all output token embeddings.
|
||||
// when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
|
||||
// the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
|
||||
|
|
@ -994,6 +1001,8 @@ extern "C" {
|
|||
// otherwise: float[n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx);
|
||||
|
||||
//
|
||||
// Vocab
|
||||
//
|
||||
|
|
@ -1452,6 +1461,14 @@ extern "C" {
|
|||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch);
|
||||
|
||||
LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched);
|
||||
|
||||
LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res);
|
||||
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "llama-memory.h"
|
||||
#include "llama-mmap.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-graph.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstring>
|
||||
|
|
@ -522,6 +523,14 @@ float * llama_context::get_logits() {
|
|||
return logits;
|
||||
}
|
||||
|
||||
void llama_context::set_logits(struct ggml_tensor * logit_override) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), logit_override);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
|
||||
ggml_backend_tensor_get_async(backend_res, logit_override, logits, 0, model.vocab.n_tokens() * sizeof(float));
|
||||
}
|
||||
|
||||
float * llama_context::get_logits_ith(int32_t i) {
|
||||
int64_t j = -1;
|
||||
|
||||
|
|
@ -617,6 +626,10 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|||
return it->second.data();
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context::get_embeddings_tensor() {
|
||||
return embd_tensor;
|
||||
}
|
||||
|
||||
void llama_context::attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch) {
|
||||
|
|
@ -1113,6 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
embd_tensor = res->get_embd();
|
||||
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
t_embd = res->get_embd_pooled();
|
||||
|
|
@ -1429,6 +1443,27 @@ llm_graph_params llama_context::graph_params(
|
|||
};
|
||||
}
|
||||
|
||||
llm_graph_params llama_context::mtp_graph_params(
|
||||
llm_graph_result* res,
|
||||
const llama_ubatch& ubatch) const {
|
||||
return {
|
||||
/*.arch =*/ model.arch,
|
||||
/*.hparams =*/ model.hparams,
|
||||
/*.cparams =*/ cparams,
|
||||
/*.ubatch =*/ ubatch,
|
||||
/*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
|
||||
/*.sched =*/ sched.get(),
|
||||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ memory->init_batch(*balloc, 1, false).get(),
|
||||
/*.cross =*/ &cross,
|
||||
/*.n_outputs =*/ 1,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
/*.res =*/ res,
|
||||
};
|
||||
}
|
||||
|
||||
ggml_status llama_context::graph_compute(
|
||||
ggml_cgraph * gf,
|
||||
bool batched) {
|
||||
|
|
@ -2233,6 +2268,7 @@ void llama_context::opt_epoch(
|
|||
llama_batch_free(batch);
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
|
|
@ -2274,6 +2310,8 @@ llama_context_params llama_context_default_params() {
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
|
||||
llama_context * llama_init_from_model(
|
||||
llama_model * model,
|
||||
llama_context_params params) {
|
||||
|
|
@ -2412,6 +2450,11 @@ float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
|||
return ctx->get_logits_ith(i);
|
||||
}
|
||||
|
||||
void llama_set_logits(llama_context* ctx, struct ggml_tensor* logit_override) {
|
||||
ctx->set_logits(logit_override);
|
||||
}
|
||||
|
||||
|
||||
float * llama_get_embeddings(llama_context * ctx) {
|
||||
ctx->synchronize();
|
||||
|
||||
|
|
@ -2430,6 +2473,13 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|||
return ctx->get_embeddings_seq(seq_id);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_embeddings_tensor();
|
||||
}
|
||||
|
||||
|
||||
// llama adapter API
|
||||
|
||||
int32_t llama_set_adapter_lora(
|
||||
|
|
@ -2926,3 +2976,12 @@ void llama_opt_epoch(
|
|||
callback_train,
|
||||
callback_eval);
|
||||
}
|
||||
|
||||
llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) {
|
||||
return ctx->mtp_graph_params(res, ubatch);
|
||||
}
|
||||
|
||||
|
||||
ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) {
|
||||
return ctx->graph_compute(gf, batched);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ struct llama_context {
|
|||
float * get_embeddings();
|
||||
float * get_embeddings_ith(int32_t i);
|
||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
ggml_tensor * get_embeddings_tensor();
|
||||
|
||||
void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
|
|
@ -199,6 +200,10 @@ public:
|
|||
// reserve a graph with a dummy ubatch of the specified size
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
||||
|
||||
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch) const;
|
||||
|
||||
void set_logits(struct ggml_tensor* logit_override);
|
||||
|
||||
private:
|
||||
llm_graph_params graph_params(
|
||||
llm_graph_result * res,
|
||||
|
|
@ -240,6 +245,7 @@ private:
|
|||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
float * embd = nullptr;
|
||||
ggml_tensor * embd_tensor = nullptr;
|
||||
|
||||
// sequence embeddings output (map of [n_embd] vectors)
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
|
|
@ -308,3 +314,4 @@ private:
|
|||
|
||||
mutable int32_t n_reused = 0; // number of times the previous graph was reused
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1911,3 +1911,7 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
|
|||
|
||||
return relative_bucket;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) {
|
||||
return res->get_logits();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -818,3 +818,4 @@ struct llm_graph_context {
|
|||
|
||||
// TODO: better name
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
|
||||
|
||||
|
|
|
|||
|
|
@ -18673,19 +18673,21 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
return llm->res->get_gf();
|
||||
}
|
||||
|
||||
ggml_cgraph* llama_model::build_mtp_graph(const llm_graph_params& params,
|
||||
ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params,
|
||||
ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const {
|
||||
std::unique_ptr<llm_graph_context> llm;
|
||||
|
||||
switch (arch) {
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
{
|
||||
printf("step: '%d'\n", 56);
|
||||
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params, hidden_state_inp, last_token_id, n_past);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
printf("step: '%d'\n", 57);
|
||||
return llm->res->get_gf();
|
||||
}
|
||||
|
||||
|
|
@ -19004,3 +19006,11 @@ bool llama_model_is_diffusion(const llama_model * model) {
|
|||
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
|
||||
return model->tensors_by_name;
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params,
|
||||
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) {
|
||||
printf("step: '%d'\n", 55);
|
||||
|
||||
return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1294,7 +1294,8 @@ struct server_slot {
|
|||
mtmd_context * mctx = nullptr;
|
||||
|
||||
common_speculative * spec = nullptr;
|
||||
bool has_mtp = false;
|
||||
bool has_mtp = false;
|
||||
int32_t last_tok_idx = -1;
|
||||
|
||||
std::vector<common_adapter_lora_info> lora;
|
||||
|
||||
|
|
@ -1432,8 +1433,8 @@ struct server_slot {
|
|||
}
|
||||
|
||||
bool can_speculate() const {
|
||||
// return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt;
|
||||
return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt;
|
||||
return (ctx_dft || has_mtp) && params.speculative.n_max > 0 && params.cache_prompt;
|
||||
// return (ctx_dft) && params.speculative.n_max > 0 && params.cache_prompt;
|
||||
}
|
||||
|
||||
void add_token(const completion_token_output & token) {
|
||||
|
|
@ -1993,7 +1994,7 @@ struct server_context {
|
|||
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
vocab = llama_model_get_vocab(model);
|
||||
|
||||
n_ctx = llama_n_ctx(ctx);
|
||||
|
|
@ -3531,6 +3532,7 @@ struct server_context {
|
|||
const int tok_idx = slot.i_batch - i;
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||
slot.last_tok_idx = tok_idx;
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
|
|
@ -3567,6 +3569,8 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
SRV_DBG("starting speculative decoding: %d\n", 1);
|
||||
|
||||
// do speculative decoding
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.is_processing() || !slot.can_speculate()) {
|
||||
|
|
@ -3583,7 +3587,9 @@ struct server_context {
|
|||
}
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
SLT_DBG(slot, "starting mtp draft: %d\n", 2);
|
||||
int n_draft_max = slot.params.speculative.n_max;
|
||||
SLT_DBG(slot, "starting mtp draft: %d\n", 3);
|
||||
|
||||
// note: n_past is not yet increased for the `id` token sampled above
|
||||
// also, need to leave space for 1 extra token to allow context shifts
|
||||
|
|
@ -3601,15 +3607,25 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "slot has mtp: %d\n", slot.has_mtp);
|
||||
|
||||
llama_token id = slot.sampled;
|
||||
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||
params_spec.p_min = slot.params.speculative.p_min;
|
||||
llama_tokens draft;
|
||||
if (slot.has_mtp) {
|
||||
SLT_DBG(slot, "starting mtp draft: %d\n", 1);
|
||||
llama_tokens draft = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx);
|
||||
}
|
||||
else {
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||
params_spec.p_min = slot.params.speculative.p_min;
|
||||
|
||||
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
|
||||
const llama_tokens& cached_text_tokens = slot.cache_tokens.get_text_tokens();
|
||||
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
|
||||
}
|
||||
|
||||
// ignore small drafts
|
||||
if (slot.params.speculative.n_min > (int) draft.size()) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue