From 4e515727b43acc6e3d3b950d545722462d15bbdc Mon Sep 17 00:00:00 2001 From: Han Yin Date: Tue, 8 Apr 2025 11:27:00 -0700 Subject: [PATCH] Abort on system prompt too long; Truncate user prompt if too long. --- .../llama/src/main/cpp/llama-android.cpp | 81 +++++++++++-------- 1 file changed, 49 insertions(+), 32 deletions(-) diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 4a0492aebb..344308763f 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -10,7 +10,8 @@ #include "common.h" #include "llama.h" -template static std::string join(const std::vector & values, const std::string & delim) { +template +static std::string join(const std::vector &values, const std::string &delim) { std::ostringstream str; for (size_t i = 0; i < values.size(); i++) { str << values[i]; @@ -37,6 +38,7 @@ constexpr int N_THREADS_MAX = 8; constexpr int N_THREADS_HEADROOM = 2; constexpr int CONTEXT_SIZE = 4096; +constexpr int OVERFLOW_HEADROOM = 4; constexpr int BATCH_SIZE = 512; constexpr float SAMPLER_TEMP = 0.3f; @@ -44,7 +46,7 @@ static llama_model * g_model; static llama_context * g_context; static llama_batch * g_batch; static common_sampler * g_sampler; -static common_chat_templates_ptr g_chat_templates; +static common_chat_templates_ptr g_chat_templates; static void log_callback(ggml_log_level level, const char *fmt, void *data) { int priority; @@ -68,9 +70,9 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) { __android_log_print(priority, TAG, fmt, data); } -JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) { - JNIEnv* env; - if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) { + JNIEnv *env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { return JNI_ERR; } @@ -102,13 +104,13 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file env->ReleaseStringUTFChars(filename, path_to_model); if (!model) { LOGe("load_model() failed"); - return -1; + return 1; } g_model = model; return 0; } -static llama_context* init_context(llama_model *model) { +static llama_context *init_context(llama_model *model) { if (!model) { LOGe("init_context(): model cannot be null"); return nullptr; @@ -134,7 +136,7 @@ static llama_context* init_context(llama_model *model) { return context; } -static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) { +static llama_batch *new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) { // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. auto *batch = new llama_batch{ 0, @@ -162,7 +164,7 @@ static llama_batch * new_batch(int n_tokens, bool embd = false, int n_seq_max = return batch; } -static common_sampler* new_sampler(float temp) { +static common_sampler *new_sampler(float temp) { common_params_sampling sparams; sparams.temp = temp; return common_sampler_init(g_model, sparams); @@ -172,7 +174,7 @@ extern "C" JNIEXPORT jint JNICALL Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) { auto *context = init_context(g_model); - if (!context) { return -1; } + if (!context) { return 1; } g_context = context; g_batch = new_batch(BATCH_SIZE); g_sampler = new_sampler(SAMPLER_TEMP); @@ -194,7 +196,7 @@ Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unuse static std::string get_backend() { std::vector backends; for (size_t i = 0; i < ggml_backend_reg_count(); i++) { - auto * reg = ggml_backend_reg_get(i); + auto *reg = ggml_backend_reg_get(i); std::string name = ggml_backend_reg_name(reg); if (name != "CPU") { backends.push_back(ggml_backend_reg_name(reg)); @@ -205,7 +207,8 @@ static std::string get_backend() { extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) { +Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, + jint pl, jint nr) { auto pp_avg = 0.0; auto tg_avg = 0.0; auto pp_std = 0.0; @@ -304,14 +307,14 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, /** * Prediction loop's long-term states */ -constexpr const char* ROLE_SYSTEM = "system"; -constexpr const char* ROLE_USER = "user"; -constexpr const char* ROLE_ASSISTANT = "assistant"; +constexpr const char *ROLE_SYSTEM = "system"; +constexpr const char *ROLE_USER = "user"; +constexpr const char *ROLE_ASSISTANT = "assistant"; static llama_pos current_position; static std::vector chat_msgs; -static std::string chat_add_and_format(const std::string & role, const std::string & content) { +static std::string chat_add_and_format(const std::string &role, const std::string &content) { common_chat_msg new_msg; new_msg.role = role; new_msg.content = content; @@ -324,12 +327,12 @@ static std::string chat_add_and_format(const std::string & role, const std::stri static int decode_tokens_in_batches( llama_context *context, - const llama_tokens& tokens, + const llama_tokens &tokens, const llama_pos start_pos, bool compute_last_logit = false, llama_batch *batch = g_batch) { // Process tokens in batches using the global batch - LOGd("Decode %d tokens starting at position %d", tokens.size(), start_pos); + LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos); for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) { int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE); common_batch_clear(*batch); @@ -347,10 +350,9 @@ static int decode_tokens_in_batches( int decode_result = llama_decode(context, *batch); if (decode_result) { LOGe("llama_decode failed w/ %d", decode_result); - return -1; + return 1; } } - return 0; } @@ -390,17 +392,24 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( } // Tokenize system prompt - const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, has_chat_template, has_chat_template); - for (auto id : system_tokens) { + const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, + has_chat_template, has_chat_template); + for (auto id: system_tokens) { LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); } - // TODO-hyin: handle context overflow + // Handle context overflow + const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM; + if ((int) system_tokens.size() > max_batch_size) { + LOGe("System prompt too long for context! %d tokens, max: %d", + (int) system_tokens.size(), max_batch_size); + return 1; + } // Decode system tokens in batches if (decode_tokens_in_batches(g_context, system_tokens, current_position)) { LOGe("llama_decode() failed!"); - return -1; + return 2; } // Update position @@ -435,28 +444,36 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( } // Decode formatted user prompts - const auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template); - for (auto id : user_tokens) { + auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template); + for (auto id: user_tokens) { LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); } - // TODO-hyin: handle context overflow + // Ensure user prompt doesn't exceed the context size by truncating if necessary. + const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM; + if ((int) user_tokens.size() > max_batch_size) { + const int skipped_tokens = (int) user_tokens.size() - max_batch_size; + user_tokens.resize(max_batch_size); + LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens); + } + + // TODO-hyin: implement context shifting // Check if context space is enough for desired tokens int desired_budget = current_position + (int) user_tokens.size() + n_predict; - if (desired_budget > llama_n_ctx(g_context)) { - LOGe("error: total tokens exceed context size"); - return -1; + if (desired_budget > max_batch_size) { + LOGe("Not enough context! %d total tokens, max: %d", desired_budget, max_batch_size); + return 1; } token_predict_budget = desired_budget; // Decode user tokens in batches if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) { LOGe("llama_decode() failed!"); - return -2; + return 2; } // Update position - current_position += (int) user_tokens.size(); // Update position + current_position += (int) user_tokens.size(); return 0; }