Feature: implement infinite conversation via context shifting
This commit is contained in:
parent
4e515727b4
commit
ec502cfde9
|
|
@ -311,8 +311,34 @@ constexpr const char *ROLE_SYSTEM = "system";
|
||||||
constexpr const char *ROLE_USER = "user";
|
constexpr const char *ROLE_USER = "user";
|
||||||
constexpr const char *ROLE_ASSISTANT = "assistant";
|
constexpr const char *ROLE_ASSISTANT = "assistant";
|
||||||
|
|
||||||
static llama_pos current_position;
|
|
||||||
static std::vector<common_chat_msg> chat_msgs;
|
static std::vector<common_chat_msg> chat_msgs;
|
||||||
|
static llama_pos system_prompt_position;
|
||||||
|
static llama_pos current_position;
|
||||||
|
|
||||||
|
static void reset_long_term_states(const bool clear_kv_cache = true) {
|
||||||
|
chat_msgs.clear();
|
||||||
|
system_prompt_position = 0;
|
||||||
|
current_position = 0;
|
||||||
|
|
||||||
|
if (clear_kv_cache)
|
||||||
|
llama_memory_clear(llama_get_memory(g_context), false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Context shifting by discarding the older half of the tokens appended after system prompt:
|
||||||
|
* - take the [system_prompt_position] first tokens from the original prompt
|
||||||
|
* - take half of the last (system_prompt_position - system_prompt_position) tokens
|
||||||
|
* - recompute the logits in batches
|
||||||
|
*/
|
||||||
|
static void shift_context() {
|
||||||
|
const int n_discard = (current_position - system_prompt_position) / 2;
|
||||||
|
LOGi("Discarding %d tokens", n_discard);
|
||||||
|
|
||||||
|
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
|
||||||
|
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
|
||||||
|
current_position -= n_discard;
|
||||||
|
LOGi("Context shifting done! Current position: %d", current_position);
|
||||||
|
}
|
||||||
|
|
||||||
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
||||||
common_chat_msg new_msg;
|
common_chat_msg new_msg;
|
||||||
|
|
@ -334,20 +360,26 @@ static int decode_tokens_in_batches(
|
||||||
// Process tokens in batches using the global batch
|
// Process tokens in batches using the global batch
|
||||||
LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos);
|
LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos);
|
||||||
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||||
int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||||
common_batch_clear(*batch);
|
common_batch_clear(*batch);
|
||||||
LOGv("Preparing a batch size of %d starting at: %d", cur_batch_size, i);
|
LOGv("Preparing a batch size of %d starting at: %d", cur_batch_size, i);
|
||||||
|
|
||||||
|
// Shift context if current batch cannot fit into the context
|
||||||
|
if (start_pos + i + cur_batch_size >= CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||||
|
LOGw("Current batch won't fit into context! Shifting...");
|
||||||
|
shift_context();
|
||||||
|
}
|
||||||
|
|
||||||
// Add tokens to the batch with proper positions
|
// Add tokens to the batch with proper positions
|
||||||
for (int j = 0; j < cur_batch_size; j++) {
|
for (int j = 0; j < cur_batch_size; j++) {
|
||||||
llama_token token_id = tokens[i + j];
|
const llama_token token_id = tokens[i + j];
|
||||||
llama_pos position = start_pos + i + j;
|
const llama_pos position = start_pos + i + j;
|
||||||
bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
|
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
|
||||||
common_batch_add(*batch, token_id, position, {0}, want_logit);
|
common_batch_add(*batch, token_id, position, {0}, want_logit);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode this batch
|
// Decode this batch
|
||||||
int decode_result = llama_decode(context, *batch);
|
const int decode_result = llama_decode(context, *batch);
|
||||||
if (decode_result) {
|
if (decode_result) {
|
||||||
LOGe("llama_decode failed w/ %d", decode_result);
|
LOGe("llama_decode failed w/ %d", decode_result);
|
||||||
return 1;
|
return 1;
|
||||||
|
|
@ -359,10 +391,16 @@ static int decode_tokens_in_batches(
|
||||||
/**
|
/**
|
||||||
* Prediction loop's short-term states
|
* Prediction loop's short-term states
|
||||||
*/
|
*/
|
||||||
static llama_pos token_predict_budget;
|
static llama_pos stop_completion_position;
|
||||||
static std::string cached_token_chars;
|
static std::string cached_token_chars;
|
||||||
static std::ostringstream assistant_ss; // For storing current assistant message
|
static std::ostringstream assistant_ss; // For storing current assistant message
|
||||||
|
|
||||||
|
static void reset_short_term_states() {
|
||||||
|
stop_completion_position = 0;
|
||||||
|
cached_token_chars.clear();
|
||||||
|
assistant_ss.str("");
|
||||||
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jint JNICALL
|
JNIEXPORT jint JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
|
|
@ -370,14 +408,9 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
jobject /*unused*/,
|
jobject /*unused*/,
|
||||||
jstring jsystem_prompt
|
jstring jsystem_prompt
|
||||||
) {
|
) {
|
||||||
// Reset long-term states and reset KV cache
|
// Reset long-term & short-term states
|
||||||
current_position = 0;
|
reset_long_term_states();
|
||||||
llama_memory_clear(llama_get_memory(g_context), false);
|
reset_short_term_states();
|
||||||
|
|
||||||
// Reset short-term states
|
|
||||||
token_predict_budget = 0;
|
|
||||||
cached_token_chars.clear();
|
|
||||||
assistant_ss.str("");
|
|
||||||
|
|
||||||
// Obtain system prompt from JEnv
|
// Obtain system prompt from JEnv
|
||||||
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||||
|
|
@ -413,11 +446,10 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update position
|
// Update position
|
||||||
current_position = (int) system_tokens.size();
|
system_prompt_position = current_position = (int) system_tokens.size();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO-hyin: support KV cache backtracking
|
|
||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jint JNICALL
|
JNIEXPORT jint JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
|
|
@ -427,9 +459,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
jint n_predict
|
jint n_predict
|
||||||
) {
|
) {
|
||||||
// Reset short-term states
|
// Reset short-term states
|
||||||
token_predict_budget = 0;
|
reset_short_term_states();
|
||||||
cached_token_chars.clear();
|
|
||||||
assistant_ss.str("");
|
|
||||||
|
|
||||||
// Obtain and tokenize user prompt
|
// Obtain and tokenize user prompt
|
||||||
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
|
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||||
|
|
@ -450,22 +480,14 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
||||||
|
const int user_prompt_size = (int) user_tokens.size();
|
||||||
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||||
if ((int) user_tokens.size() > max_batch_size) {
|
if (user_prompt_size > max_batch_size) {
|
||||||
const int skipped_tokens = (int) user_tokens.size() - max_batch_size;
|
const int skipped_tokens = user_prompt_size - max_batch_size;
|
||||||
user_tokens.resize(max_batch_size);
|
user_tokens.resize(max_batch_size);
|
||||||
LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens);
|
LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO-hyin: implement context shifting
|
|
||||||
// Check if context space is enough for desired tokens
|
|
||||||
int desired_budget = current_position + (int) user_tokens.size() + n_predict;
|
|
||||||
if (desired_budget > max_batch_size) {
|
|
||||||
LOGe("Not enough context! %d total tokens, max: %d", desired_budget, max_batch_size);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
token_predict_budget = desired_budget;
|
|
||||||
|
|
||||||
// Decode user tokens in batches
|
// Decode user tokens in batches
|
||||||
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
|
if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) {
|
||||||
LOGe("llama_decode() failed!");
|
LOGe("llama_decode() failed!");
|
||||||
|
|
@ -473,7 +495,8 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update position
|
// Update position
|
||||||
current_position += (int) user_tokens.size();
|
current_position += user_prompt_size;
|
||||||
|
stop_completion_position = current_position + user_prompt_size + n_predict;
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -517,9 +540,15 @@ Java_android_llama_cpp_LLamaAndroid_completionLoop(
|
||||||
JNIEnv *env,
|
JNIEnv *env,
|
||||||
jobject /*unused*/
|
jobject /*unused*/
|
||||||
) {
|
) {
|
||||||
// Stop if running out of token budget
|
// Infinite text generation via context shifting
|
||||||
if (current_position >= token_predict_budget) {
|
if (current_position >= CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||||
LOGw("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget);
|
LOGw("Context full! Shifting...");
|
||||||
|
shift_context();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop if reaching the marked position
|
||||||
|
if (current_position >= stop_completion_position) {
|
||||||
|
LOGw("STOP: hitting stop position: %d", stop_completion_position);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue