diff --git a/examples/llama.android/lib/src/main/cpp/ai_chat.cpp b/examples/llama.android/lib/src/main/cpp/ai_chat.cpp index 9e460ac198..a37b6b6343 100644 --- a/examples/llama.android/lib/src/main/cpp/ai_chat.cpp +++ b/examples/llama.android/lib/src/main/cpp/ai_chat.cpp @@ -269,23 +269,6 @@ static void reset_long_term_states(const bool clear_kv_cache = true) { llama_memory_clear(llama_get_memory(g_context), false); } -/** - * TODO-hyin: implement sliding-window version as a better alternative - * - * Context shifting by discarding the older half of the tokens appended after system prompt: - * - take the [system_prompt_position] first tokens from the original prompt - * - take half of the last (system_prompt_position - system_prompt_position) tokens - * - recompute the logits in batches - */ -static void shift_context() { - const int n_discard = (current_position - system_prompt_position) / 2; - LOGi("%s: Discarding %d tokens", __func__, n_discard); - llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard); - llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard); - current_position -= n_discard; - LOGi("%s: Context shifting done! Current position: %d", __func__, current_position); -} - static std::string chat_add_and_format(const std::string &role, const std::string &content) { common_chat_msg new_msg; new_msg.role = role; @@ -313,6 +296,32 @@ static void reset_short_term_states() { assistant_ss.str(""); } +/** + * TODO-hyin: implement sliding-window version as a better alternative + * + * Context shifting by discarding the older half of the tokens appended after system prompt: + * - take the [keep_recent] first tokens from the original prompt + * - take half of the last (current_position - keep_first) tokens + * - recompute the logits in batches + * + * attention_sink: keep the first 4 tokens to maintain attention. + */ +static void shift_context() { + const int attention_sink = 4; + const int keep_first = std::max(system_prompt_position, attention_sink); + const int n_discard = (current_position - keep_first) / 2; + if (n_discard <= 0) { + LOGi("%s: n_discard <= 0", __func__); + return; + } + LOGi("%s: Discarding %d tokens", __func__, n_discard); + llama_memory_seq_rm(llama_get_memory(g_context), 0, keep_first, keep_first + n_discard); + llama_memory_seq_add(llama_get_memory(g_context), 0, keep_first + n_discard, -1, -n_discard); + current_position -= n_discard; + stop_generation_position -= n_discard; + LOGi("%s: Context shifting done! Current position: %d", __func__, current_position); +} + static int decode_tokens_in_batches( llama_context *context, llama_batch &batch,