Merge 2d55904a15 into 18ddaea2ae
This commit is contained in:
commit
d8f720b472
|
|
@ -269,23 +269,6 @@ static void reset_long_term_states(const bool clear_kv_cache = true) {
|
|||
llama_memory_clear(llama_get_memory(g_context), false);
|
||||
}
|
||||
|
||||
/**
|
||||
* TODO-hyin: implement sliding-window version as a better alternative
|
||||
*
|
||||
* Context shifting by discarding the older half of the tokens appended after system prompt:
|
||||
* - take the [system_prompt_position] first tokens from the original prompt
|
||||
* - take half of the last (system_prompt_position - system_prompt_position) tokens
|
||||
* - recompute the logits in batches
|
||||
*/
|
||||
static void shift_context() {
|
||||
const int n_discard = (current_position - system_prompt_position) / 2;
|
||||
LOGi("%s: Discarding %d tokens", __func__, n_discard);
|
||||
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
|
||||
current_position -= n_discard;
|
||||
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
|
||||
}
|
||||
|
||||
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
||||
common_chat_msg new_msg;
|
||||
new_msg.role = role;
|
||||
|
|
@ -313,6 +296,32 @@ static void reset_short_term_states() {
|
|||
assistant_ss.str("");
|
||||
}
|
||||
|
||||
/**
|
||||
* TODO-hyin: implement sliding-window version as a better alternative
|
||||
*
|
||||
* Context shifting by discarding the older half of the tokens appended after system prompt:
|
||||
* - take the [keep_recent] first tokens from the original prompt
|
||||
* - take half of the last (current_position - keep_first) tokens
|
||||
* - recompute the logits in batches
|
||||
*
|
||||
* attention_sink: keep the first 4 tokens to maintain attention.
|
||||
*/
|
||||
static void shift_context() {
|
||||
const int attention_sink = 4;
|
||||
const int keep_first = std::max(system_prompt_position, attention_sink);
|
||||
const int n_discard = (current_position - keep_first) / 2;
|
||||
if (n_discard <= 0) {
|
||||
LOGi("%s: n_discard <= 0", __func__);
|
||||
return;
|
||||
}
|
||||
LOGi("%s: Discarding %d tokens", __func__, n_discard);
|
||||
llama_memory_seq_rm(llama_get_memory(g_context), 0, keep_first, keep_first + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(g_context), 0, keep_first + n_discard, -1, -n_discard);
|
||||
current_position -= n_discard;
|
||||
stop_generation_position -= n_discard;
|
||||
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
|
||||
}
|
||||
|
||||
static int decode_tokens_in_batches(
|
||||
llama_context *context,
|
||||
llama_batch &batch,
|
||||
|
|
|
|||
Loading…
Reference in New Issue