add llama_context graph_type
This commit is contained in:
parent
a2860dc85e
commit
64f05859db
|
|
@ -194,6 +194,13 @@ extern "C" {
|
|||
LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported
|
||||
};
|
||||
|
||||
enum llama_graph_type {
|
||||
LLAMA_GRAPH_TYPE_DEFAULT,
|
||||
LLAMA_GRAPH_TYPE_ENCODER,
|
||||
LLAMA_GRAPH_TYPE_DECODER,
|
||||
LLAMA_GRAPH_TYPE_DECODER_MTP,
|
||||
};
|
||||
|
||||
// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
|
|
@ -370,13 +377,14 @@ extern "C" {
|
|||
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
|
||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
bool is_mtp; // create context for Multi-Token Prediction (MTP)
|
||||
|
||||
// [EXPERIMENTAL]
|
||||
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
|
||||
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
|
||||
struct llama_sampler_seq_config * samplers;
|
||||
size_t n_samplers;
|
||||
|
||||
llama_graph_type graph_type; // type of the computation graph to be used
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
|
|
|
|||
|
|
@ -27,8 +27,6 @@ llama_context::llama_context(
|
|||
// may need to be backend-dependent
|
||||
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
||||
|
||||
is_mtp = params.is_mtp;
|
||||
|
||||
t_start_us = model.t_start_us;
|
||||
t_load_us = model.t_load_us;
|
||||
|
||||
|
|
@ -189,6 +187,23 @@ llama_context::llama_context(
|
|||
}
|
||||
}
|
||||
|
||||
switch (params.graph_type) {
|
||||
case LLAMA_GRAPH_TYPE_DEFAULT:
|
||||
gtype = LLM_GRAPH_TYPE_DEFAULT;
|
||||
break;
|
||||
case LLAMA_GRAPH_TYPE_ENCODER:
|
||||
gtype = LLM_GRAPH_TYPE_ENCODER;
|
||||
break;
|
||||
case LLAMA_GRAPH_TYPE_DECODER:
|
||||
gtype = LLM_GRAPH_TYPE_DECODER;
|
||||
break;
|
||||
case LLAMA_GRAPH_TYPE_DECODER_MTP:
|
||||
gtype = LLM_GRAPH_TYPE_DECODER_MTP;
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("invalid graph type");
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max);
|
||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
|
||||
LLAMA_LOG_INFO("%s: n_ctx_seq = %u\n", __func__, cparams.n_ctx_seq);
|
||||
|
|
@ -814,7 +829,7 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|||
}
|
||||
|
||||
int32_t llama_context::cpy_mtp_state(llama_context & ctx_mtp) {
|
||||
if (!ctx_mtp.is_mtp) {
|
||||
if (ctx_mtp.gtype != LLM_GRAPH_TYPE_DECODER_MTP) {
|
||||
LLAMA_LOG_ERROR("%s: target context is not MTP\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
|
@ -1488,7 +1503,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
|
|||
int llama_context::decode(const llama_batch & batch_inp) {
|
||||
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
||||
|
||||
if (is_mtp) {
|
||||
if (gtype == LLM_GRAPH_TYPE_DECODER_MTP) {
|
||||
if (model.hparams.nextn_predict_layers == 0) {
|
||||
LLAMA_LOG_ERROR("%s: MTP decode called but model does not support MTP\n", __func__);
|
||||
return -1;
|
||||
|
|
@ -1665,7 +1680,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
|
||||
ggml_status status;
|
||||
llm_graph_type gtype = is_mtp ? LLM_GRAPH_TYPE_DECODER_MTP : LLM_GRAPH_TYPE_DECODER;
|
||||
const auto * res = process_ubatch(ubatch, gtype, mctx.get(), status);
|
||||
|
||||
if (!res) {
|
||||
|
|
@ -3029,9 +3043,9 @@ llama_context_params llama_context_default_params() {
|
|||
/*.op_offload =*/ true,
|
||||
/*.swa_full =*/ true,
|
||||
/*.kv_unified =*/ false,
|
||||
/*.is_mtp =*/ false,
|
||||
/*.sampler =*/ nullptr,
|
||||
/*.n_sampler =*/ 0,
|
||||
/*.graph_type =*/ LLAMA_GRAPH_TYPE_DEFAULT,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
|
|||
|
|
@ -351,7 +351,7 @@ private:
|
|||
// host buffer for the model output (logits and embeddings)
|
||||
ggml_backend_buffer_ptr buf_output;
|
||||
|
||||
bool is_mtp = false;
|
||||
llm_graph_type gtype;
|
||||
|
||||
bool has_evaluated_once = false;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue