Merge 52c283a951 into 58062860af
This commit is contained in:
commit
a42f8e0c6d
|
|
@ -35,8 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
|
|||
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
||||
enum slot_state {
|
||||
SLOT_STATE_IDLE,
|
||||
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
|
||||
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
|
||||
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
|
||||
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
|
||||
SLOT_STATE_PROCESSING_PROMPT,
|
||||
SLOT_STATE_DONE_PROMPT,
|
||||
SLOT_STATE_GENERATING,
|
||||
|
|
@ -529,6 +529,7 @@ struct server_context_impl {
|
|||
llama_batch batch {};
|
||||
|
||||
bool add_bos_token = true;
|
||||
bool has_encoder = false; // true if model is encoder-decoder (e.g., T5, BART)
|
||||
|
||||
int32_t n_ctx; // total context for all clients / slots
|
||||
|
||||
|
|
@ -590,6 +591,23 @@ struct server_context_impl {
|
|||
n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||
has_encoder = llama_model_has_encoder(model);
|
||||
|
||||
if (has_encoder) {
|
||||
SRV_INF("model has encoder - encoder-decoder mode enabled (e.g., T5, BART)%s\n", "");
|
||||
|
||||
// warn about incompatible features
|
||||
if (params_base.ctx_shift) {
|
||||
SRV_WRN("encoder-decoder models do not support context shift - disabling%s\n", "");
|
||||
params_base.ctx_shift = false;
|
||||
}
|
||||
// Note: prompt caching is disabled for encoder-decoder models
|
||||
// (encoder outputs depend on entire input, prefix caching doesn't apply)
|
||||
if (params_base.has_speculative()) {
|
||||
SRV_WRN("encoder-decoder models do not support speculative decoding - ignoring draft model%s\n", "");
|
||||
// Note: speculative setup continues below but won't be used for enc-dec slots
|
||||
}
|
||||
}
|
||||
|
||||
if (params_base.has_speculative()) {
|
||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
|
||||
|
|
@ -723,7 +741,8 @@ struct server_context_impl {
|
|||
slot.mctx = mctx;
|
||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||
|
||||
if (model_dft) {
|
||||
// speculative decoding is not supported for encoder-decoder models
|
||||
if (model_dft && !has_encoder) {
|
||||
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
||||
|
||||
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
|
||||
|
|
@ -838,7 +857,8 @@ struct server_context_impl {
|
|||
bool update_cache = false;
|
||||
|
||||
// find the slot that has at least n% prompt similarity
|
||||
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
|
||||
// skip for encoder-decoder models - slot.prompt.tokens only contains decoder tokens
|
||||
if (ret == nullptr && slot_prompt_similarity != 0.0f && !has_encoder) {
|
||||
float sim_best = 0;
|
||||
|
||||
for (server_slot & slot : slots) {
|
||||
|
|
@ -916,6 +936,11 @@ struct server_context_impl {
|
|||
// TODO: mtmd does not support prompt cache
|
||||
update_cache = update_cache && (ret->mctx == nullptr);
|
||||
|
||||
// encoder-decoder models don't support prompt caching:
|
||||
// - encoder outputs depend on the entire input, not just a prefix
|
||||
// - we always clear the decoder KV cache and re-encode
|
||||
update_cache = update_cache && !has_encoder;
|
||||
|
||||
if (update_cache) {
|
||||
SRV_WRN("%s", "updating prompt cache\n");
|
||||
|
||||
|
|
@ -1921,11 +1946,101 @@ struct server_context_impl {
|
|||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_generation = 0;
|
||||
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
|
||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
|
||||
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
|
||||
|
||||
// encoder-decoder model handling (e.g., T5, BART, MADLAD)
|
||||
if (has_encoder) {
|
||||
SLT_INF(slot, "encoder-decoder model: encoding %d tokens\n", slot.task->n_tokens());
|
||||
|
||||
// clear the decoder KV cache for this slot - encoder-decoder models
|
||||
// don't support prefix caching, so we always start fresh
|
||||
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||
slot.prompt.tokens.clear();
|
||||
|
||||
// empty prompt check
|
||||
if (input_tokens.empty()) {
|
||||
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||
slot.print_timings();
|
||||
send_final_response(slot);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
// get the text tokens for encoding
|
||||
const llama_tokens & text_tokens = input_tokens.get_text_tokens();
|
||||
|
||||
// check for empty text tokens (could happen with multimodal-only input)
|
||||
if (text_tokens.empty()) {
|
||||
SLT_ERR(slot, "%s", "encoder-decoder models require text tokens\n");
|
||||
send_error(slot, "encoder-decoder models require text input", ERROR_TYPE_INVALID_REQUEST);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
// build encoder batch with all prompt tokens
|
||||
// Note: we need to allocate a proper batch with seq_id support
|
||||
llama_batch batch_enc = llama_batch_init(text_tokens.size(), 0, 1);
|
||||
batch_enc.n_tokens = text_tokens.size();
|
||||
|
||||
for (size_t i = 0; i < text_tokens.size(); i++) {
|
||||
batch_enc.token[i] = text_tokens[i];
|
||||
batch_enc.pos[i] = i;
|
||||
batch_enc.n_seq_id[i] = 1;
|
||||
batch_enc.seq_id[i][0] = slot.id;
|
||||
batch_enc.logits[i] = false;
|
||||
}
|
||||
|
||||
// encode the entire prompt
|
||||
const int ret = llama_encode(ctx, batch_enc);
|
||||
|
||||
// free the encoder batch
|
||||
llama_batch_free(batch_enc);
|
||||
if (ret != 0) {
|
||||
SLT_ERR(slot, "llama_encode() failed with error %d\n", ret);
|
||||
send_error(slot, "encoder failed", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
}
|
||||
|
||||
SLT_INF(slot, "encoder completed, %d tokens encoded\n", slot.task->n_tokens());
|
||||
|
||||
// get decoder start token
|
||||
llama_token decoder_start_token = llama_model_decoder_start_token(model);
|
||||
if (decoder_start_token == LLAMA_TOKEN_NULL) {
|
||||
decoder_start_token = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "decoder start token: %d '%s'\n",
|
||||
decoder_start_token, common_token_to_piece(ctx, decoder_start_token).c_str());
|
||||
|
||||
// add decoder start token to the batch
|
||||
common_batch_add(batch, decoder_start_token, 0, { slot.id }, true);
|
||||
|
||||
// update slot state - we've processed all prompt tokens (via encoder)
|
||||
// and the decoder is ready to generate
|
||||
slot.prompt.tokens.clear();
|
||||
slot.prompt.tokens.push_back(decoder_start_token);
|
||||
slot.n_prompt_tokens_processed = slot.task->n_tokens();
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
|
||||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
|
||||
SLT_INF(slot, "encoder-decoder: prompt encoded, decoder ready%s\n", "");
|
||||
|
||||
if (!slot_batched) {
|
||||
slot_batched = &slot;
|
||||
}
|
||||
|
||||
continue; // skip normal prompt processing
|
||||
}
|
||||
|
||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||
|
||||
// print prompt tokens (for debugging)
|
||||
/*if (1) {
|
||||
// first 16 tokens (avoid flooding logs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue