llama : add llama_memory_can_rm_suffix()
This commit is contained in:
parent
6fdddb4987
commit
d30e59b62a
|
|
@ -760,6 +760,9 @@ extern "C" {
|
||||||
// Check if the memory supports shifting
|
// Check if the memory supports shifting
|
||||||
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
|
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
|
||||||
|
|
||||||
|
// Check if the memory supports removing the last tokens in the sequence
|
||||||
|
LLAMA_API bool llama_memory_can_rm_suffix(llama_memory_t mem);
|
||||||
|
|
||||||
//
|
//
|
||||||
// State / sessions
|
// State / sessions
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -3360,6 +3360,14 @@ bool llama_memory_can_shift(llama_memory_t mem) {
|
||||||
return mem->get_can_shift();
|
return mem->get_can_shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_memory_can_rm_suffix(llama_memory_t mem) {
|
||||||
|
if (!mem) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return mem->get_can_rm_suffix();
|
||||||
|
}
|
||||||
|
|
||||||
// llama state API
|
// llama state API
|
||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
|
|
|
||||||
|
|
@ -221,6 +221,10 @@ bool llama_kv_cache_iswa::get_can_shift() const {
|
||||||
return kv_base->get_size() == kv_swa->get_size();
|
return kv_base->get_size() == kv_swa->get_size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_kv_cache_iswa::get_can_rm_suffix() const {
|
||||||
|
return kv_base->get_can_rm_suffix() && kv_swa->get_can_rm_suffix();
|
||||||
|
}
|
||||||
|
|
||||||
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||||
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
|
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
|
||||||
kv_base->state_write(io, seq_id, flags);
|
kv_base->state_write(io, seq_id, flags);
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,8 @@ public:
|
||||||
|
|
||||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
bool get_can_rm_suffix() const override;
|
||||||
|
|
||||||
void clear(bool data) override;
|
void clear(bool data) override;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -977,6 +977,10 @@ bool llama_kv_cache::get_can_shift() const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_kv_cache::get_can_rm_suffix() const {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache::get_size() const {
|
uint32_t llama_kv_cache::get_size() const {
|
||||||
const auto & cells = v_cells[seq_to_stream[0]];
|
const auto & cells = v_cells[seq_to_stream[0]];
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,8 @@ public:
|
||||||
|
|
||||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
bool get_can_rm_suffix() const override;
|
||||||
|
|
||||||
void clear(bool data) override;
|
void clear(bool data) override;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -126,8 +126,11 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * l
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_hybrid_iswa::get_can_shift() const {
|
bool llama_memory_hybrid_iswa::get_can_shift() const {
|
||||||
// Shifting is trivially supported for recurrent
|
return mem_attn->get_can_shift() && mem_recr->get_can_shift();
|
||||||
return mem_attn->get_can_shift();
|
}
|
||||||
|
|
||||||
|
bool llama_memory_hybrid_iswa::get_can_rm_suffix() const {
|
||||||
|
return mem_attn->get_can_rm_suffix() && mem_recr->get_can_rm_suffix();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_memory_hybrid_iswa::clear(bool data) {
|
void llama_memory_hybrid_iswa::clear(bool data) {
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,8 @@ public:
|
||||||
|
|
||||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
bool get_can_rm_suffix() const override;
|
||||||
|
|
||||||
void clear(bool data) override;
|
void clear(bool data) override;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -120,8 +120,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx,
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_memory_hybrid::get_can_shift() const {
|
bool llama_memory_hybrid::get_can_shift() const {
|
||||||
// Shifting is trivially supported for recurrent
|
return mem_attn->get_can_shift() && mem_recr->get_can_shift();
|
||||||
return mem_attn->get_can_shift();
|
}
|
||||||
|
|
||||||
|
bool llama_memory_hybrid::get_can_rm_suffix() const {
|
||||||
|
return mem_attn->get_can_rm_suffix() && mem_recr->get_can_rm_suffix();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_memory_hybrid::clear(bool data) {
|
void llama_memory_hybrid::clear(bool data) {
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,8 @@ public:
|
||||||
|
|
||||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
bool get_can_rm_suffix() const override;
|
||||||
|
|
||||||
void clear(bool data) override;
|
void clear(bool data) override;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -665,6 +665,10 @@ bool llama_memory_recurrent::get_can_shift() const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_memory_recurrent::get_can_rm_suffix() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
size_t llama_memory_recurrent::total_size() const {
|
size_t llama_memory_recurrent::total_size() const {
|
||||||
size_t size = 0;
|
size_t size = 0;
|
||||||
for (const auto & [_, buf] : ctxs_bufs) {
|
for (const auto & [_, buf] : ctxs_bufs) {
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,9 @@ public:
|
||||||
// find a contiguous slot of memory cells and emplace the ubatch there
|
// find a contiguous slot of memory cells and emplace the ubatch there
|
||||||
bool find_slot(const llama_ubatch & ubatch);
|
bool find_slot(const llama_ubatch & ubatch);
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
bool get_can_rm_suffix() const override;
|
||||||
|
|
||||||
|
|
||||||
// state write/load
|
// state write/load
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,8 @@ struct llama_memory_i {
|
||||||
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
||||||
|
|
||||||
// getters
|
// getters
|
||||||
virtual bool get_can_shift() const = 0;
|
virtual bool get_can_shift() const = 0;
|
||||||
|
virtual bool get_can_rm_suffix() const = 0;
|
||||||
|
|
||||||
//
|
//
|
||||||
// ops
|
// ops
|
||||||
|
|
|
||||||
|
|
@ -752,7 +752,7 @@ private:
|
||||||
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
slot.prompt.tokens.has_mtmd = mctx != nullptr;
|
||||||
|
|
||||||
// try speculative decoding
|
// try speculative decoding
|
||||||
{
|
if (llama_memory_can_rm_suffix(llama_get_memory(ctx))) {
|
||||||
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
|
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
|
||||||
if (slot.spec) {
|
if (slot.spec) {
|
||||||
if (mctx) {
|
if (mctx) {
|
||||||
|
|
@ -763,6 +763,8 @@ private:
|
||||||
} else {
|
} else {
|
||||||
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
|
SLT_INF(slot, "%s", "speculative decoding context not initialized\n");
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
SLT_WRN(slot, "%s", "speculative decoding not supported by this context (no memory_rm_suffix support)\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue