Merge d10a5a4a5b into 18ddaea2ae
This commit is contained in:
commit
465fa04f2e
|
|
@ -3229,6 +3229,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.speculative.cache_type_k = kv_cache_type_from_str(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT"));
|
||||
add_opt(common_arg(
|
||||
{"-mtp", "--multi-token-prediction"},
|
||||
string_format("Activate multi-token-prediction (if supported) (default: %s)", params.mtp ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.mtp = true;
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"-ctvd", "--cache-type-v-draft"}, "TYPE",
|
||||
string_format(
|
||||
|
|
|
|||
|
|
@ -1351,6 +1351,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
|
|||
mparams.check_tensors = params.check_tensors;
|
||||
mparams.use_extra_bufts = !params.no_extra_bufts;
|
||||
mparams.no_host = params.no_host;
|
||||
mparams.mtp = params.mtp;
|
||||
|
||||
if (params.kv_overrides.empty()) {
|
||||
mparams.kv_overrides = NULL;
|
||||
|
|
|
|||
|
|
@ -430,6 +430,7 @@ struct common_params {
|
|||
bool no_op_offload = false; // globally disable offload host tensor operations to device
|
||||
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
|
||||
bool no_host = false; // bypass host buffer allowing extra buffers to be used
|
||||
bool mtp = false; // enable MTP if supported by the model
|
||||
|
||||
bool single_turn = false; // single turn chat conversation
|
||||
|
||||
|
|
|
|||
|
|
@ -666,3 +666,29 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|||
|
||||
return samplers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Specialized sampling for speculative drafting.
|
||||
*
|
||||
* Prioritizes performance by using a direct ArgMax loop (Greedy).
|
||||
* Penalties and complex sampling logic are bypassed to minimize
|
||||
* drafting latency.
|
||||
*/
|
||||
llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
|
||||
const auto & params = gsmpl->params;
|
||||
|
||||
float * logits = llama_get_logits_ith(ctx, idx);
|
||||
const int n_vocab = llama_n_vocab(llama_model_get_vocab(llama_get_model(ctx)));
|
||||
|
||||
int best_id = 0;
|
||||
float max_val = logits[0];
|
||||
|
||||
for (int i = 1; i < n_vocab; ++i) {
|
||||
if (logits[i] > max_val) {
|
||||
max_val = logits[i];
|
||||
best_id = i;
|
||||
}
|
||||
}
|
||||
|
||||
return best_id;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -115,3 +115,5 @@ struct common_sampler_deleter {
|
|||
};
|
||||
|
||||
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;
|
||||
|
||||
llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
|
||||
|
|
|
|||
|
|
@ -359,3 +359,116 @@ llama_tokens common_speculative_gen_draft(
|
|||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_tokens mtp_speculative_gen_draft(
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx,
|
||||
struct common_speculative_params params,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
llama_seq_id seq_id) {
|
||||
|
||||
int n_draft = params.n_draft;
|
||||
|
||||
llama_tokens drafts;
|
||||
drafts.reserve(n_draft);
|
||||
|
||||
if (!smpl) return drafts;
|
||||
|
||||
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
|
||||
mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN;
|
||||
|
||||
llama_token current_input_id = id_last;
|
||||
int32_t current_n_past = n_past;
|
||||
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
mtp_batch.n_tokens = 0;
|
||||
common_batch_add(mtp_batch, current_input_id, current_n_past, {seq_id}, true);
|
||||
|
||||
// Perform the MTP draft generation decode. This writes the MTP layer's
|
||||
// KV state for the draft token into the cache.
|
||||
if (llama_decode(ctx, mtp_batch) != 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
llama_token id_next = common_sampler_sample_speculative(smpl, ctx, 0);
|
||||
|
||||
// Drafting stops if token probability drops below `p_min` to save compute.
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
if (cur_p && cur_p->size > 0) {
|
||||
float prob = cur_p->data[0].p;
|
||||
|
||||
if (prob < params.p_min) {
|
||||
drafts.push_back(id_next);
|
||||
current_n_past++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
drafts.push_back(id_next);
|
||||
|
||||
current_input_id = id_next;
|
||||
current_n_past++;
|
||||
}
|
||||
llama_batch_free(mtp_batch);
|
||||
|
||||
// CRITICAL: Purge the metadata for the draft token we just wrote.
|
||||
// This makes the physical cell available again for the main model's validation pass,
|
||||
// preventing a cache state corruption where two cells map to the same logical position.
|
||||
if (!drafts.empty()) {
|
||||
llama_kv_cache_seq_rm(ctx, seq_id, n_past, current_n_past);
|
||||
}
|
||||
|
||||
return drafts;
|
||||
}
|
||||
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup) {
|
||||
if (batch.n_tokens == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
|
||||
|
||||
llama_batch mtp_batch = batch;
|
||||
if (is_prompt_warmup) {
|
||||
mtp_batch.mtp_params.op_type = MTP_OP_WARMUP;
|
||||
} else {
|
||||
mtp_batch.mtp_params.op_type = MTP_OP_UPDATE_ACCEPTED;
|
||||
}
|
||||
|
||||
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
|
||||
mtp_batch.logits[i] = true;
|
||||
}
|
||||
llama_decode(ctx, mtp_batch);
|
||||
}
|
||||
|
||||
void mtp_accept_tokens(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & ids,
|
||||
int32_t n_past_base,
|
||||
llama_seq_id seq_id
|
||||
) {
|
||||
if (ids.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Prepare a resized copy of the validation sinfo to match the number of accepted tokens.
|
||||
// This sets up the context for a "forced sinfo" decode.
|
||||
if (!llama_mtp_prepare_sinfo_for_update(ctx, ids.size())) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a new batch containing only the accepted tokens.
|
||||
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
|
||||
}
|
||||
|
||||
mtp_update_kv_cache(ctx, accepted_batch, false);
|
||||
|
||||
// Clean up the forced state to not affect subsequent, normal decode calls.
|
||||
llama_mtp_cancel_sinfo_update(ctx);
|
||||
|
||||
llama_batch_free(accepted_batch);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,12 @@ struct common_speculative_params {
|
|||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
};
|
||||
|
||||
struct mtp_kv_update_data {
|
||||
llama_token id;
|
||||
int32_t n_past;
|
||||
int32_t tok_idx;
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct llama_context * ctx_tgt,
|
||||
struct llama_context * ctx_dft
|
||||
|
|
@ -29,7 +35,40 @@ void common_speculative_add_replacement_tgt_dft(
|
|||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
struct common_speculative * spec,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
|
||||
/**
|
||||
* @brief Generates speculative draft tokens using the Multi-Token Prediction (MTP) architecture.
|
||||
*
|
||||
* This function performs a recursive generation loop using the MTP head (e.g., Eagle/NextN).
|
||||
* It uses the fixed hidden state from the main model's last step and updates the MTP layer's
|
||||
* internal KV cache autoregressively.
|
||||
*
|
||||
* @param smpl The sampler instance.
|
||||
* @param ctx The llama context (shared between Main and MTP).
|
||||
* @param params Speculative parameters (n_draft, p_min).
|
||||
* @param id_last The last confirmed token ID from the main model.
|
||||
* @param n_past The number of tokens in the validated past (start position for drafting).
|
||||
* @param seq_id The sequence ID to use for drafting.
|
||||
*
|
||||
* @return std::vector<llama_token> The generated draft tokens.
|
||||
*/
|
||||
llama_tokens mtp_speculative_gen_draft(
|
||||
struct common_sampler * smpl,
|
||||
struct llama_context * ctx,
|
||||
struct common_speculative_params params,
|
||||
llama_token id_last,
|
||||
int32_t n_past,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, bool is_prompt_warmup);
|
||||
|
||||
void mtp_accept_tokens(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & ids,
|
||||
int32_t n_past_base,
|
||||
llama_seq_id seq_id
|
||||
);
|
||||
|
|
|
|||
|
|
@ -228,6 +228,17 @@ extern "C" {
|
|||
// - if not: only the last token is output
|
||||
// )
|
||||
//
|
||||
typedef enum {
|
||||
MTP_OP_NONE,
|
||||
MTP_OP_WARMUP,
|
||||
MTP_OP_UPDATE_ACCEPTED,
|
||||
MTP_OP_DRAFT_GEN,
|
||||
} llama_mtp_op_type;
|
||||
|
||||
typedef struct llama_mtp_params {
|
||||
llama_mtp_op_type op_type;
|
||||
} llama_mtp_params;
|
||||
|
||||
typedef struct llama_batch {
|
||||
int32_t n_tokens;
|
||||
|
||||
|
|
@ -237,6 +248,7 @@ extern "C" {
|
|||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
llama_mtp_params mtp_params;
|
||||
} llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
|
|
@ -314,6 +326,7 @@ extern "C" {
|
|||
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
|
||||
bool no_host; // bypass host buffer allowing extra buffers to be used
|
||||
bool no_alloc; // only load metadata and simulate memory allocations
|
||||
bool mtp; // use mtp if is supported by the Model
|
||||
};
|
||||
|
||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||
|
|
@ -543,6 +556,8 @@ extern "C" {
|
|||
|
||||
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
|
||||
|
||||
LLAMA_API int32_t llama_model_n_nextn_layer(const struct llama_model * model);
|
||||
|
||||
// Functions to access the model's GGUF metadata scalar values
|
||||
// - The functions return the length of the string on success, or -1 on failure
|
||||
// - The output string is always null-terminated and cleared on failure
|
||||
|
|
@ -1451,6 +1466,38 @@ extern "C" {
|
|||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
//
|
||||
// MTP
|
||||
//
|
||||
|
||||
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
|
||||
|
||||
/**
|
||||
* @brief Prepares the context for an MTP KV cache update by creating a resized copy of the last sinfo.
|
||||
* This is used after speculative validation when only a subset of draft tokens are accepted.
|
||||
* @param n_accepted The number of tokens that were accepted and for which the sinfo should be resized.
|
||||
* @return true on success.
|
||||
*/
|
||||
LLAMA_API bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted);
|
||||
|
||||
/**
|
||||
* @brief Prepares the context for an MTP KV cache update by reusing the sinfo from the last main model decode.
|
||||
* This is used for the prompt warmup to ensure the MTP and main model KV caches are perfectly aligned.
|
||||
* @return true on success.
|
||||
*/
|
||||
LLAMA_API bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx);
|
||||
|
||||
/**
|
||||
* @brief Clears the forced sinfo state from the context. Must be called after a decode that used a prepared sinfo.
|
||||
*/
|
||||
LLAMA_API void llama_mtp_cancel_sinfo_update(struct llama_context * ctx);
|
||||
|
||||
/**
|
||||
* @brief Removes KV cache metadata for a specified sequence and token range.
|
||||
* This makes the physical cells logically available again without deleting the tensor data.
|
||||
*/
|
||||
LLAMA_API void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -2446,12 +2446,13 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||
{LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
|
||||
// These tensors only exist in the last layer(s) and are treated as output tensors
|
||||
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
// Changed to LLM_TENSOR_LAYER_REPEATING because we saved these under a blk with a non-negative id
|
||||
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
|
||||
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
};
|
||||
|
||||
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
||||
|
|
|
|||
|
|
@ -301,17 +301,17 @@ bool llama_batch_allocr::init(
|
|||
ok = false;
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
LLAMA_LOG_ERROR(
|
||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
||||
__func__, s, s, p0, s, seq_pos_min(s));
|
||||
// if (!ok) {
|
||||
// LLAMA_LOG_ERROR(
|
||||
// "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
// " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
// " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
||||
// " it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
||||
// __func__, s, s, p0, s, seq_pos_min(s));
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// return false;
|
||||
// }
|
||||
}
|
||||
|
||||
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
||||
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
|
||||
|
|
@ -874,13 +874,14 @@ struct llama_batch llama_batch_get_one(
|
|||
|
||||
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
llama_batch batch = {
|
||||
/*n_tokens =*/ 0,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*n_tokens =*/ 0,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*.mtp_params =*/ { MTP_OP_NONE },
|
||||
};
|
||||
|
||||
if (embd) {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include "llama-memory.h"
|
||||
#include "llama-mmap.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cmath>
|
||||
|
|
@ -17,6 +18,13 @@
|
|||
//
|
||||
// llama_context
|
||||
//
|
||||
// Key for the graph cache. It contains all parameters that define the graph topology.
|
||||
|
||||
struct llama_context_kv_cache_data {
|
||||
llama_kv_cache::slot_info_vec_t last_main_model_sinfos;
|
||||
llama_kv_cache::slot_info_vec_t resized_sinfo_for_force;
|
||||
const llama_kv_cache::slot_info_vec_t * forced_sinfos = nullptr;
|
||||
};
|
||||
|
||||
llama_context::llama_context(
|
||||
const llama_model & model,
|
||||
|
|
@ -136,6 +144,9 @@ llama_context::llama_context(
|
|||
cparams.op_offload = params.op_offload;
|
||||
cparams.kv_unified = params.kv_unified;
|
||||
|
||||
kv_cache_data = new llama_context_kv_cache_data();
|
||||
|
||||
|
||||
{
|
||||
const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE");
|
||||
graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable;
|
||||
|
|
@ -476,6 +487,7 @@ llama_context::~llama_context() {
|
|||
}
|
||||
}
|
||||
ggml_opt_free(opt_ctx);
|
||||
delete static_cast<llama_context_kv_cache_data *>(kv_cache_data);
|
||||
}
|
||||
|
||||
void llama_context::synchronize() {
|
||||
|
|
@ -711,6 +723,10 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|||
return it->second.data();
|
||||
}
|
||||
|
||||
ggml_tensor * llama_context::get_embeddings_tensor() {
|
||||
return embd_tensor;
|
||||
}
|
||||
|
||||
void llama_context::attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch) {
|
||||
|
|
@ -805,7 +821,8 @@ bool llama_context::apply_adapter_cvec(
|
|||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
||||
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret,
|
||||
const llama_mtp_params & mtp_params) {
|
||||
if (mctx && !mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
|
|
@ -817,7 +834,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
|
||||
// the new graph parameters
|
||||
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
|
||||
const auto gparams = graph_params(res, ubatch, mctx, gtype);
|
||||
const auto gparams = graph_params(res, ubatch, mctx, gtype, mtp_params);
|
||||
|
||||
if (!graph_reuse_disable && res->can_reuse(gparams)) {
|
||||
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
|
||||
|
|
@ -848,6 +865,13 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
}
|
||||
}
|
||||
|
||||
if (mtp_params.op_type != MTP_OP_NONE) { // If it is any MTP operation
|
||||
if (!prepare_mtp_graph_inputs(res, ubatch, mtp_params)) {
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// set the input data for the input tensors
|
||||
{
|
||||
//const auto t_start_us = ggml_time_us();
|
||||
|
|
@ -926,7 +950,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
cparams.causal_attn = false;
|
||||
|
||||
ggml_status status;
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, { MTP_OP_NONE });
|
||||
|
||||
cparams.causal_attn = causal_attn_org;
|
||||
|
||||
|
|
@ -1034,6 +1058,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
int llama_context::decode(const llama_batch & batch_inp) {
|
||||
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
||||
|
||||
auto * kvd = static_cast<llama_context_kv_cache_data *>(kv_cache_data);
|
||||
|
||||
if (!memory) {
|
||||
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
||||
return encode(batch_inp);
|
||||
|
|
@ -1088,10 +1114,11 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// handle any pending shifts/copies
|
||||
memory_update(false);
|
||||
|
||||
llama_memory_context_ptr mctx;
|
||||
std::unique_ptr<llama_memory_context_i> mctx;
|
||||
|
||||
while (true) {
|
||||
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||
mctx = this->initialize_decode_context(batch_inp, output_all);
|
||||
|
||||
if (!mctx) {
|
||||
return -2;
|
||||
}
|
||||
|
|
@ -1108,6 +1135,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||
{
|
||||
if (kvd->forced_sinfos) {
|
||||
LLAMA_LOG_ERROR("%s: Mismatch between ubatches and sinfos during reuse.\n", __func__);
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!did_optimize) {
|
||||
did_optimize = true;
|
||||
|
||||
|
|
@ -1161,7 +1194,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
|
||||
ggml_status status;
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, batch_inp.mtp_params);
|
||||
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module
|
||||
|
|
@ -1208,71 +1241,81 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
// extract logits
|
||||
if (t_logits && n_outputs > 0) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
// MTP operations that are purely for updating the KV cache
|
||||
// (MTP_OP_WARMUP and MTP_OP_UPDATE_ACCEPTED) also produce a logit tensor
|
||||
// as a side effect of running the graph. If these logits are copied
|
||||
// back to the main context buffer, they will overwrite the valid logits
|
||||
// produced by the main model's pass, leading to incorrect sampling.
|
||||
// This condition explicitly prevents that copy for cache-only operations.
|
||||
if (batch_inp.mtp_params.op_type != MTP_OP_WARMUP &&
|
||||
batch_inp.mtp_params.op_type != MTP_OP_UPDATE_ACCEPTED) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
|
||||
float * logits_out = logits + n_outputs_prev*n_vocab;
|
||||
float * logits_out = logits + n_outputs_prev*n_vocab;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd && n_outputs > 0) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
float * embd_out = embd + n_outputs_prev*n_embd;
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
float * embd_out = embd + n_outputs_prev*n_embd;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - n_cls_out floats per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_cls_out);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - n_cls_out floats per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_cls_out);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1479,7 +1522,7 @@ ggml_cgraph * llama_context::graph_reserve(
|
|||
|
||||
auto * res = gf_res_reserve.get();
|
||||
|
||||
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
|
||||
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE });
|
||||
|
||||
res->reset();
|
||||
|
||||
|
|
@ -1506,8 +1549,9 @@ ggml_cgraph * llama_context::graph_reserve(
|
|||
llm_graph_params llama_context::graph_params(
|
||||
llm_graph_result * res,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_memory_context_i * mctx,
|
||||
llm_graph_type gtype) const {
|
||||
const llama_memory_context_i * mctx,
|
||||
llm_graph_type gtype,
|
||||
const llama_mtp_params & mtp_params) const {
|
||||
return {
|
||||
/*.arch =*/ model.arch,
|
||||
/*.hparams =*/ model.hparams,
|
||||
|
|
@ -1520,12 +1564,28 @@ llm_graph_params llama_context::graph_params(
|
|||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.mtp_params =*/ mtp_params,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
/*.res =*/ res,
|
||||
};
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_memory_context_i> llama_context::mtp_memory_batch(const llama_batch& batch_inp) {
|
||||
const auto& vocab = model.vocab;
|
||||
const auto& hparams = model.hparams;
|
||||
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, false)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return memory->init_batch(*balloc, 1, false);
|
||||
}
|
||||
|
||||
ggml_status llama_context::graph_compute(
|
||||
ggml_cgraph * gf,
|
||||
bool batched) {
|
||||
|
|
@ -2267,7 +2327,7 @@ void llama_context::opt_epoch_iter(
|
|||
|
||||
auto * res = gf_res_prev.get();
|
||||
|
||||
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
|
||||
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, { MTP_OP_NONE });
|
||||
|
||||
res->reset();
|
||||
|
||||
|
|
@ -3056,3 +3116,122 @@ void llama_opt_epoch(
|
|||
callback_train,
|
||||
callback_eval);
|
||||
}
|
||||
|
||||
void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) {
|
||||
ctx->draft_input_hidden_state = hidden_state;
|
||||
}
|
||||
|
||||
bool llama_mtp_prepare_sinfo_for_warmup(struct llama_context * ctx) {
|
||||
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
|
||||
const auto & last_sinfo = kvd->last_main_model_sinfos;
|
||||
|
||||
if (last_sinfo.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: The main call sinfo is not available for warmup.\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
kvd->forced_sinfos = &last_sinfo;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
bool llama_mtp_prepare_sinfo_for_update(struct llama_context * ctx, size_t n_accepted) {
|
||||
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
|
||||
const auto & last_sinfo = kvd->last_main_model_sinfos;
|
||||
|
||||
if (last_sinfo.empty() || last_sinfo[0].idxs.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: The sinfo for the last main call is not available.", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
kvd->resized_sinfo_for_force = last_sinfo;
|
||||
|
||||
kvd->resized_sinfo_for_force[0].idxs[0].resize(n_accepted);
|
||||
|
||||
kvd->forced_sinfos = &kvd->resized_sinfo_for_force;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void llama_mtp_cancel_sinfo_update(struct llama_context * ctx) {
|
||||
auto * kvd = static_cast<llama_context_kv_cache_data *>(ctx->kv_cache_data);
|
||||
kvd->forced_sinfos = nullptr;
|
||||
}
|
||||
|
||||
void llama_context::kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
if (memory) {
|
||||
static_cast<llama_kv_cache *>(memory.get())->seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
ctx->kv_cache_seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
|
||||
/*
|
||||
Initializes the memory context for a decode operation.
|
||||
The logic follows a specific priority:
|
||||
1. Warmup: Always use a standard batch initialization.
|
||||
2. Forced S-Info (MTP Updates): If a specific KV cache layout is forced, use it.
|
||||
3. Default: Use a standard batch initialization, and if it's a main model pass,
|
||||
save the resulting s-info for potential future reuse by MTP.
|
||||
*/
|
||||
std::unique_ptr<llama_memory_context_i> llama_context::initialize_decode_context(const llama_batch & batch_inp, const bool output_all) {
|
||||
auto * kvd = static_cast<llama_context_kv_cache_data *>(kv_cache_data);
|
||||
std::unique_ptr<llama_memory_context_i> mctx;
|
||||
|
||||
if (cparams.warmup) {
|
||||
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||
} else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) {
|
||||
LLAMA_LOG_DEBUG("%s: Forcing sinfos, bypassing find_slot.\n", __func__);
|
||||
mctx = static_cast<llama_kv_cache *>(memory.get())->init_batch_with_sinfos(
|
||||
*balloc, cparams.n_ubatch, *kvd->forced_sinfos, true
|
||||
);
|
||||
} else {
|
||||
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||
|
||||
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
|
||||
if (mctx && mctx->get_status() == LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
kvd->last_main_model_sinfos = static_cast<llama_kv_cache_context *>(mctx.get())->get_sinfos();
|
||||
} else {
|
||||
kvd->last_main_model_sinfos.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mctx;
|
||||
}
|
||||
|
||||
|
||||
bool llama_context::prepare_mtp_graph_inputs(
|
||||
llm_graph_result * res,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_mtp_params & mtp_params) {
|
||||
|
||||
const char * target_tensor_name = "result_embd_pooled";
|
||||
ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name);
|
||||
|
||||
const float * source_hidden_state = nullptr;
|
||||
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
source_hidden_state = this->embd;
|
||||
} else { // MTP_OP_DRAFT_GEN
|
||||
source_hidden_state = this->draft_input_hidden_state;
|
||||
}
|
||||
|
||||
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
|
||||
const char * op_type;
|
||||
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
op_type = "MTP_UPDATE";
|
||||
} else { // MTP_OP_DRAFT_GEN
|
||||
op_type = "DRAFT_GEN";
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",
|
||||
__func__, target_tensor_name);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,6 +32,8 @@ struct llama_memory_breakdown_data {
|
|||
}
|
||||
};
|
||||
|
||||
struct llama_context_kv_cache_data;
|
||||
|
||||
struct llama_context {
|
||||
// init scheduler and compute buffers, reserve worst-case graphs
|
||||
llama_context(
|
||||
|
|
@ -69,6 +71,11 @@ struct llama_context {
|
|||
float * get_embeddings();
|
||||
float * get_embeddings_ith(int32_t i);
|
||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
ggml_tensor * get_embeddings_tensor();
|
||||
|
||||
const float * draft_input_hidden_state = nullptr;
|
||||
|
||||
void * kv_cache_data = nullptr;
|
||||
|
||||
void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
|
|
@ -100,6 +107,8 @@ struct llama_context {
|
|||
int32_t il_start,
|
||||
int32_t il_end);
|
||||
|
||||
void kv_cache_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1);
|
||||
|
||||
// process a single ubatch with a specific graph type
|
||||
// if memory_context is provided, it will be applied first to the context's memory
|
||||
// ret contains the status of the graph computation
|
||||
|
|
@ -108,7 +117,8 @@ struct llama_context {
|
|||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
llama_memory_context_i * mctx,
|
||||
ggml_status & ret);
|
||||
ggml_status & ret,
|
||||
const llama_mtp_params & mtp_params);
|
||||
|
||||
int encode(const llama_batch & batch_inp);
|
||||
int decode(const llama_batch & batch_inp);
|
||||
|
|
@ -218,10 +228,21 @@ private:
|
|||
llm_graph_result * res,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_memory_context_i * mctx,
|
||||
llm_graph_type gtype) const;
|
||||
llm_graph_type gtype,
|
||||
const llama_mtp_params & mtp_params) const;
|
||||
|
||||
llm_graph_cb graph_get_cb() const;
|
||||
|
||||
// Methods for MTP decode
|
||||
std::unique_ptr<llama_memory_context_i> initialize_decode_context(const llama_batch & batch_inp, const bool output_all);
|
||||
|
||||
bool prepare_mtp_graph_inputs(
|
||||
llm_graph_result * res,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_mtp_params & mtp_params);
|
||||
|
||||
std::unique_ptr<struct llama_memory_context_i> mtp_memory_batch(const llama_batch & batch_inp);
|
||||
|
||||
// TODO: read/write lora adapters and cvec
|
||||
size_t state_write_data(llama_io_write_i & io);
|
||||
size_t state_read_data (llama_io_read_i & io);
|
||||
|
|
@ -251,6 +272,7 @@ private:
|
|||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
float * embd = nullptr;
|
||||
ggml_tensor * embd_tensor = nullptr;
|
||||
|
||||
// sequence embeddings output (map of [n_embd] vectors)
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
|
|
@ -315,4 +337,4 @@ private:
|
|||
mutable int32_t n_eval = 0; // number of eval calls
|
||||
|
||||
mutable int32_t n_reused = 0; // number of times the previous graph was reused
|
||||
};
|
||||
};
|
||||
|
|
@ -1254,6 +1254,26 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|||
return cur;
|
||||
}
|
||||
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>();
|
||||
ggml_tensor * cur = nullptr;
|
||||
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_name(inp->tokens, "mtp_inp_tokens");
|
||||
ggml_set_input(inp->tokens);
|
||||
|
||||
cur = ggml_get_rows(ctx0, mtp_tok_embd, inp->tokens);
|
||||
} else {
|
||||
GGML_ABORT("fatal error: MTP update expects token IDs, not embeddings");
|
||||
}
|
||||
|
||||
cb(cur, "mtp_inp_embd", -1);
|
||||
res->add_input(std::move(inp));
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_pos() const {
|
||||
auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ enum llm_graph_type {
|
|||
LLM_GRAPH_TYPE_DEFAULT,
|
||||
LLM_GRAPH_TYPE_ENCODER,
|
||||
LLM_GRAPH_TYPE_DECODER,
|
||||
LLM_GRAPH_TYPE_DRAFT,
|
||||
};
|
||||
|
||||
enum llm_ffn_op_type {
|
||||
|
|
@ -102,6 +103,20 @@ protected:
|
|||
|
||||
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
|
||||
|
||||
class llm_graph_input_mtp_states : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mtp_states() = default;
|
||||
virtual ~llm_graph_input_mtp_states() = default;
|
||||
|
||||
void set_input(const llama_ubatch * /*ubatch*/) override {}
|
||||
|
||||
bool can_reuse(const llm_graph_params & /*params*/) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_tensor * states = nullptr;
|
||||
};
|
||||
|
||||
class llm_graph_input_embd : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_embd() = default;
|
||||
|
|
@ -428,6 +443,7 @@ struct llm_graph_params {
|
|||
const llama_adapter_loras * loras;
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
llama_mtp_params mtp_params;
|
||||
|
||||
uint32_t n_outputs;
|
||||
|
||||
|
|
@ -476,6 +492,7 @@ struct llm_graph_params {
|
|||
cvec == other.cvec &&
|
||||
loras == other.loras &&
|
||||
cross == other.cross &&
|
||||
mtp_params.op_type == other.mtp_params.op_type &&
|
||||
n_outputs == other.n_outputs;
|
||||
}
|
||||
};
|
||||
|
|
@ -690,6 +707,7 @@ struct llm_graph_context {
|
|||
//
|
||||
|
||||
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
|
||||
ggml_tensor * build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const;
|
||||
ggml_tensor * build_inp_pos() const;
|
||||
ggml_tensor * build_inp_attn_scale() const;
|
||||
ggml_tensor * build_inp_out_ids() const;
|
||||
|
|
@ -842,4 +860,4 @@ struct llm_graph_context {
|
|||
};
|
||||
|
||||
// TODO: better name
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);
|
||||
|
|
@ -542,6 +542,34 @@ llama_memory_context_ptr llama_kv_cache::init_batch(
|
|||
return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache::init_batch_with_sinfos(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
const slot_info_vec_t & sinfos,
|
||||
bool is_inplace_update) {
|
||||
|
||||
if (sinfos.empty()) {
|
||||
return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
balloc.split_reset();
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
ubatches.push_back(std::move(ubatch));
|
||||
}
|
||||
|
||||
if (ubatches.size() != sinfos.size()) {
|
||||
return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_context>(
|
||||
this, sinfos, std::move(ubatches), is_inplace_update);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache::init_full() {
|
||||
return std::make_unique<llama_kv_cache_context>(this);
|
||||
}
|
||||
|
|
@ -888,40 +916,61 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch,
|
|||
}
|
||||
|
||||
assert(res.s1 >= res.s0);
|
||||
if (!res.empty()) {
|
||||
std::string idxs_str;
|
||||
for (const auto& vec : res.idxs) {
|
||||
if (!vec.empty()) {
|
||||
if (vec.size() > 8) {
|
||||
idxs_str += " [" + std::to_string(vec.front()) + "..." + std::to_string(vec.back()) + " (" + std::to_string(vec.size()) + " cells)]";
|
||||
} else {
|
||||
idxs_str += " [";
|
||||
for(size_t i = 0; i < vec.size(); ++i) {
|
||||
idxs_str += std::to_string(vec[i]) + (i == vec.size() - 1 ? "" : ", ");
|
||||
}
|
||||
idxs_str += "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||
// for non-SWA cache, this would be always empty
|
||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
seq_pos_max_rm[s] = -1;
|
||||
}
|
||||
void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update) {
|
||||
// For "in-place" updates (MTP warmup/accept), we only update the tensor data.
|
||||
// The cell metadata (logical position, sequence ID) has already been set
|
||||
// by the main model's pass. We must skip all metadata modifications
|
||||
// to prevent `pos_set` from asserting on an already-set cell.
|
||||
if (!is_inplace_update) {
|
||||
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||
// for non-SWA cache, this would be always empty
|
||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
seq_pos_max_rm[s] = -1;
|
||||
}
|
||||
|
||||
assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
|
||||
assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
|
||||
|
||||
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||
for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
|
||||
const uint32_t i = s*sinfo.size() + ii;
|
||||
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
|
||||
for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
|
||||
const uint32_t i = s*sinfo.size() + ii;
|
||||
|
||||
auto & cells = v_cells[sinfo.strm[s]];
|
||||
auto & cells = v_cells[sinfo.strm[s]];
|
||||
|
||||
const auto idx = sinfo.idxs[s][ii];
|
||||
const auto idx = sinfo.idxs[s][ii];
|
||||
|
||||
if (!cells.is_empty(idx)) {
|
||||
assert(cells.seq_count(idx) == 1);
|
||||
if (!cells.is_empty(idx)) {
|
||||
assert(cells.seq_count(idx) == 1);
|
||||
|
||||
const llama_seq_id seq_id = cells.seq_get(idx);
|
||||
const llama_pos pos = cells.pos_get(idx);
|
||||
const llama_seq_id seq_id = cells.seq_get(idx);
|
||||
const llama_pos pos = cells.pos_get(idx);
|
||||
|
||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||
|
||||
cells.rm(idx);
|
||||
}
|
||||
cells.rm(idx);
|
||||
}
|
||||
|
||||
cells.pos_set(idx, ubatch.pos[i]);
|
||||
cells.pos_set(idx, ubatch.pos[i]);
|
||||
|
||||
if (ubatch.is_pos_2d()) {
|
||||
llama_kv_cell_ext ext {
|
||||
|
|
@ -931,29 +980,30 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
|
|||
cells.ext_set(idx, ext);
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
||||
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
|
||||
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq_pos_max_rm[s] == -1) {
|
||||
continue;
|
||||
}
|
||||
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
|
||||
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq_pos_max_rm[s] == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
GGML_ASSERT(s < seq_to_stream.size());
|
||||
GGML_ASSERT(s < seq_to_stream.size());
|
||||
|
||||
auto & cells = v_cells[seq_to_stream[s]];
|
||||
auto & cells = v_cells[seq_to_stream[s]];
|
||||
|
||||
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
||||
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
||||
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
||||
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
||||
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
|
||||
|
||||
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
||||
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2010,7 +2060,8 @@ llama_kv_cache_context::llama_kv_cache_context(
|
|||
llama_kv_cache_context::llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
llama_kv_cache::slot_info_vec_t sinfos,
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
|
||||
std::vector<llama_ubatch> ubatches,
|
||||
bool is_inplace_update) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)), is_inplace_update(is_inplace_update) {
|
||||
}
|
||||
|
||||
llama_kv_cache_context::~llama_kv_cache_context() = default;
|
||||
|
|
@ -2035,7 +2086,7 @@ bool llama_kv_cache_context::apply() {
|
|||
return true;
|
||||
}
|
||||
|
||||
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
|
||||
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur], is_inplace_update);
|
||||
n_kv = kv->get_n_kv(sinfos[i_cur]);
|
||||
|
||||
return true;
|
||||
|
|
@ -2098,3 +2149,7 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_pos_bucket(dst, ubatch);
|
||||
}
|
||||
|
||||
void llama_kv_cache_context::set_sinfos(llama_kv_cache_context::slot_info_vec_t new_sinfos) {
|
||||
sinfos = new_sinfos;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -118,6 +118,12 @@ public:
|
|||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_context_ptr init_batch_with_sinfos(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
const slot_info_vec_t & sinfos,
|
||||
bool is_inplace_update);
|
||||
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
|
|
@ -182,7 +188,7 @@ public:
|
|||
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
|
||||
|
||||
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
|
||||
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
||||
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch, bool is_inplace_update = false);
|
||||
|
||||
//
|
||||
// input API
|
||||
|
|
@ -309,7 +315,8 @@ public:
|
|||
llama_kv_cache_context(
|
||||
llama_kv_cache * kv,
|
||||
slot_info_vec_t sinfos,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
std::vector<llama_ubatch> ubatches,
|
||||
bool is_inplace_update = false);
|
||||
|
||||
virtual ~llama_kv_cache_context();
|
||||
|
||||
|
|
@ -355,6 +362,10 @@ public:
|
|||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
void set_sinfos(slot_info_vec_t new_sinfos);
|
||||
|
||||
const slot_info_vec_t & get_sinfos() const { return sinfos; }
|
||||
|
||||
private:
|
||||
llama_memory_status status;
|
||||
|
||||
|
|
@ -387,4 +398,6 @@ private:
|
|||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// as the cache gets filled, the benefit from this heuristic disappears
|
||||
int32_t n_kv;
|
||||
|
||||
bool is_inplace_update = false;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1782,9 +1782,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
// NextN/MTP parameters
|
||||
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
|
||||
|
||||
// TODO: when MTP is implemented, this should probably be updated if needed
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
|
||||
if (params.mtp) {
|
||||
// Include MTP layers in KV cache if MTP is enabled
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer;
|
||||
}
|
||||
else {
|
||||
// Otherwise exclude to save memory
|
||||
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
}
|
||||
switch (hparams.n_layer) {
|
||||
case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
|
||||
case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open
|
||||
|
|
@ -5215,12 +5220,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
}
|
||||
|
||||
// Load ALL tensors including NextN layer to satisfy total tensor count
|
||||
// but only PROCESS up to last layer (skipping final NextN layer) in forward pass
|
||||
// but skip loading data for NextN layers if MTP is disabled to save VRAM
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
int flags = 0;
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
// skip all tensors in the NextN layers
|
||||
flags |= TENSOR_SKIP;
|
||||
// Skip loading MTP layers if the feature is disabled
|
||||
if (!params.mtp) {
|
||||
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
|
||||
flags |= TENSOR_SKIP;
|
||||
}
|
||||
}
|
||||
|
||||
auto & layer = layers[i];
|
||||
|
|
@ -7908,7 +7915,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
}
|
||||
|
||||
// add on pooling layer
|
||||
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
|
||||
if (params.mtp_params.op_type == MTP_OP_NONE) {
|
||||
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
|
||||
}
|
||||
|
||||
// if the gguf model was converted with --sentence-transformers-dense-modules
|
||||
// there will be two additional dense projection layers
|
||||
|
|
@ -7942,6 +7951,7 @@ llama_model_params llama_model_default_params() {
|
|||
/*.use_extra_bufts =*/ true,
|
||||
/*.no_host =*/ false,
|
||||
/*.no_alloc =*/ false,
|
||||
/*.mtp =*/ false,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
@ -7999,6 +8009,10 @@ const char * llama_model_cls_label(const struct llama_model * model, uint32_t i)
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
int32_t llama_model_n_nextn_layer(const llama_model * model) {
|
||||
return model->hparams.nextn_predict_layers;
|
||||
}
|
||||
|
||||
// deprecated
|
||||
int32_t llama_n_ctx_train(const llama_model * model) {
|
||||
return llama_model_n_ctx_train(model);
|
||||
|
|
|
|||
|
|
@ -2,169 +2,323 @@
|
|||
|
||||
llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
int sections[4];
|
||||
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
if (params.mtp_params.op_type != MTP_OP_NONE) {
|
||||
ggml_tensor* hidden_states_from_main_model;
|
||||
|
||||
bool use_mrope = hparams.use_mrope();
|
||||
if (ubatch.embd && !use_mrope) {
|
||||
// unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results
|
||||
GGML_ABORT("This GGUF does not support multimodal. Please reconvert it.");
|
||||
if (params.mtp_params.op_type == MTP_OP_WARMUP || params.mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
|
||||
hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
|
||||
ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
|
||||
ggml_set_input(hidden_states_from_main_model);
|
||||
|
||||
auto inp_mtp = std::make_unique<llm_graph_input_mtp_states>();
|
||||
inp_mtp->states = hidden_states_from_main_model;
|
||||
res->add_input(std::move(inp_mtp));
|
||||
} else {
|
||||
hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
|
||||
ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
|
||||
ggml_set_input(hidden_states_from_main_model);
|
||||
|
||||
auto inp_mtp = std::make_unique<llm_graph_input_mtp_states>();
|
||||
inp_mtp->states = hidden_states_from_main_model;
|
||||
res->add_input(std::move(inp_mtp));
|
||||
}
|
||||
|
||||
const int il_mtp = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il_mtp];
|
||||
res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head, model);
|
||||
|
||||
} else {
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
bool use_mrope = hparams.use_mrope();
|
||||
if (ubatch.embd && !use_mrope) {
|
||||
// unfortunately, we need to forcefully stop here, to avoid users complaining about wrong results
|
||||
GGML_ABORT("This GGUF does not support multimodal. Please reconvert it.");
|
||||
}
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
// Only process up to last layer (skip final NextN layer)
|
||||
// Final layer tensors are loaded but not processed in forward pass
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// Pre-attention norm
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
}
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
}
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Apply Q/K norm if available (GLM-4.5 355B variant)
|
||||
if (model.layers[il].attn_q_norm) {
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
}
|
||||
if (model.layers[il].attn_k_norm) {
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
}
|
||||
|
||||
if (use_mrope) {
|
||||
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
} else {
|
||||
// Normal RoPE
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot,
|
||||
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot,
|
||||
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// Post-attention norm
|
||||
cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "post_attn_norm", il);
|
||||
|
||||
// Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense)
|
||||
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
|
||||
// Dense FFN layer
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// Process routed experts using existing MoE infrastructure
|
||||
ggml_tensor * routed_out = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(routed_out, "ffn_moe_out", il);
|
||||
|
||||
// Process shared expert on original input
|
||||
ggml_tensor * shared_out = build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(shared_out, "ffn_shexp_out", il);
|
||||
|
||||
// Final output: routed_output + shared_output
|
||||
cur = ggml_add(ctx0, routed_out, shared_out);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
}
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_build_forward_expand(gf, res->t_logits);
|
||||
}
|
||||
|
||||
|
||||
ggml_tensor * llm_build_glm4_moe::build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings,
|
||||
int64_t n_embd_head, const llama_model & model) {
|
||||
ggml_tensor * embd_copy = ggml_dup(ctx0, prev_embeddings);
|
||||
cb(embd_copy, "mtp_embd_copy", -1);
|
||||
|
||||
const int il = hparams.n_layer - 1;
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
// Only process up to last layer (skip final NextN layer)
|
||||
// Final layer tensors are loaded but not processed in forward pass
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// Pre-attention norm
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
}
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
}
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Apply Q/K norm if available (GLM-4.5 355B variant)
|
||||
if (model.layers[il].attn_q_norm) {
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
}
|
||||
if (model.layers[il].attn_k_norm) {
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
}
|
||||
|
||||
if (use_mrope) {
|
||||
Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
} else {
|
||||
// Normal RoPE
|
||||
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot,
|
||||
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
|
||||
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot,
|
||||
rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
}
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// Post-attention norm
|
||||
cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "post_attn_norm", il);
|
||||
|
||||
// Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense)
|
||||
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
|
||||
// Dense FFN layer
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// Process routed experts using existing MoE infrastructure
|
||||
ggml_tensor * routed_out = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(routed_out, "ffn_moe_out", il);
|
||||
|
||||
// Process shared expert on original input
|
||||
ggml_tensor * shared_out = build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(shared_out, "ffn_shexp_out", il);
|
||||
|
||||
// Final output: routed_output + shared_output
|
||||
cur = ggml_add(ctx0, routed_out, shared_out);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
// If nextn.embed_tokens is missing (GLM-4.6), use model.tok_embd
|
||||
ggml_tensor * mtp_embd_weights = mtp_layer.nextn.embed_tokens;
|
||||
if (mtp_embd_weights == nullptr) {
|
||||
mtp_embd_weights = model.tok_embd;
|
||||
}
|
||||
cur = inpL;
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_embd_weights);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
|
||||
ggml_tensor * hidden_state_norm = build_norm(embd_copy, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
|
||||
cb(combined, "mtp_concat", il);
|
||||
ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined);
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
// now proceed through last layer (skipped in main model)
|
||||
ggml_tensor * inpSA = cur;
|
||||
// Pre-attention norm for the MTP block
|
||||
cur = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur);
|
||||
if (mtp_layer.bq) Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur);
|
||||
if (mtp_layer.bk) Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur);
|
||||
if (mtp_layer.bv) Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Apply Q/K norm if available (GLM-4.5 355B variant)
|
||||
if (mtp_layer.attn_q_norm) {
|
||||
Qcur = build_norm(Qcur, mtp_layer.attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
}
|
||||
if (mtp_layer.attn_k_norm) {
|
||||
Kcur = build_norm(Kcur, mtp_layer.attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
mtp_layer.wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "mtp_ffn_inp", il);
|
||||
|
||||
cur = build_norm(ffn_inp, mtp_layer.attn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
// moe ffn for nextn block
|
||||
{
|
||||
// Process routed experts using existing MoE infrastructure
|
||||
ggml_tensor * routed_out = build_moe_ffn(cur,
|
||||
mtp_layer.ffn_gate_inp,
|
||||
mtp_layer.ffn_up_exps,
|
||||
mtp_layer.ffn_gate_exps,
|
||||
mtp_layer.ffn_down_exps,
|
||||
mtp_layer.ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(routed_out, "ffn_moe_out", il);
|
||||
|
||||
// Process shared expert on original input
|
||||
ggml_tensor * shared_out = build_ffn(cur,
|
||||
mtp_layer.ffn_up_shexp, NULL, NULL,
|
||||
mtp_layer.ffn_gate_shexp, NULL, NULL,
|
||||
mtp_layer.ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(shared_out, "ffn_shexp_out", il);
|
||||
|
||||
// Final output: routed_output + shared_output
|
||||
cur = ggml_add(ctx0, routed_out, shared_out);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "mtp_ffn_out_resid", il);
|
||||
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
// If nextn.shared_head_head is missing (GLM-4.6), use model.output (Main LM Head)
|
||||
ggml_tensor * mtp_head_weights = mtp_layer.nextn.shared_head_head;
|
||||
if (mtp_head_weights == nullptr) {
|
||||
mtp_head_weights = model.output;
|
||||
}
|
||||
cur = build_lora_mm(mtp_head_weights, cur);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -220,6 +220,8 @@ struct llm_build_glm4 : public llm_graph_context {
|
|||
|
||||
struct llm_build_glm4_moe : public llm_graph_context {
|
||||
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params);
|
||||
|
||||
ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings, int64_t n_embd_head, const llama_model & model);
|
||||
};
|
||||
|
||||
struct llm_build_gpt2 : public llm_graph_context {
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ struct server_slot {
|
|||
mtmd_context * mctx = nullptr;
|
||||
|
||||
common_speculative * spec = nullptr;
|
||||
bool has_mtp = false;
|
||||
|
||||
std::unique_ptr<const server_task> task;
|
||||
std::unique_ptr<const server_task> task_prev; // used for debugging
|
||||
|
|
@ -206,7 +207,7 @@ struct server_slot {
|
|||
bool need_embd() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
return server_task_type_need_embd(task->type);
|
||||
return server_task_type_need_embd(task->type) || has_mtp;
|
||||
}
|
||||
|
||||
bool need_logits() const {
|
||||
|
|
@ -220,7 +221,8 @@ struct server_slot {
|
|||
bool can_split() const {
|
||||
return
|
||||
!need_embd() ||
|
||||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
|
||||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST) ||
|
||||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_NONE);
|
||||
}
|
||||
|
||||
bool can_batch_with(server_slot & other_slot) const {
|
||||
|
|
@ -252,7 +254,7 @@ struct server_slot {
|
|||
}
|
||||
|
||||
bool can_speculate() const {
|
||||
return ctx_dft;
|
||||
return (ctx_dft || has_mtp);
|
||||
}
|
||||
|
||||
void add_token(const completion_token_output & token) {
|
||||
|
|
@ -776,6 +778,18 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
// if model has MTP and no draft model is specified...
|
||||
else if (llama_model_n_nextn_layer(model) > 0 && params_base.mtp) {
|
||||
SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model));
|
||||
slot.has_mtp = true;
|
||||
|
||||
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
||||
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
|
||||
|
||||
SRV_INF("%s (n_max=%d)\n", "MTP needs embeddings on decode, enabling", params_base.speculative.n_max);
|
||||
llama_set_embeddings(ctx, true);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
||||
|
||||
slot.callback_on_release = [this](int) {
|
||||
|
|
@ -1974,12 +1988,34 @@ private:
|
|||
GGML_ABORT("not supported by multimodal");
|
||||
}
|
||||
|
||||
llama_tokens draft;
|
||||
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
|
||||
params_spec.p_min = slot.task->params.speculative.p_min;
|
||||
const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
||||
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
||||
|
||||
if (slot.ctx_dft) {
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max;
|
||||
} else {
|
||||
params_spec.n_reuse = 0;
|
||||
}
|
||||
|
||||
if (slot.has_mtp) {
|
||||
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, -1));
|
||||
|
||||
draft = mtp_speculative_gen_draft(
|
||||
slot.smpl.get(),
|
||||
ctx,
|
||||
params_spec,
|
||||
slot.sampled,
|
||||
slot.prompt.n_tokens(),
|
||||
slot.id
|
||||
);
|
||||
}
|
||||
else {
|
||||
const llama_tokens& cached_text_tokens = slot.prompt.tokens.get_text_tokens();
|
||||
draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled);
|
||||
}
|
||||
|
||||
// add the sampled token to the batch
|
||||
slot.i_batch_dft.push_back(batch.n_tokens);
|
||||
|
|
@ -2592,6 +2628,21 @@ private:
|
|||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
if (slot_batched && slot_batched->has_mtp &&
|
||||
(slot_batched->state == SLOT_STATE_PROCESSING_PROMPT || slot_batched->state == SLOT_STATE_DONE_PROMPT)) {
|
||||
|
||||
// Prepare the context to reuse the exact sinfo layout (including multiple u-batches)
|
||||
// from the main model's prompt processing pass. This ensures the MTP layer's
|
||||
// KV cache is perfectly aligned.
|
||||
if (llama_mtp_prepare_sinfo_for_warmup(ctx)) {
|
||||
mtp_update_kv_cache(ctx, batch_view, true);
|
||||
// Clean up the forced state to not affect subsequent decodes.
|
||||
llama_mtp_cancel_sinfo_update(ctx);
|
||||
} else {
|
||||
LOG_ERR("%s: Failed to prepare the MTP for warmup.", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// move the head of the batch forward with the number of tokens we just processed
|
||||
i_next = i + n_tokens;
|
||||
|
||||
|
|
@ -2711,6 +2762,16 @@ private:
|
|||
slot.i_batch_dft.clear();
|
||||
slot.drafted.clear();
|
||||
|
||||
if (slot.has_mtp) {
|
||||
if (!ids.empty()) {
|
||||
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1));
|
||||
} else {
|
||||
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0));
|
||||
}
|
||||
|
||||
mtp_accept_tokens(ctx, ids, slot.prompt.n_tokens(), slot.id);
|
||||
}
|
||||
|
||||
slot.n_decoded += ids.size();
|
||||
|
||||
slot.t_token_generation = std::max<int64_t>(1, t_current - slot.t_start_generation) / 1e3;
|
||||
|
|
|
|||
Loading…
Reference in New Issue