memory : remove KV cache size padding (#16812)

* memory : remove KV cache size padding

* cont : restore padding for n_kv tensor shape

* server : use slot context size instead of training context size

* server : simplify context limit logic
This commit is contained in:
Georgi Gerganov 2025-10-28 20:19:44 +02:00 committed by GitHub
parent a8ca18b4b8
commit 85a7d8677b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 14 additions and 54 deletions

View File

@ -961,10 +961,14 @@ bool llama_kv_cache::get_has_shift() const {
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0; uint32_t result = 0;
// pad the n_kv value so that the graph remains constant across batches and can be reused
// note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220)
const uint32_t n_pad_cur = std::max(n_pad, 256u);
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
const auto & cells = v_cells[sinfo.strm[s]]; const auto & cells = v_cells[sinfo.strm[s]];
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result);
} }
return result; return result;
@ -2014,8 +2018,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch); kv->set_input_pos_bucket(dst, ubatch);
} }
uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
// the FA kernels require padding to avoid extra runtime boundary checks
return cparams.flash_attn ? 256u : 32u;
}

View File

@ -19,8 +19,6 @@ struct llama_context;
class llama_kv_cache : public llama_memory_i { class llama_kv_cache : public llama_memory_i {
public: public:
static uint32_t get_padding(const llama_cparams & cparams);
struct stream_copy_info { struct stream_copy_info {
bool empty() const { bool empty() const {
assert(ssrc.size() == sdst.size()); assert(ssrc.size() == sdst.size());

View File

@ -19641,7 +19641,7 @@ struct llm_build_apertus : public llm_graph_context {
} }
}; };
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const {
llama_memory_i * res; llama_memory_i * res;
switch (arch) { switch (arch) {
@ -19692,17 +19692,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
}; };
} }
const auto padding = llama_kv_cache::get_padding(cparams);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
res = new llama_memory_hybrid( res = new llama_memory_hybrid(
/* model */ *this, /* model */ *this,
/* attn_type_k */ params.type_k, /* attn_type_k */ params.type_k,
/* attn_type_v */ params.type_v, /* attn_type_v */ params.type_v,
/* attn_v_trans */ !cparams.flash_attn, /* attn_v_trans */ !cparams.flash_attn,
/* attn_kv_size */ cparams.n_ctx, /* attn_kv_size */ cparams.n_ctx,
/* attn_n_pad */ padding, /* attn_n_pad */ 1,
/* attn_n_swa */ hparams.n_swa, /* attn_n_swa */ hparams.n_swa,
/* attn_swa_type */ hparams.swa_type, /* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32, /* recurrent_type_k */ GGML_TYPE_F32,
@ -19714,23 +19710,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* filter_attn */ std::move(filter_attn), /* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr)); /* filter_recr */ std::move(filter_recr));
} else { } else {
const auto padding = llama_kv_cache::get_padding(cparams);
uint32_t n_ctx_per_stream = cparams.n_ctx; uint32_t n_ctx_per_stream = cparams.n_ctx;
if (!cparams.kv_unified) { if (!cparams.kv_unified) {
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max; n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
} else {
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
cparams.n_ctx = n_ctx_per_stream;
} }
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
llama_memory_i::layer_reuse_cb reuse = nullptr; llama_memory_i::layer_reuse_cb reuse = nullptr;
if (arch == LLM_ARCH_GEMMA3N) { if (arch == LLM_ARCH_GEMMA3N) {
@ -19757,7 +19742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
n_ctx_per_stream, n_ctx_per_stream,
cparams.n_seq_max, cparams.n_seq_max,
cparams.n_ubatch, cparams.n_ubatch,
padding, 1,
nullptr, nullptr,
reuse); reuse);
} else { } else {
@ -19772,7 +19757,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.kv_unified, cparams.kv_unified,
n_ctx_per_stream, n_ctx_per_stream,
cparams.n_seq_max, cparams.n_seq_max,
padding, 1,
hparams.n_swa, hparams.n_swa,
hparams.swa_type, hparams.swa_type,
nullptr, nullptr,

View File

@ -500,9 +500,8 @@ struct llama_model {
ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const; ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
// note: can mutate `cparams`
// TODO: move this to new llm_arch_model_i interface // TODO: move this to new llm_arch_model_i interface
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const;
// TODO: move this to new llm_arch_model_i interface // TODO: move this to new llm_arch_model_i interface
ggml_cgraph * build_graph(const llm_graph_params & params) const; ggml_cgraph * build_graph(const llm_graph_params & params) const;

View File

@ -2866,10 +2866,12 @@ struct server_context {
// if context shifting is disabled, make sure that we don't run out of context // if context shifting is disabled, make sure that we don't run out of context
if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) {
slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT; slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false; slot.has_next_token = false;
SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx); SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
} }
// check the limits // check the limits
@ -2929,16 +2931,6 @@ struct server_context {
} }
} }
// if context shift is disabled, we stop when it reaches the context limit
if (slot.n_past >= slot.n_ctx) {
slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false;
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx);
}
if (llama_vocab_is_eog(vocab, result.tok)) { if (llama_vocab_is_eog(vocab, result.tok)) {
slot.stop = STOP_TYPE_EOS; slot.stop = STOP_TYPE_EOS;
slot.has_next_token = false; slot.has_next_token = false;
@ -2946,19 +2938,6 @@ struct server_context {
SLT_DBG(slot, "%s", "stopped by EOS\n"); SLT_DBG(slot, "%s", "stopped by EOS\n");
} }
const auto n_ctx_train = llama_model_n_ctx_train(model);
if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) {
slot.truncated = true;
slot.stop = STOP_TYPE_LIMIT;
slot.has_next_token = false; // stop prediction
SLT_WRN(slot,
"n_predict (%d) is set for infinite generation. "
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
slot.task->params.n_predict, n_ctx_train);
}
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
return slot.has_next_token; // continue return slot.has_next_token; // continue

View File

@ -45,7 +45,7 @@ def test_ctx_shift_enabled():
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [ @pytest.mark.parametrize("n_predict,n_token_output,truncated", [
(64, 64, False), (64, 64, False),
(-1, 120, True), (-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total
]) ])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server global server