Feature: implement infinite conversation via context shifting
This commit is contained in:
parent
4e515727b4
commit
ec502cfde9
|
|
@ -311,8 +311,34 @@ constexpr const char *ROLE_SYSTEM = "system";
|
|||
constexpr const char *ROLE_USER = "user";
|
||||
constexpr const char *ROLE_ASSISTANT = "assistant";
|
||||
|
||||
static llama_pos current_position;
|
||||
static std::vector<common_chat_msg> chat_msgs;
|
||||
static llama_pos system_prompt_position;
|
||||
static llama_pos current_position;
|
||||
|
||||
static void reset_long_term_states(const bool clear_kv_cache = true) {
|
||||
chat_msgs.clear();
|
||||
system_prompt_position = 0;
|
||||
current_position = 0;
|
||||
|
||||
if (clear_kv_cache)
|
||||
llama_memory_clear(llama_get_memory(g_context), false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Context shifting by discarding the older half of the tokens appended after system prompt:
|
||||
* - take the [system_prompt_position] first tokens from the original prompt
|
||||
* - take half of the last (system_prompt_position - system_prompt_position) tokens
|
||||
* - recompute the logits in batches
|
||||
*/
|
||||
static void shift_context() {
|
||||
const int n_discard = (current_position - system_prompt_position) / 2;
|
||||
LOGi("Discarding %d tokens", n_discard);
|
||||
|
||||
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
|
||||
current_position -= n_discard;
|
||||
LOGi("Context shifting done! Current position: %d", current_position);
|
||||
}
|
||||
|
||||
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
||||
common_chat_msg new_msg;
|
||||
|
|
@ -334,20 +360,26 @@ static int decode_tokens_in_batches(
|
|||
// Process tokens in batches using the global batch
|
||||
LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos);
|
||||
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||
int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||
common_batch_clear(*batch);
|
||||
LOGv("Preparing a batch size of %d starting at: %d", cur_batch_size, i);
|
||||
|
||||
// Shift context if current batch cannot fit into the context
|
||||
if (start_pos + i + cur_batch_size >= CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||
LOGw("Current batch won't fit into context! Shifting...");
|
||||
shift_context();
|
||||
}
|
||||
|
||||
// Add tokens to the batch with proper positions
|
||||
for (int j = 0; j < cur_batch_size; j++) {
|
||||
llama_token token_id = tokens[i + j];
|
||||
llama_pos position = start_pos + i + j;
|
||||
bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
|
||||
const llama_token token_id = tokens[i + j];
|
||||
const llama_pos position = start_pos + i + j;
|
||||
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
|
||||
common_batch_add(*batch, token_id, position, {0}, want_logit);
|
||||
}
|
||||
|
||||
// Decode this batch
|
||||
int decode_result = llama_decode(context, *batch);
|
||||
const int decode_result = llama_decode(context, *batch);
|
||||
if (decode_result) {
|
||||
LOGe("llama_decode failed w/ %d", decode_result);
|
||||
return 1;
|
||||
|
|
@ -359,10 +391,16 @@ static int decode_tokens_in_batches(
|
|||
/**
|
||||
* Prediction loop's short-term states
|
||||
*/
|
||||
static llama_pos token_predict_budget;
|
||||
static llama_pos stop_completion_position;
|
||||
static std::string cached_token_chars;
|
||||
static std::ostringstream assistant_ss; // For storing current assistant message
|
||||
|
||||
static void reset_short_term_states() {
|
||||
stop_completion_position = 0;
|
||||
cached_token_chars.clear();
|
||||
assistant_ss.str("");
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||
|
|
@ -370,14 +408,9 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
|||
jobject /*unused*/,
|
||||
jstring jsystem_prompt
|
||||
) {
|
||||
// Reset long-term states and reset KV cache
|
||||
current_position = 0;
|
||||
llama_memory_clear(llama_get_memory(g_context), false);
|
||||
|
||||
// Reset short-term states
|
||||
token_predict_budget = 0;
|
||||
cached_token_chars.clear();
|
||||
assistant_ss.str("");
|
||||
// Reset long-term & short-term states
|
||||
reset_long_term_states();
|
||||
reset_short_term_states();
|
||||
|
||||
// Obtain system prompt from JEnv
|
||||
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||
|
|
@ -413,11 +446,10 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
|||
}
|
||||
|
||||
// Update position
|
||||
current_position = (int) system_tokens.size();
|
||||
system_prompt_position = current_position = (int) system_tokens.size();
|
||||
return 0;
|
||||
}
|
||||
|
||||
// TODO-hyin: support KV cache backtracking
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||
|
|
@ -427,9 +459,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
|||
jint n_predict
|
||||
) {
|
||||
// Reset short-term states
|
||||
token_predict_budget = 0;
|
||||
cached_token_chars.clear();
|
||||
assistant_ss.str("");
|
||||
reset_short_term_states();
|
||||
|
||||
// Obtain and tokenize user prompt
|
||||
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||
|
|
@ -450,22 +480,14 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
|||
}
|
||||
|
||||
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
||||
const int user_prompt_size = (int) user_tokens.size();
|
||||
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||
if ((int) user_tokens.size() > max_batch_size) {
|
||||
const int skipped_tokens = (int) user_tokens.size() - max_batch_size;
|
||||
if (user_prompt_size > max_batch_size) {
|
||||
const int skipped_tokens = user_prompt_size - max_batch_size;
|
||||
user_tokens.resize(max_batch_size);
|
||||
LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens);
|
||||
}
|
||||
|
||||
// TODO-hyin: implement context shifting
|
||||
// Check if context space is enough for desired tokens
|
||||
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
|
||||
if (desired_budget > max_batch_size) {
|
||||
LOGe("Not enough context! %d total tokens, max: %d", desired_budget, max_batch_size);
|
||||
return 1;
|
||||
}
|
||||
token_predict_budget = desired_budget;
|
||||
|
||||
// Decode user tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
|
||||
LOGe("llama_decode() failed!");
|
||||
|
|
@ -473,7 +495,8 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
|||
}
|
||||
|
||||
// Update position
|
||||
current_position += (int) user_tokens.size();
|
||||
current_position += user_prompt_size;
|
||||
stop_completion_position = current_position + user_prompt_size + n_predict;
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -517,9 +540,15 @@ Java_android_llama_cpp_LLamaAndroid_completionLoop(
|
|||
JNIEnv *env,
|
||||
jobject /*unused*/
|
||||
) {
|
||||
// Stop if running out of token budget
|
||||
if (current_position >= token_predict_budget) {
|
||||
LOGw("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget);
|
||||
// Infinite text generation via context shifting
|
||||
if (current_position >= CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||
LOGw("Context full! Shifting...");
|
||||
shift_context();
|
||||
}
|
||||
|
||||
// Stop if reaching the marked position
|
||||
if (current_position >= stop_completion_position) {
|
||||
LOGw("STOP: hitting stop position: %d", stop_completion_position);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue