server : add encoder-decoder model support (T5, BART, MADLAD)

Add support for encoder-decoder models in llama-server, matching the
behavior of llama-cli. This enables translation models like MADLAD
and other T5-based models to work with the server.

Changes:
- Add has_encoder flag to detect encoder-decoder models at load time
- Implement llama_encode() call for encoder-decoder prompt processing
- Use decoder_start_token to initialize decoder after encoding
- Clear decoder KV cache before each new request (no prefix caching)
- Disable incompatible features for encoder-decoder models:
  - Context shift (encoder outputs are fixed)
  - Speculative decoding (not supported)
  - Prompt caching (encoder outputs depend on entire input)
  - Slot selection by LCP similarity (meaningless for enc-dec)
- Add edge case handling for empty text tokens

The encoder processes the full prompt, then the decoder generates
output using cross-attention to the encoder's hidden states.
This commit is contained in:
Turkka Mannila 2025-12-12 11:47:23 +02:00
parent a81a569577
commit 52c283a951
1 changed files with 121 additions and 6 deletions

View File

@ -35,8 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
enum slot_state { enum slot_state {
SLOT_STATE_IDLE, SLOT_STATE_IDLE,
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
SLOT_STATE_STARTED, // after assigning a task and about to process prompt SLOT_STATE_STARTED, // after assigning a task and about to process prompt
SLOT_STATE_PROCESSING_PROMPT, SLOT_STATE_PROCESSING_PROMPT,
SLOT_STATE_DONE_PROMPT, SLOT_STATE_DONE_PROMPT,
SLOT_STATE_GENERATING, SLOT_STATE_GENERATING,
@ -529,6 +529,7 @@ struct server_context_impl {
llama_batch batch {}; llama_batch batch {};
bool add_bos_token = true; bool add_bos_token = true;
bool has_encoder = false; // true if model is encoder-decoder (e.g., T5, BART)
int32_t n_ctx; // total context for all clients / slots int32_t n_ctx; // total context for all clients / slots
@ -593,6 +594,23 @@ struct server_context_impl {
n_ctx = llama_n_ctx(ctx); n_ctx = llama_n_ctx(ctx);
add_bos_token = llama_vocab_get_add_bos(vocab); add_bos_token = llama_vocab_get_add_bos(vocab);
has_encoder = llama_model_has_encoder(model);
if (has_encoder) {
SRV_INF("model has encoder - encoder-decoder mode enabled (e.g., T5, BART)%s\n", "");
// warn about incompatible features
if (params_base.ctx_shift) {
SRV_WRN("encoder-decoder models do not support context shift - disabling%s\n", "");
params_base.ctx_shift = false;
}
// Note: prompt caching is disabled for encoder-decoder models
// (encoder outputs depend on entire input, prefix caching doesn't apply)
if (params_base.has_speculative()) {
SRV_WRN("encoder-decoder models do not support speculative decoding - ignoring draft model%s\n", "");
// Note: speculative setup continues below but won't be used for enc-dec slots
}
}
if (params_base.has_speculative()) { if (params_base.has_speculative()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
@ -726,7 +744,8 @@ struct server_context_impl {
slot.mctx = mctx; slot.mctx = mctx;
slot.prompt.tokens.has_mtmd = mctx != nullptr; slot.prompt.tokens.has_mtmd = mctx != nullptr;
if (model_dft) { // speculative decoding is not supported for encoder-decoder models
if (model_dft && !has_encoder) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
@ -841,7 +860,8 @@ struct server_context_impl {
bool update_cache = false; bool update_cache = false;
// find the slot that has at least n% prompt similarity // find the slot that has at least n% prompt similarity
if (ret == nullptr && slot_prompt_similarity != 0.0f) { // skip for encoder-decoder models - slot.prompt.tokens only contains decoder tokens
if (ret == nullptr && slot_prompt_similarity != 0.0f && !has_encoder) {
float sim_best = 0; float sim_best = 0;
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
@ -919,6 +939,11 @@ struct server_context_impl {
// TODO: mtmd does not support prompt cache // TODO: mtmd does not support prompt cache
update_cache = update_cache && (ret->mctx == nullptr); update_cache = update_cache && (ret->mctx == nullptr);
// encoder-decoder models don't support prompt caching:
// - encoder outputs depend on the entire input, not just a prefix
// - we always clear the decoder KV cache and re-encode
update_cache = update_cache && !has_encoder;
if (update_cache) { if (update_cache) {
SRV_WRN("%s", "updating prompt cache\n"); SRV_WRN("%s", "updating prompt cache\n");
@ -1928,11 +1953,101 @@ struct server_context_impl {
slot.t_start_process_prompt = ggml_time_us(); slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0; slot.t_start_generation = 0;
slot.state = SLOT_STATE_PROCESSING_PROMPT;
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n", SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
// encoder-decoder model handling (e.g., T5, BART, MADLAD)
if (has_encoder) {
SLT_INF(slot, "encoder-decoder model: encoding %d tokens\n", slot.task->n_tokens());
// clear the decoder KV cache for this slot - encoder-decoder models
// don't support prefix caching, so we always start fresh
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
slot.prompt.tokens.clear();
// empty prompt check
if (input_tokens.empty()) {
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
slot.print_timings();
send_final_response(slot);
slot.release();
continue;
}
// get the text tokens for encoding
const llama_tokens & text_tokens = input_tokens.get_text_tokens();
// check for empty text tokens (could happen with multimodal-only input)
if (text_tokens.empty()) {
SLT_ERR(slot, "%s", "encoder-decoder models require text tokens\n");
send_error(slot, "encoder-decoder models require text input", ERROR_TYPE_INVALID_REQUEST);
slot.release();
continue;
}
// build encoder batch with all prompt tokens
// Note: we need to allocate a proper batch with seq_id support
llama_batch batch_enc = llama_batch_init(text_tokens.size(), 0, 1);
batch_enc.n_tokens = text_tokens.size();
for (size_t i = 0; i < text_tokens.size(); i++) {
batch_enc.token[i] = text_tokens[i];
batch_enc.pos[i] = i;
batch_enc.n_seq_id[i] = 1;
batch_enc.seq_id[i][0] = slot.id;
batch_enc.logits[i] = false;
}
// encode the entire prompt
const int ret = llama_encode(ctx, batch_enc);
// free the encoder batch
llama_batch_free(batch_enc);
if (ret != 0) {
SLT_ERR(slot, "llama_encode() failed with error %d\n", ret);
send_error(slot, "encoder failed", ERROR_TYPE_SERVER);
slot.release();
continue;
}
SLT_INF(slot, "encoder completed, %d tokens encoded\n", slot.task->n_tokens());
// get decoder start token
llama_token decoder_start_token = llama_model_decoder_start_token(model);
if (decoder_start_token == LLAMA_TOKEN_NULL) {
decoder_start_token = llama_vocab_bos(vocab);
}
SLT_DBG(slot, "decoder start token: %d '%s'\n",
decoder_start_token, common_token_to_piece(ctx, decoder_start_token).c_str());
// add decoder start token to the batch
common_batch_add(batch, decoder_start_token, 0, { slot.id }, true);
// update slot state - we've processed all prompt tokens (via encoder)
// and the decoder is ready to generate
slot.prompt.tokens.clear();
slot.prompt.tokens.push_back(decoder_start_token);
slot.n_prompt_tokens_processed = slot.task->n_tokens();
common_sampler_reset(slot.smpl);
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
slot.state = SLOT_STATE_DONE_PROMPT;
SLT_INF(slot, "encoder-decoder: prompt encoded, decoder ready%s\n", "");
if (!slot_batched) {
slot_batched = &slot;
}
continue; // skip normal prompt processing
}
slot.state = SLOT_STATE_PROCESSING_PROMPT;
// print prompt tokens (for debugging) // print prompt tokens (for debugging)
/*if (1) { /*if (1) {
// first 16 tokens (avoid flooding logs) // first 16 tokens (avoid flooding logs)