Feature: implement infinite conversation via context shifting

This commit is contained in:
Han Yin 2025-04-08 13:32:15 -07:00
parent 4e515727b4
commit ec502cfde9
1 changed files with 64 additions and 35 deletions

View File

@ -311,8 +311,34 @@ constexpr const char *ROLE_SYSTEM = "system";
constexpr const char *ROLE_USER = "user";
constexpr const char *ROLE_ASSISTANT = "assistant";
static llama_pos current_position;
static std::vector<common_chat_msg> chat_msgs;
static llama_pos system_prompt_position;
static llama_pos current_position;
static void reset_long_term_states(const bool clear_kv_cache = true) {
chat_msgs.clear();
system_prompt_position = 0;
current_position = 0;
if (clear_kv_cache)
llama_memory_clear(llama_get_memory(g_context), false);
}
/**
* Context shifting by discarding the older half of the tokens appended after system prompt:
* - take the [system_prompt_position] first tokens from the original prompt
* - take half of the last (system_prompt_position - system_prompt_position) tokens
* - recompute the logits in batches
*/
static void shift_context() {
const int n_discard = (current_position - system_prompt_position) / 2;
LOGi("Discarding %d tokens", n_discard);
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
current_position -= n_discard;
LOGi("Context shifting done! Current position: %d", current_position);
}
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
common_chat_msg new_msg;
@ -334,20 +360,26 @@ static int decode_tokens_in_batches(
// Process tokens in batches using the global batch
LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos);
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
common_batch_clear(*batch);
LOGv("Preparing a batch size of %d starting at: %d", cur_batch_size, i);
// Shift context if current batch cannot fit into the context
if (start_pos + i + cur_batch_size >= CONTEXT_SIZE - OVERFLOW_HEADROOM) {
LOGw("Current batch won't fit into context! Shifting...");
shift_context();
}
// Add tokens to the batch with proper positions
for (int j = 0; j < cur_batch_size; j++) {
llama_token token_id = tokens[i + j];
llama_pos position = start_pos + i + j;
bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
const llama_token token_id = tokens[i + j];
const llama_pos position = start_pos + i + j;
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
common_batch_add(*batch, token_id, position, {0}, want_logit);
}
// Decode this batch
int decode_result = llama_decode(context, *batch);
const int decode_result = llama_decode(context, *batch);
if (decode_result) {
LOGe("llama_decode failed w/ %d", decode_result);
return 1;
@ -359,10 +391,16 @@ static int decode_tokens_in_batches(
/**
* Prediction loop's short-term states
*/
static llama_pos token_predict_budget;
static llama_pos stop_completion_position;
static std::string cached_token_chars;
static std::ostringstream assistant_ss; // For storing current assistant message
static void reset_short_term_states() {
stop_completion_position = 0;
cached_token_chars.clear();
assistant_ss.str("");
}
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
@ -370,14 +408,9 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
jobject /*unused*/,
jstring jsystem_prompt
) {
// Reset long-term states and reset KV cache
current_position = 0;
llama_memory_clear(llama_get_memory(g_context), false);
// Reset short-term states
token_predict_budget = 0;
cached_token_chars.clear();
assistant_ss.str("");
// Reset long-term & short-term states
reset_long_term_states();
reset_short_term_states();
// Obtain system prompt from JEnv
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
@ -413,11 +446,10 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
}
// Update position
current_position = (int) system_tokens.size();
system_prompt_position = current_position = (int) system_tokens.size();
return 0;
}
// TODO-hyin: support KV cache backtracking
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
@ -427,9 +459,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
jint n_predict
) {
// Reset short-term states
token_predict_budget = 0;
cached_token_chars.clear();
assistant_ss.str("");
reset_short_term_states();
// Obtain and tokenize user prompt
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
@ -450,22 +480,14 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
}
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
const int user_prompt_size = (int) user_tokens.size();
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
if ((int) user_tokens.size() > max_batch_size) {
const int skipped_tokens = (int) user_tokens.size() - max_batch_size;
if (user_prompt_size > max_batch_size) {
const int skipped_tokens = user_prompt_size - max_batch_size;
user_tokens.resize(max_batch_size);
LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens);
}
// TODO-hyin: implement context shifting
// Check if context space is enough for desired tokens
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
if (desired_budget > max_batch_size) {
LOGe("Not enough context! %d total tokens, max: %d", desired_budget, max_batch_size);
return 1;
}
token_predict_budget = desired_budget;
// Decode user tokens in batches
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
LOGe("llama_decode() failed!");
@ -473,7 +495,8 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
}
// Update position
current_position += (int) user_tokens.size();
current_position += user_prompt_size;
stop_completion_position = current_position + user_prompt_size + n_predict;
return 0;
}
@ -517,9 +540,15 @@ Java_android_llama_cpp_LLamaAndroid_completionLoop(
JNIEnv *env,
jobject /*unused*/
) {
// Stop if running out of token budget
if (current_position >= token_predict_budget) {
LOGw("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget);
// Infinite text generation via context shifting
if (current_position >= CONTEXT_SIZE - OVERFLOW_HEADROOM) {
LOGw("Context full! Shifting...");
shift_context();
}
// Stop if reaching the marked position
if (current_position >= stop_completion_position) {
LOGw("STOP: hitting stop position: %d", stop_completion_position);
return nullptr;
}