diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 344308763f..019caacc4a 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -311,8 +311,34 @@ constexpr const char *ROLE_SYSTEM = "system"; constexpr const char *ROLE_USER = "user"; constexpr const char *ROLE_ASSISTANT = "assistant"; -static llama_pos current_position; static std::vector chat_msgs; +static llama_pos system_prompt_position; +static llama_pos current_position; + +static void reset_long_term_states(const bool clear_kv_cache = true) { + chat_msgs.clear(); + system_prompt_position = 0; + current_position = 0; + + if (clear_kv_cache) + llama_memory_clear(llama_get_memory(g_context), false); +} + +/** + * Context shifting by discarding the older half of the tokens appended after system prompt: + * - take the [system_prompt_position] first tokens from the original prompt + * - take half of the last (system_prompt_position - system_prompt_position) tokens + * - recompute the logits in batches + */ +static void shift_context() { + const int n_discard = (current_position - system_prompt_position) / 2; + LOGi("Discarding %d tokens", n_discard); + + llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard); + llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard); + current_position -= n_discard; + LOGi("Context shifting done! Current position: %d", current_position); +} static std::string chat_add_and_format(const std::string &role, const std::string &content) { common_chat_msg new_msg; @@ -334,20 +360,26 @@ static int decode_tokens_in_batches( // Process tokens in batches using the global batch LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos); for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) { - int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE); + const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE); common_batch_clear(*batch); LOGv("Preparing a batch size of %d starting at: %d", cur_batch_size, i); + // Shift context if current batch cannot fit into the context + if (start_pos + i + cur_batch_size >= CONTEXT_SIZE - OVERFLOW_HEADROOM) { + LOGw("Current batch won't fit into context! Shifting..."); + shift_context(); + } + // Add tokens to the batch with proper positions for (int j = 0; j < cur_batch_size; j++) { - llama_token token_id = tokens[i + j]; - llama_pos position = start_pos + i + j; - bool want_logit = compute_last_logit && (i + j == tokens.size() - 1); + const llama_token token_id = tokens[i + j]; + const llama_pos position = start_pos + i + j; + const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1); common_batch_add(*batch, token_id, position, {0}, want_logit); } // Decode this batch - int decode_result = llama_decode(context, *batch); + const int decode_result = llama_decode(context, *batch); if (decode_result) { LOGe("llama_decode failed w/ %d", decode_result); return 1; @@ -359,10 +391,16 @@ static int decode_tokens_in_batches( /** * Prediction loop's short-term states */ -static llama_pos token_predict_budget; +static llama_pos stop_completion_position; static std::string cached_token_chars; static std::ostringstream assistant_ss; // For storing current assistant message +static void reset_short_term_states() { + stop_completion_position = 0; + cached_token_chars.clear(); + assistant_ss.str(""); +} + extern "C" JNIEXPORT jint JNICALL Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( @@ -370,14 +408,9 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( jobject /*unused*/, jstring jsystem_prompt ) { - // Reset long-term states and reset KV cache - current_position = 0; - llama_memory_clear(llama_get_memory(g_context), false); - - // Reset short-term states - token_predict_budget = 0; - cached_token_chars.clear(); - assistant_ss.str(""); + // Reset long-term & short-term states + reset_long_term_states(); + reset_short_term_states(); // Obtain system prompt from JEnv const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr); @@ -413,11 +446,10 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( } // Update position - current_position = (int) system_tokens.size(); + system_prompt_position = current_position = (int) system_tokens.size(); return 0; } -// TODO-hyin: support KV cache backtracking extern "C" JNIEXPORT jint JNICALL Java_android_llama_cpp_LLamaAndroid_processUserPrompt( @@ -427,9 +459,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( jint n_predict ) { // Reset short-term states - token_predict_budget = 0; - cached_token_chars.clear(); - assistant_ss.str(""); + reset_short_term_states(); // Obtain and tokenize user prompt const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr); @@ -450,22 +480,14 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( } // Ensure user prompt doesn't exceed the context size by truncating if necessary. + const int user_prompt_size = (int) user_tokens.size(); const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM; - if ((int) user_tokens.size() > max_batch_size) { - const int skipped_tokens = (int) user_tokens.size() - max_batch_size; + if (user_prompt_size > max_batch_size) { + const int skipped_tokens = user_prompt_size - max_batch_size; user_tokens.resize(max_batch_size); LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens); } - // TODO-hyin: implement context shifting - // Check if context space is enough for desired tokens - int desired_budget = current_position + (int) user_tokens.size() + n_predict; - if (desired_budget > max_batch_size) { - LOGe("Not enough context! %d total tokens, max: %d", desired_budget, max_batch_size); - return 1; - } - token_predict_budget = desired_budget; - // Decode user tokens in batches if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) { LOGe("llama_decode() failed!"); @@ -473,7 +495,8 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( } // Update position - current_position += (int) user_tokens.size(); + current_position += user_prompt_size; + stop_completion_position = current_position + user_prompt_size + n_predict; return 0; } @@ -517,9 +540,15 @@ Java_android_llama_cpp_LLamaAndroid_completionLoop( JNIEnv *env, jobject /*unused*/ ) { - // Stop if running out of token budget - if (current_position >= token_predict_budget) { - LOGw("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget); + // Infinite text generation via context shifting + if (current_position >= CONTEXT_SIZE - OVERFLOW_HEADROOM) { + LOGw("Context full! Shifting..."); + shift_context(); + } + + // Stop if reaching the marked position + if (current_position >= stop_completion_position) { + LOGw("STOP: hitting stop position: %d", stop_completion_position); return nullptr; }