From 52c283a951f2c6fe3c3ecc3ce802b6421577fd95 Mon Sep 17 00:00:00 2001 From: Turkka Mannila Date: Fri, 12 Dec 2025 11:47:23 +0200 Subject: [PATCH] server : add encoder-decoder model support (T5, BART, MADLAD) Add support for encoder-decoder models in llama-server, matching the behavior of llama-cli. This enables translation models like MADLAD and other T5-based models to work with the server. Changes: - Add has_encoder flag to detect encoder-decoder models at load time - Implement llama_encode() call for encoder-decoder prompt processing - Use decoder_start_token to initialize decoder after encoding - Clear decoder KV cache before each new request (no prefix caching) - Disable incompatible features for encoder-decoder models: - Context shift (encoder outputs are fixed) - Speculative decoding (not supported) - Prompt caching (encoder outputs depend on entire input) - Slot selection by LCP similarity (meaningless for enc-dec) - Add edge case handling for empty text tokens The encoder processes the full prompt, then the decoder generates output using cross-attention to the encoder's hidden states. --- tools/server/server-context.cpp | 127 ++++++++++++++++++++++++++++++-- 1 file changed, 121 insertions(+), 6 deletions(-) diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 5a67f508df..f45a6d36c1 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -35,8 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1; // state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 enum slot_state { SLOT_STATE_IDLE, - SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt - SLOT_STATE_STARTED, // after assigning a task and about to process prompt + SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt + SLOT_STATE_STARTED, // after assigning a task and about to process prompt SLOT_STATE_PROCESSING_PROMPT, SLOT_STATE_DONE_PROMPT, SLOT_STATE_GENERATING, @@ -529,6 +529,7 @@ struct server_context_impl { llama_batch batch {}; bool add_bos_token = true; + bool has_encoder = false; // true if model is encoder-decoder (e.g., T5, BART) int32_t n_ctx; // total context for all clients / slots @@ -593,6 +594,23 @@ struct server_context_impl { n_ctx = llama_n_ctx(ctx); add_bos_token = llama_vocab_get_add_bos(vocab); + has_encoder = llama_model_has_encoder(model); + + if (has_encoder) { + SRV_INF("model has encoder - encoder-decoder mode enabled (e.g., T5, BART)%s\n", ""); + + // warn about incompatible features + if (params_base.ctx_shift) { + SRV_WRN("encoder-decoder models do not support context shift - disabling%s\n", ""); + params_base.ctx_shift = false; + } + // Note: prompt caching is disabled for encoder-decoder models + // (encoder outputs depend on entire input, prefix caching doesn't apply) + if (params_base.has_speculative()) { + SRV_WRN("encoder-decoder models do not support speculative decoding - ignoring draft model%s\n", ""); + // Note: speculative setup continues below but won't be used for enc-dec slots + } + } if (params_base.has_speculative()) { SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); @@ -726,7 +744,8 @@ struct server_context_impl { slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; - if (model_dft) { + // speculative decoding is not supported for encoder-decoder models + if (model_dft && !has_encoder) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] @@ -841,7 +860,8 @@ struct server_context_impl { bool update_cache = false; // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f) { + // skip for encoder-decoder models - slot.prompt.tokens only contains decoder tokens + if (ret == nullptr && slot_prompt_similarity != 0.0f && !has_encoder) { float sim_best = 0; for (server_slot & slot : slots) { @@ -919,6 +939,11 @@ struct server_context_impl { // TODO: mtmd does not support prompt cache update_cache = update_cache && (ret->mctx == nullptr); + // encoder-decoder models don't support prompt caching: + // - encoder outputs depend on the entire input, not just a prefix + // - we always clear the decoder KV cache and re-encode + update_cache = update_cache && !has_encoder; + if (update_cache) { SRV_WRN("%s", "updating prompt cache\n"); @@ -1928,11 +1953,101 @@ struct server_context_impl { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - slot.state = SLOT_STATE_PROCESSING_PROMPT; - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n", slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); + // encoder-decoder model handling (e.g., T5, BART, MADLAD) + if (has_encoder) { + SLT_INF(slot, "encoder-decoder model: encoding %d tokens\n", slot.task->n_tokens()); + + // clear the decoder KV cache for this slot - encoder-decoder models + // don't support prefix caching, so we always start fresh + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + slot.prompt.tokens.clear(); + + // empty prompt check + if (input_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); + slot.print_timings(); + send_final_response(slot); + slot.release(); + continue; + } + + // get the text tokens for encoding + const llama_tokens & text_tokens = input_tokens.get_text_tokens(); + + // check for empty text tokens (could happen with multimodal-only input) + if (text_tokens.empty()) { + SLT_ERR(slot, "%s", "encoder-decoder models require text tokens\n"); + send_error(slot, "encoder-decoder models require text input", ERROR_TYPE_INVALID_REQUEST); + slot.release(); + continue; + } + + // build encoder batch with all prompt tokens + // Note: we need to allocate a proper batch with seq_id support + llama_batch batch_enc = llama_batch_init(text_tokens.size(), 0, 1); + batch_enc.n_tokens = text_tokens.size(); + + for (size_t i = 0; i < text_tokens.size(); i++) { + batch_enc.token[i] = text_tokens[i]; + batch_enc.pos[i] = i; + batch_enc.n_seq_id[i] = 1; + batch_enc.seq_id[i][0] = slot.id; + batch_enc.logits[i] = false; + } + + // encode the entire prompt + const int ret = llama_encode(ctx, batch_enc); + + // free the encoder batch + llama_batch_free(batch_enc); + if (ret != 0) { + SLT_ERR(slot, "llama_encode() failed with error %d\n", ret); + send_error(slot, "encoder failed", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + SLT_INF(slot, "encoder completed, %d tokens encoded\n", slot.task->n_tokens()); + + // get decoder start token + llama_token decoder_start_token = llama_model_decoder_start_token(model); + if (decoder_start_token == LLAMA_TOKEN_NULL) { + decoder_start_token = llama_vocab_bos(vocab); + } + + SLT_DBG(slot, "decoder start token: %d '%s'\n", + decoder_start_token, common_token_to_piece(ctx, decoder_start_token).c_str()); + + // add decoder start token to the batch + common_batch_add(batch, decoder_start_token, 0, { slot.id }, true); + + // update slot state - we've processed all prompt tokens (via encoder) + // and the decoder is ready to generate + slot.prompt.tokens.clear(); + slot.prompt.tokens.push_back(decoder_start_token); + slot.n_prompt_tokens_processed = slot.task->n_tokens(); + + common_sampler_reset(slot.smpl); + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + slot.state = SLOT_STATE_DONE_PROMPT; + + SLT_INF(slot, "encoder-decoder: prompt encoded, decoder ready%s\n", ""); + + if (!slot_batched) { + slot_batched = &slot; + } + + continue; // skip normal prompt processing + } + + slot.state = SLOT_STATE_PROCESSING_PROMPT; + // print prompt tokens (for debugging) /*if (1) { // first 16 tokens (avoid flooding logs)