This commit is contained in:
Karlon 2026-01-02 15:16:00 -08:00 committed by GitHub
commit d8f720b472
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 26 additions and 17 deletions

View File

@ -269,23 +269,6 @@ static void reset_long_term_states(const bool clear_kv_cache = true) {
llama_memory_clear(llama_get_memory(g_context), false);
}
/**
* TODO-hyin: implement sliding-window version as a better alternative
*
* Context shifting by discarding the older half of the tokens appended after system prompt:
* - take the [system_prompt_position] first tokens from the original prompt
* - take half of the last (system_prompt_position - system_prompt_position) tokens
* - recompute the logits in batches
*/
static void shift_context() {
const int n_discard = (current_position - system_prompt_position) / 2;
LOGi("%s: Discarding %d tokens", __func__, n_discard);
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
current_position -= n_discard;
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
}
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
common_chat_msg new_msg;
new_msg.role = role;
@ -313,6 +296,32 @@ static void reset_short_term_states() {
assistant_ss.str("");
}
/**
* TODO-hyin: implement sliding-window version as a better alternative
*
* Context shifting by discarding the older half of the tokens appended after system prompt:
* - take the [keep_recent] first tokens from the original prompt
* - take half of the last (current_position - keep_first) tokens
* - recompute the logits in batches
*
* attention_sink: keep the first 4 tokens to maintain attention.
*/
static void shift_context() {
const int attention_sink = 4;
const int keep_first = std::max(system_prompt_position, attention_sink);
const int n_discard = (current_position - keep_first) / 2;
if (n_discard <= 0) {
LOGi("%s: n_discard <= 0", __func__);
return;
}
LOGi("%s: Discarding %d tokens", __func__, n_discard);
llama_memory_seq_rm(llama_get_memory(g_context), 0, keep_first, keep_first + n_discard);
llama_memory_seq_add(llama_get_memory(g_context), 0, keep_first + n_discard, -1, -n_discard);
current_position -= n_discard;
stop_generation_position -= n_discard;
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
}
static int decode_tokens_in_batches(
llama_context *context,
llama_batch &batch,