From d30e59b62a15ef4266a6503e3f4eba770aec001b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 2 Feb 2026 15:18:01 +0200 Subject: [PATCH] llama : add llama_memory_can_rm_suffix() --- include/llama.h | 3 +++ src/llama-context.cpp | 8 ++++++++ src/llama-kv-cache-iswa.cpp | 4 ++++ src/llama-kv-cache-iswa.h | 3 ++- src/llama-kv-cache.cpp | 4 ++++ src/llama-kv-cache.h | 3 ++- src/llama-memory-hybrid-iswa.cpp | 7 +++++-- src/llama-memory-hybrid-iswa.h | 3 ++- src/llama-memory-hybrid.cpp | 7 +++++-- src/llama-memory-hybrid.h | 3 ++- src/llama-memory-recurrent.cpp | 4 ++++ src/llama-memory-recurrent.h | 4 +++- src/llama-memory.h | 3 ++- tools/server/server-context.cpp | 4 +++- 14 files changed, 49 insertions(+), 11 deletions(-) diff --git a/include/llama.h b/include/llama.h index bf4e28a8be..1866c60e1e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -760,6 +760,9 @@ extern "C" { // Check if the memory supports shifting LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); + // Check if the memory supports removing the last tokens in the sequence + LLAMA_API bool llama_memory_can_rm_suffix(llama_memory_t mem); + // // State / sessions // diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 203852d0f1..c635c76877 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -3360,6 +3360,14 @@ bool llama_memory_can_shift(llama_memory_t mem) { return mem->get_can_shift(); } +bool llama_memory_can_rm_suffix(llama_memory_t mem) { + if (!mem) { + return false; + } + + return mem->get_can_rm_suffix(); +} + // llama state API // deprecated diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 3a34102a23..3489290736 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -221,6 +221,10 @@ bool llama_kv_cache_iswa::get_can_shift() const { return kv_base->get_size() == kv_swa->get_size(); } +bool llama_kv_cache_iswa::get_can_rm_suffix() const { + return kv_base->get_can_rm_suffix() && kv_swa->get_can_rm_suffix(); +} + void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { kv_base->state_write(io, seq_id, flags); diff --git a/src/llama-kv-cache-iswa.h b/src/llama-kv-cache-iswa.h index 70ab22f0d6..4f15463990 100644 --- a/src/llama-kv-cache-iswa.h +++ b/src/llama-kv-cache-iswa.h @@ -43,7 +43,8 @@ public: llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; - bool get_can_shift() const override; + bool get_can_shift() const override; + bool get_can_rm_suffix() const override; void clear(bool data) override; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c35cd6761b..4c0e650b24 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -977,6 +977,10 @@ bool llama_kv_cache::get_can_shift() const { return true; } +bool llama_kv_cache::get_can_rm_suffix() const { + return true; +} + uint32_t llama_kv_cache::get_size() const { const auto & cells = v_cells[seq_to_stream[0]]; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index e194bf3e26..7f4bcc976b 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -123,7 +123,8 @@ public: llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; - bool get_can_shift() const override; + bool get_can_shift() const override; + bool get_can_rm_suffix() const override; void clear(bool data) override; diff --git a/src/llama-memory-hybrid-iswa.cpp b/src/llama-memory-hybrid-iswa.cpp index 411769672a..4cbc19c19d 100644 --- a/src/llama-memory-hybrid-iswa.cpp +++ b/src/llama-memory-hybrid-iswa.cpp @@ -126,8 +126,11 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * l } bool llama_memory_hybrid_iswa::get_can_shift() const { - // Shifting is trivially supported for recurrent - return mem_attn->get_can_shift(); + return mem_attn->get_can_shift() && mem_recr->get_can_shift(); +} + +bool llama_memory_hybrid_iswa::get_can_rm_suffix() const { + return mem_attn->get_can_rm_suffix() && mem_recr->get_can_rm_suffix(); } void llama_memory_hybrid_iswa::clear(bool data) { diff --git a/src/llama-memory-hybrid-iswa.h b/src/llama-memory-hybrid-iswa.h index 807c8aac96..3f91f1408f 100644 --- a/src/llama-memory-hybrid-iswa.h +++ b/src/llama-memory-hybrid-iswa.h @@ -55,7 +55,8 @@ public: llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; - bool get_can_shift() const override; + bool get_can_shift() const override; + bool get_can_rm_suffix() const override; void clear(bool data) override; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index a1b45e4a3c..c15cfe1c15 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -120,8 +120,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, } bool llama_memory_hybrid::get_can_shift() const { - // Shifting is trivially supported for recurrent - return mem_attn->get_can_shift(); + return mem_attn->get_can_shift() && mem_recr->get_can_shift(); +} + +bool llama_memory_hybrid::get_can_rm_suffix() const { + return mem_attn->get_can_rm_suffix() && mem_recr->get_can_rm_suffix(); } void llama_memory_hybrid::clear(bool data) { diff --git a/src/llama-memory-hybrid.h b/src/llama-memory-hybrid.h index 558cafdf98..08d1a79f2b 100644 --- a/src/llama-memory-hybrid.h +++ b/src/llama-memory-hybrid.h @@ -55,7 +55,8 @@ public: llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; - bool get_can_shift() const override; + bool get_can_shift() const override; + bool get_can_rm_suffix() const override; void clear(bool data) override; diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index f0038036dc..3c0f6f8587 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -665,6 +665,10 @@ bool llama_memory_recurrent::get_can_shift() const { return true; } +bool llama_memory_recurrent::get_can_rm_suffix() const { + return false; +} + size_t llama_memory_recurrent::total_size() const { size_t size = 0; for (const auto & [_, buf] : ctxs_bufs) { diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d7391..3e67b4cdce 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -58,7 +58,9 @@ public: // find a contiguous slot of memory cells and emplace the ubatch there bool find_slot(const llama_ubatch & ubatch); - bool get_can_shift() const override; + bool get_can_shift() const override; + bool get_can_rm_suffix() const override; + // state write/load diff --git a/src/llama-memory.h b/src/llama-memory.h index 4a157b91fd..07d7341479 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -91,7 +91,8 @@ struct llama_memory_i { virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0; // getters - virtual bool get_can_shift() const = 0; + virtual bool get_can_shift() const = 0; + virtual bool get_can_rm_suffix() const = 0; // // ops diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 7f9c3c566b..42b21248a4 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -752,7 +752,7 @@ private: slot.prompt.tokens.has_mtmd = mctx != nullptr; // try speculative decoding - { + if (llama_memory_can_rm_suffix(llama_get_memory(ctx))) { slot.spec = common_speculative_init(params_base.speculative, slot.ctx); if (slot.spec) { if (mctx) { @@ -763,6 +763,8 @@ private: } else { SLT_INF(slot, "%s", "speculative decoding context not initialized\n"); } + } else { + SLT_WRN(slot, "%s", "speculative decoding not supported by this context (no memory_rm_suffix support)\n"); } SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);