llama_mtp_start
This commit is contained in:
parent
b4457e48bb
commit
de17303dc6
|
|
@ -370,6 +370,7 @@ extern "C" {
|
|||
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
|
||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
bool is_mtp; // create context for Multi-Token Prediction (MTP)
|
||||
|
||||
// [EXPERIMENTAL]
|
||||
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
|
||||
|
|
@ -939,14 +940,6 @@ extern "C" {
|
|||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
// Process a batch of tokens using MTP (Multi-Token Prediction).
|
||||
// The input token can be either from the last llama_decode() call,
|
||||
// or from the previous llama_decode_mtp() call.
|
||||
// Input token order must match the output token order from the previous call.
|
||||
LLAMA_API int32_t llama_decode_mtp(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch);
|
||||
|
||||
// Set the number of threads used for decoding
|
||||
// n_threads is the number of threads used for generation (single token)
|
||||
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
|
||||
|
|
@ -1014,6 +1007,10 @@ extern "C" {
|
|||
// otherwise: float[n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
// Copy the internal MTP state from ctx_llm to ctx_mtp, ready for MTP decoding.
|
||||
// This must be done before calling llama_decode() on ctx_mtp
|
||||
LLAMA_API int32_t llama_mtp_start(struct llama_context * ctx_llm, struct llama_context * ctx_mtp);
|
||||
|
||||
//
|
||||
// backend sampling API [EXPERIMENTAL]
|
||||
// note: use only if the llama_context was created with at least one llama_sampler_seq_config
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ llama_context::llama_context(
|
|||
// may need to be backend-dependent
|
||||
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
||||
|
||||
is_mtp = params.is_mtp;
|
||||
|
||||
t_start_us = model.t_start_us;
|
||||
t_load_us = model.t_load_us;
|
||||
|
||||
|
|
@ -814,6 +816,23 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|||
return it->second.data();
|
||||
}
|
||||
|
||||
int32_t llama_context::cpy_mtp_state(llama_context & ctx_mtp) {
|
||||
if (!ctx_mtp.is_mtp) {
|
||||
LLAMA_LOG_ERROR("%s: target context is not MTP\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (cross.n_token == 0 || cross.n_embd == 0) {
|
||||
LLAMA_LOG_ERROR("%s: no state to copy\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// TODO: maybe std::move is better?
|
||||
ctx_mtp.cross = cross;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
|
|
@ -1456,7 +1475,7 @@ static void copy_tensor_async_candidates(
|
|||
}
|
||||
}
|
||||
|
||||
int llama_context::decode(const llama_batch & batch_inp, bool is_mtp) {
|
||||
int llama_context::decode(const llama_batch & batch_inp) {
|
||||
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
||||
|
||||
if (is_mtp) {
|
||||
|
|
@ -1464,7 +1483,7 @@ int llama_context::decode(const llama_batch & batch_inp, bool is_mtp) {
|
|||
LLAMA_LOG_ERROR("%s: MTP decode called but model does not support MTP\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
if (batch_inp.n_tokens > n_ubatch()) {
|
||||
if ((uint32_t)batch_inp.n_tokens > n_ubatch()) {
|
||||
// TODO @ngxson : n_tokens > ubatch will mess up the llama_cross state, may need to fix it later
|
||||
LLAMA_LOG_ERROR("%s: MTP decode requires n_ubatch >= n_tokens\n", __func__);
|
||||
return -1;
|
||||
|
|
@ -3043,6 +3062,7 @@ llama_context_params llama_context_default_params() {
|
|||
/*.op_offload =*/ true,
|
||||
/*.swa_full =*/ true,
|
||||
/*.kv_unified =*/ false,
|
||||
/*.is_mtp =*/ false,
|
||||
/*.sampler =*/ nullptr,
|
||||
/*.n_sampler =*/ 0,
|
||||
};
|
||||
|
|
@ -3233,6 +3253,12 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|||
return ctx->get_embeddings_seq(seq_id);
|
||||
}
|
||||
|
||||
int32_t llama_mtp_start(llama_context * ctx_llm, llama_context * ctx_mtp) {
|
||||
ctx_llm->synchronize();
|
||||
|
||||
return ctx_llm->cpy_mtp_state(*ctx_mtp);
|
||||
}
|
||||
|
||||
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
|
||||
return ctx->set_sampler(seq_id, smpl);
|
||||
}
|
||||
|
|
@ -3553,7 +3579,7 @@ int32_t llama_encode(
|
|||
int32_t llama_decode(
|
||||
llama_context * ctx,
|
||||
llama_batch batch) {
|
||||
const int ret = ctx->decode(batch, false);
|
||||
const int ret = ctx->decode(batch);
|
||||
if (ret != 0 && ret != 1) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
|
@ -3561,17 +3587,6 @@ int32_t llama_decode(
|
|||
return ret;
|
||||
}
|
||||
|
||||
int32_t llama_decode_mtp(
|
||||
llama_context * ctx,
|
||||
llama_batch batch) {
|
||||
const int ret = ctx->decode(batch, true);
|
||||
if (ret != 0 && ret != 1) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode MTP, ret = %d\n", __func__, ret);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
//
|
||||
// perf
|
||||
//
|
||||
|
|
|
|||
|
|
@ -77,6 +77,8 @@ struct llama_context {
|
|||
float * get_embeddings();
|
||||
float * get_embeddings_ith(int32_t i);
|
||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
|
||||
int32_t cpy_mtp_state(llama_context & ctx_mtp);
|
||||
|
||||
llama_token * get_sampled_tokens() const;
|
||||
llama_token get_sampled_token_ith(int32_t idx);
|
||||
|
|
@ -131,7 +133,7 @@ struct llama_context {
|
|||
ggml_status & ret);
|
||||
|
||||
int encode(const llama_batch & batch_inp);
|
||||
int decode(const llama_batch & batch_inp, bool is_mtp);
|
||||
int decode(const llama_batch & batch_inp);
|
||||
|
||||
//
|
||||
// state save/load
|
||||
|
|
@ -349,6 +351,8 @@ private:
|
|||
// host buffer for the model output (logits and embeddings)
|
||||
ggml_backend_buffer_ptr buf_output;
|
||||
|
||||
bool is_mtp = false;
|
||||
|
||||
bool has_evaluated_once = false;
|
||||
|
||||
// env: LLAMA_GRAPH_REUSE_DISABLE
|
||||
|
|
|
|||
|
|
@ -54,7 +54,6 @@ llm_build_glm4_moe<false>::llm_build_glm4_moe(const llama_model & model, const l
|
|||
template <>
|
||||
llm_build_glm4_moe<true>::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const bool use_mrope = hparams.use_mrope();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
|
|
@ -73,6 +72,9 @@ llm_build_glm4_moe<true>::llm_build_glm4_moe(const llama_model & model, const ll
|
|||
? mtp_layer.nextn.embed_tokens : model.tok_embd);
|
||||
ggml_tensor * inp_state_embd = build_inp_cross_mtp();
|
||||
|
||||
// check number of input tokens
|
||||
GGML_ASSERT(inp_state_embd->ne[1] == inp_token_embd->ne[1]);
|
||||
|
||||
inp_token_embd = build_norm(inp_token_embd, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
|
||||
inp_state_embd = build_norm(inp_state_embd, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue