diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 019caacc4a..fceffcc693 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -33,20 +33,20 @@ static std::string join(const std::vector &values, const std::string &delim) /** * LLama resources: context, model, batch and sampler */ -constexpr int N_THREADS_MIN = 1; -constexpr int N_THREADS_MAX = 8; -constexpr int N_THREADS_HEADROOM = 2; +constexpr int N_THREADS_MIN = 1; +constexpr int N_THREADS_MAX = 8; +constexpr int N_THREADS_HEADROOM = 2; -constexpr int CONTEXT_SIZE = 4096; -constexpr int OVERFLOW_HEADROOM = 4; -constexpr int BATCH_SIZE = 512; -constexpr float SAMPLER_TEMP = 0.3f; +constexpr int DEFAULT_CONTEXT_SIZE = 8192; +constexpr int OVERFLOW_HEADROOM = 4; +constexpr int BATCH_SIZE = 512; +constexpr float DEFAULT_SAMPLER_TEMP = 0.3f; -static llama_model * g_model; -static llama_context * g_context; -static llama_batch * g_batch; -static common_sampler * g_sampler; -static common_chat_templates_ptr g_chat_templates; +static llama_model * g_model; +static llama_context * g_context; +static llama_batch * g_batch; +static common_sampler * g_sampler; +static common_chat_templates_ptr g_chat_templates; static void log_callback(ggml_log_level level, const char *fmt, void *data) { int priority; @@ -98,12 +98,11 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file llama_model_params model_params = llama_model_default_params(); const auto *path_to_model = env->GetStringUTFChars(filename, 0); - LOGd("Loading model from: %s", path_to_model); + LOGd("%s: Loading model from: \n%s\n", __func__, path_to_model); auto *model = llama_model_load_from_file(path_to_model, model_params); env->ReleaseStringUTFChars(filename, path_to_model); if (!model) { - LOGe("load_model() failed"); return 1; } g_model = model; @@ -112,31 +111,36 @@ Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring file static llama_context *init_context(llama_model *model) { if (!model) { - LOGe("init_context(): model cannot be null"); + LOGe("%s: model cannot be null", __func__); return nullptr; } // Multi-threading setup - int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX, + const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX, (int) sysconf(_SC_NPROCESSORS_ONLN) - N_THREADS_HEADROOM)); - LOGi("Using %d threads", n_threads); + LOGi("%s: Using %d threads", __func__, n_threads); // Context parameters setup llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = CONTEXT_SIZE; + const int trained_context_size = llama_model_n_ctx_train(model); + if (DEFAULT_CONTEXT_SIZE > trained_context_size) { + LOGe("%s: Model was trained with only %d context size! Enforcing %d context size...", + __func__, trained_context_size, DEFAULT_CONTEXT_SIZE); + } + ctx_params.n_ctx = DEFAULT_CONTEXT_SIZE; ctx_params.n_batch = BATCH_SIZE; ctx_params.n_ubatch = BATCH_SIZE; ctx_params.n_threads = n_threads; ctx_params.n_threads_batch = n_threads; auto *context = llama_init_from_model(g_model, ctx_params); if (context == nullptr) { - LOGe("llama_new_context_with_model() returned null)"); + LOGe("%s: llama_new_context_with_model() returned null)", __func__); } return context; } -static llama_batch *new_batch(int n_tokens, bool embd = false, int n_seq_max = 1) { +static llama_batch *new_batch(int n_tokens, int n_seq_max = 1) { // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. auto *batch = new llama_batch{ 0, @@ -148,12 +152,7 @@ static llama_batch *new_batch(int n_tokens, bool embd = false, int n_seq_max = 1 nullptr, }; - if (embd) { - batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); - } else { - batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); - } - + batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); @@ -177,7 +176,7 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus if (!context) { return 1; } g_context = context; g_batch = new_batch(BATCH_SIZE); - g_sampler = new_sampler(SAMPLER_TEMP); + g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP); g_chat_templates = common_chat_templates_init(g_model, ""); return 0; } @@ -305,7 +304,9 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, /** - * Prediction loop's long-term states + * Completion loop's long-term states: + * - chat management + * - position tracking */ constexpr const char *ROLE_SYSTEM = "system"; constexpr const char *ROLE_USER = "user"; @@ -325,6 +326,8 @@ static void reset_long_term_states(const bool clear_kv_cache = true) { } /** + * TODO-hyin: implement sliding-window version as a better alternative + * * Context shifting by discarding the older half of the tokens appended after system prompt: * - take the [system_prompt_position] first tokens from the original prompt * - take half of the last (system_prompt_position - system_prompt_position) tokens @@ -332,12 +335,11 @@ static void reset_long_term_states(const bool clear_kv_cache = true) { */ static void shift_context() { const int n_discard = (current_position - system_prompt_position) / 2; - LOGi("Discarding %d tokens", n_discard); - + LOGi("%s: Discarding %d tokens", __func__, n_discard); llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard); llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard); current_position -= n_discard; - LOGi("Context shifting done! Current position: %d", current_position); + LOGi("%s: Context shifting done! Current position: %d", __func__, current_position); } static std::string chat_add_and_format(const std::string &role, const std::string &content) { @@ -347,10 +349,26 @@ static std::string chat_add_and_format(const std::string &role, const std::strin auto formatted = common_chat_format_single( g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false); chat_msgs.push_back(new_msg); - LOGi("Formatted and added %s message: \n%s\n", role.c_str(), formatted.c_str()); + LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str()); return formatted; } +/** + * Completion loop's short-term states: + * - stop generation position + * - token chars caching + * - current assistant message being generated + */ +static llama_pos stop_generation_position; +static std::string cached_token_chars; +static std::ostringstream assistant_ss; + +static void reset_short_term_states() { + stop_generation_position = 0; + cached_token_chars.clear(); + assistant_ss.str(""); +} + static int decode_tokens_in_batches( llama_context *context, const llama_tokens &tokens, @@ -358,15 +376,15 @@ static int decode_tokens_in_batches( bool compute_last_logit = false, llama_batch *batch = g_batch) { // Process tokens in batches using the global batch - LOGd("Decode %d tokens starting at position %d", (int) tokens.size(), start_pos); + LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos); for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) { const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE); common_batch_clear(*batch); - LOGv("Preparing a batch size of %d starting at: %d", cur_batch_size, i); + LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i); // Shift context if current batch cannot fit into the context - if (start_pos + i + cur_batch_size >= CONTEXT_SIZE - OVERFLOW_HEADROOM) { - LOGw("Current batch won't fit into context! Shifting..."); + if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) { + LOGw("%s: Current batch won't fit into context! Shifting...", __func__); shift_context(); } @@ -381,26 +399,13 @@ static int decode_tokens_in_batches( // Decode this batch const int decode_result = llama_decode(context, *batch); if (decode_result) { - LOGe("llama_decode failed w/ %d", decode_result); + LOGe("%s: llama_decode failed w/ %d", __func__, decode_result); return 1; } } return 0; } -/** - * Prediction loop's short-term states - */ -static llama_pos stop_completion_position; -static std::string cached_token_chars; -static std::ostringstream assistant_ss; // For storing current assistant message - -static void reset_short_term_states() { - stop_completion_position = 0; - cached_token_chars.clear(); - assistant_ss.str(""); -} - extern "C" JNIEXPORT jint JNICALL Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( @@ -414,7 +419,7 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( // Obtain system prompt from JEnv const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr); - LOGd("System prompt received: \n%s", system_prompt); + LOGd("%s: System prompt received: \n%s", __func__, system_prompt); std::string formatted_system_prompt(system_prompt); env->ReleaseStringUTFChars(jsystem_prompt, system_prompt); @@ -432,16 +437,16 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( } // Handle context overflow - const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM; + const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM; if ((int) system_tokens.size() > max_batch_size) { - LOGe("System prompt too long for context! %d tokens, max: %d", - (int) system_tokens.size(), max_batch_size); + LOGe("%s: System prompt too long for context! %d tokens, max: %d", + __func__, (int) system_tokens.size(), max_batch_size); return 1; } // Decode system tokens in batches if (decode_tokens_in_batches(g_context, system_tokens, current_position)) { - LOGe("llama_decode() failed!"); + LOGe("%s: llama_decode() failed!", __func__); return 2; } @@ -463,7 +468,7 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( // Obtain and tokenize user prompt const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr); - LOGd("User prompt received: \n%s", user_prompt); + LOGd("%s: User prompt received: \n%s", __func__, user_prompt); std::string formatted_user_prompt(user_prompt); env->ReleaseStringUTFChars(juser_prompt, user_prompt); @@ -481,22 +486,22 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( // Ensure user prompt doesn't exceed the context size by truncating if necessary. const int user_prompt_size = (int) user_tokens.size(); - const int max_batch_size = CONTEXT_SIZE - OVERFLOW_HEADROOM; + const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM; if (user_prompt_size > max_batch_size) { const int skipped_tokens = user_prompt_size - max_batch_size; user_tokens.resize(max_batch_size); - LOGw("User prompt too long! Skipped %d tokens!", skipped_tokens); + LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens); } // Decode user tokens in batches if (decode_tokens_in_batches(g_context, user_tokens, current_position, true)) { - LOGe("llama_decode() failed!"); + LOGe("%s: llama_decode() failed!", __func__); return 2; } // Update position current_position += user_prompt_size; - stop_completion_position = current_position + user_prompt_size + n_predict; + stop_generation_position = current_position + user_prompt_size + n_predict; return 0; } @@ -536,19 +541,19 @@ static bool is_valid_utf8(const char *string) { extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_completionLoop( +Java_android_llama_cpp_LLamaAndroid_generateNextToken( JNIEnv *env, jobject /*unused*/ ) { // Infinite text generation via context shifting - if (current_position >= CONTEXT_SIZE - OVERFLOW_HEADROOM) { - LOGw("Context full! Shifting..."); + if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) { + LOGw("%s: Context full! Shifting...", __func__); shift_context(); } // Stop if reaching the marked position - if (current_position >= stop_completion_position) { - LOGw("STOP: hitting stop position: %d", stop_completion_position); + if (current_position >= stop_generation_position) { + LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position); return nullptr; } @@ -560,7 +565,7 @@ Java_android_llama_cpp_LLamaAndroid_completionLoop( common_batch_clear(*g_batch); common_batch_add(*g_batch, new_token_id, current_position, {0}, true); if (llama_decode(g_context, *g_batch) != 0) { - LOGe("llama_decode() failed for generated token"); + LOGe("%s: llama_decode() failed for generated token", __func__); return nullptr; } @@ -578,7 +583,7 @@ Java_android_llama_cpp_LLamaAndroid_completionLoop( auto new_token_chars = common_token_to_piece(g_context, new_token_id); cached_token_chars += new_token_chars; - // Create and return Java string + // Create and return a valid UTF-8 Java string jstring result = nullptr; if (is_valid_utf8(cached_token_chars.c_str())) { result = env->NewStringUTF(cached_token_chars.c_str()); diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index 4895086f97..41e84f3fc2 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -25,7 +25,7 @@ class LLamaAndroid { private external fun processSystemPrompt(systemPrompt: String): Int private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int - private external fun completionLoop(): String? + private external fun generateNextToken(): String? /** * Thread local state @@ -128,7 +128,7 @@ class LLamaAndroid { Log.i(TAG, "User prompt processed! Generating assistant prompt...") while (true) { - completionLoop()?.let { utf8token -> + generateNextToken()?.let { utf8token -> if (utf8token.isNotEmpty()) emit(utf8token) } ?: break }