server : add encoder-decoder model support (T5, BART, MADLAD)
Add support for encoder-decoder models in llama-server, matching the behavior of llama-cli. This enables translation models like MADLAD and other T5-based models to work with the server. Changes: - Add has_encoder flag to detect encoder-decoder models at load time - Implement llama_encode() call for encoder-decoder prompt processing - Use decoder_start_token to initialize decoder after encoding - Clear decoder KV cache before each new request (no prefix caching) - Disable incompatible features for encoder-decoder models: - Context shift (encoder outputs are fixed) - Speculative decoding (not supported) - Prompt caching (encoder outputs depend on entire input) - Slot selection by LCP similarity (meaningless for enc-dec) - Add edge case handling for empty text tokens The encoder processes the full prompt, then the decoder generates output using cross-attention to the encoder's hidden states.
This commit is contained in:
parent
a81a569577
commit
52c283a951
|
|
@ -35,8 +35,8 @@ constexpr int HTTP_POLLING_SECONDS = 1;
|
||||||
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
|
||||||
enum slot_state {
|
enum slot_state {
|
||||||
SLOT_STATE_IDLE,
|
SLOT_STATE_IDLE,
|
||||||
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
|
SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt
|
||||||
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
|
SLOT_STATE_STARTED, // after assigning a task and about to process prompt
|
||||||
SLOT_STATE_PROCESSING_PROMPT,
|
SLOT_STATE_PROCESSING_PROMPT,
|
||||||
SLOT_STATE_DONE_PROMPT,
|
SLOT_STATE_DONE_PROMPT,
|
||||||
SLOT_STATE_GENERATING,
|
SLOT_STATE_GENERATING,
|
||||||
|
|
@ -529,6 +529,7 @@ struct server_context_impl {
|
||||||
llama_batch batch {};
|
llama_batch batch {};
|
||||||
|
|
||||||
bool add_bos_token = true;
|
bool add_bos_token = true;
|
||||||
|
bool has_encoder = false; // true if model is encoder-decoder (e.g., T5, BART)
|
||||||
|
|
||||||
int32_t n_ctx; // total context for all clients / slots
|
int32_t n_ctx; // total context for all clients / slots
|
||||||
|
|
||||||
|
|
@ -593,6 +594,23 @@ struct server_context_impl {
|
||||||
n_ctx = llama_n_ctx(ctx);
|
n_ctx = llama_n_ctx(ctx);
|
||||||
|
|
||||||
add_bos_token = llama_vocab_get_add_bos(vocab);
|
add_bos_token = llama_vocab_get_add_bos(vocab);
|
||||||
|
has_encoder = llama_model_has_encoder(model);
|
||||||
|
|
||||||
|
if (has_encoder) {
|
||||||
|
SRV_INF("model has encoder - encoder-decoder mode enabled (e.g., T5, BART)%s\n", "");
|
||||||
|
|
||||||
|
// warn about incompatible features
|
||||||
|
if (params_base.ctx_shift) {
|
||||||
|
SRV_WRN("encoder-decoder models do not support context shift - disabling%s\n", "");
|
||||||
|
params_base.ctx_shift = false;
|
||||||
|
}
|
||||||
|
// Note: prompt caching is disabled for encoder-decoder models
|
||||||
|
// (encoder outputs depend on entire input, prefix caching doesn't apply)
|
||||||
|
if (params_base.has_speculative()) {
|
||||||
|
SRV_WRN("encoder-decoder models do not support speculative decoding - ignoring draft model%s\n", "");
|
||||||
|
// Note: speculative setup continues below but won't be used for enc-dec slots
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (params_base.has_speculative()) {
|
if (params_base.has_speculative()) {
|
||||||
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
|
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());
|
||||||
|
|
@ -726,7 +744,8 @@ struct server_context_impl {
|
||||||
slot.mctx = mctx;
|
slot.mctx = mctx;
|
||||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||||
|
|
||||||
if (model_dft) {
|
// speculative decoding is not supported for encoder-decoder models
|
||||||
|
if (model_dft && !has_encoder) {
|
||||||
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
||||||
|
|
||||||
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
|
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
|
||||||
|
|
@ -841,7 +860,8 @@ struct server_context_impl {
|
||||||
bool update_cache = false;
|
bool update_cache = false;
|
||||||
|
|
||||||
// find the slot that has at least n% prompt similarity
|
// find the slot that has at least n% prompt similarity
|
||||||
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
|
// skip for encoder-decoder models - slot.prompt.tokens only contains decoder tokens
|
||||||
|
if (ret == nullptr && slot_prompt_similarity != 0.0f && !has_encoder) {
|
||||||
float sim_best = 0;
|
float sim_best = 0;
|
||||||
|
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
|
|
@ -919,6 +939,11 @@ struct server_context_impl {
|
||||||
// TODO: mtmd does not support prompt cache
|
// TODO: mtmd does not support prompt cache
|
||||||
update_cache = update_cache && (ret->mctx == nullptr);
|
update_cache = update_cache && (ret->mctx == nullptr);
|
||||||
|
|
||||||
|
// encoder-decoder models don't support prompt caching:
|
||||||
|
// - encoder outputs depend on the entire input, not just a prefix
|
||||||
|
// - we always clear the decoder KV cache and re-encode
|
||||||
|
update_cache = update_cache && !has_encoder;
|
||||||
|
|
||||||
if (update_cache) {
|
if (update_cache) {
|
||||||
SRV_WRN("%s", "updating prompt cache\n");
|
SRV_WRN("%s", "updating prompt cache\n");
|
||||||
|
|
||||||
|
|
@ -1928,11 +1953,101 @@ struct server_context_impl {
|
||||||
slot.t_start_process_prompt = ggml_time_us();
|
slot.t_start_process_prompt = ggml_time_us();
|
||||||
slot.t_start_generation = 0;
|
slot.t_start_generation = 0;
|
||||||
|
|
||||||
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
|
||||||
|
|
||||||
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
|
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n",
|
||||||
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
|
slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens());
|
||||||
|
|
||||||
|
// encoder-decoder model handling (e.g., T5, BART, MADLAD)
|
||||||
|
if (has_encoder) {
|
||||||
|
SLT_INF(slot, "encoder-decoder model: encoding %d tokens\n", slot.task->n_tokens());
|
||||||
|
|
||||||
|
// clear the decoder KV cache for this slot - encoder-decoder models
|
||||||
|
// don't support prefix caching, so we always start fresh
|
||||||
|
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
|
||||||
|
slot.prompt.tokens.clear();
|
||||||
|
|
||||||
|
// empty prompt check
|
||||||
|
if (input_tokens.empty()) {
|
||||||
|
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
||||||
|
slot.print_timings();
|
||||||
|
send_final_response(slot);
|
||||||
|
slot.release();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the text tokens for encoding
|
||||||
|
const llama_tokens & text_tokens = input_tokens.get_text_tokens();
|
||||||
|
|
||||||
|
// check for empty text tokens (could happen with multimodal-only input)
|
||||||
|
if (text_tokens.empty()) {
|
||||||
|
SLT_ERR(slot, "%s", "encoder-decoder models require text tokens\n");
|
||||||
|
send_error(slot, "encoder-decoder models require text input", ERROR_TYPE_INVALID_REQUEST);
|
||||||
|
slot.release();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// build encoder batch with all prompt tokens
|
||||||
|
// Note: we need to allocate a proper batch with seq_id support
|
||||||
|
llama_batch batch_enc = llama_batch_init(text_tokens.size(), 0, 1);
|
||||||
|
batch_enc.n_tokens = text_tokens.size();
|
||||||
|
|
||||||
|
for (size_t i = 0; i < text_tokens.size(); i++) {
|
||||||
|
batch_enc.token[i] = text_tokens[i];
|
||||||
|
batch_enc.pos[i] = i;
|
||||||
|
batch_enc.n_seq_id[i] = 1;
|
||||||
|
batch_enc.seq_id[i][0] = slot.id;
|
||||||
|
batch_enc.logits[i] = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// encode the entire prompt
|
||||||
|
const int ret = llama_encode(ctx, batch_enc);
|
||||||
|
|
||||||
|
// free the encoder batch
|
||||||
|
llama_batch_free(batch_enc);
|
||||||
|
if (ret != 0) {
|
||||||
|
SLT_ERR(slot, "llama_encode() failed with error %d\n", ret);
|
||||||
|
send_error(slot, "encoder failed", ERROR_TYPE_SERVER);
|
||||||
|
slot.release();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
SLT_INF(slot, "encoder completed, %d tokens encoded\n", slot.task->n_tokens());
|
||||||
|
|
||||||
|
// get decoder start token
|
||||||
|
llama_token decoder_start_token = llama_model_decoder_start_token(model);
|
||||||
|
if (decoder_start_token == LLAMA_TOKEN_NULL) {
|
||||||
|
decoder_start_token = llama_vocab_bos(vocab);
|
||||||
|
}
|
||||||
|
|
||||||
|
SLT_DBG(slot, "decoder start token: %d '%s'\n",
|
||||||
|
decoder_start_token, common_token_to_piece(ctx, decoder_start_token).c_str());
|
||||||
|
|
||||||
|
// add decoder start token to the batch
|
||||||
|
common_batch_add(batch, decoder_start_token, 0, { slot.id }, true);
|
||||||
|
|
||||||
|
// update slot state - we've processed all prompt tokens (via encoder)
|
||||||
|
// and the decoder is ready to generate
|
||||||
|
slot.prompt.tokens.clear();
|
||||||
|
slot.prompt.tokens.push_back(decoder_start_token);
|
||||||
|
slot.n_prompt_tokens_processed = slot.task->n_tokens();
|
||||||
|
|
||||||
|
common_sampler_reset(slot.smpl);
|
||||||
|
|
||||||
|
slot.n_decoded = 0;
|
||||||
|
slot.i_batch = batch.n_tokens - 1;
|
||||||
|
|
||||||
|
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||||
|
|
||||||
|
SLT_INF(slot, "encoder-decoder: prompt encoded, decoder ready%s\n", "");
|
||||||
|
|
||||||
|
if (!slot_batched) {
|
||||||
|
slot_batched = &slot;
|
||||||
|
}
|
||||||
|
|
||||||
|
continue; // skip normal prompt processing
|
||||||
|
}
|
||||||
|
|
||||||
|
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
||||||
|
|
||||||
// print prompt tokens (for debugging)
|
// print prompt tokens (for debugging)
|
||||||
/*if (1) {
|
/*if (1) {
|
||||||
// first 16 tokens (avoid flooding logs)
|
// first 16 tokens (avoid flooding logs)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue