From 8bf2f4d412ac2c9680e3d3c9c75dd7352e4f1e2d Mon Sep 17 00:00:00 2001 From: Han Yin Date: Mon, 7 Apr 2025 20:37:33 -0700 Subject: [PATCH] Feature: chat template auto formatting --- .../java/com/example/llama/MainViewModel.kt | 2 +- .../llama/src/main/cpp/llama-android.cpp | 96 ++++++++++++++----- .../java/android/llama/cpp/LLamaAndroid.kt | 14 +-- 3 files changed, 78 insertions(+), 34 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt index 9b1aa9d96c..bb29cb08f6 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt @@ -46,7 +46,7 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan viewModelScope.launch { // TODO-hyin: implement format message - llamaAndroid.sendUserPrompt(formattedMessage = text) + llamaAndroid.sendUserPrompt(message = actualText) .catch { Log.e(tag, "send() failed", it) messages += it.message!! diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 098cc747a4..415cc197de 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -5,8 +5,10 @@ #include #include #include -#include "llama.h" + +#include "chat.h" #include "common.h" +#include "llama.h" template static std::string join(const std::vector & values, const std::string & delim) { std::ostringstream str; @@ -34,14 +36,15 @@ constexpr int N_THREADS_MIN = 1; constexpr int N_THREADS_MAX = 8; constexpr int N_THREADS_HEADROOM = 2; -constexpr int CONTEXT_SIZE = 4096; -constexpr int BATCH_SIZE = 512; -constexpr float SAMPLER_TEMP = 0.3f; +constexpr int CONTEXT_SIZE = 4096; +constexpr int BATCH_SIZE = 512; +constexpr float SAMPLER_TEMP = 0.3f; -static llama_model * g_model; -static llama_context * g_context; -static llama_batch * g_batch; -static common_sampler * g_sampler; +static llama_model * g_model; +static llama_context * g_context; +static llama_batch * g_batch; +static common_sampler * g_sampler; +static common_chat_templates_ptr g_chat_templates; static void log_callback(ggml_log_level level, const char *fmt, void *data) { int priority; @@ -174,16 +177,18 @@ Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unus if (ret != 0) { return ret; } new_batch(BATCH_SIZE); new_sampler(SAMPLER_TEMP); + g_chat_templates = common_chat_templates_init(g_model, ""); return 0; } extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) { - llama_model_free(g_model); - llama_free(g_context); - delete g_batch; + g_chat_templates.reset(); common_sampler_free(g_sampler); + delete g_batch; + llama_free(g_context); + llama_model_free(g_model); llama_backend_free(); } @@ -298,12 +303,25 @@ Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, /** - * Prediction loop's long-term and short-term states + * Prediction loop's long-term states */ -static llama_pos current_position; +constexpr const char* ROLE_SYSTEM = "system"; +constexpr const char* ROLE_USER = "user"; +constexpr const char* ROLE_ASSISTANT = "assistant"; -static llama_pos token_predict_budget; -static std::string cached_token_chars; +static llama_pos current_position; +static std::vector chat_msgs; + +static std::string chat_add_and_format(const std::string & role, const std::string & content) { + common_chat_msg new_msg; + new_msg.role = role; + new_msg.content = content; + auto formatted = common_chat_format_single( + g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false); + chat_msgs.push_back(new_msg); + LOGi("Formatted and added %s message: \n%s\n", role.c_str(), formatted.c_str()); + return formatted; +} static int decode_tokens_in_batches( llama_context *context, @@ -337,6 +355,13 @@ static int decode_tokens_in_batches( return 0; } +/** + * Prediction loop's short-term states + */ +static llama_pos token_predict_budget; +static std::string cached_token_chars; +static std::ostringstream assistant_ss; // For storing current assistant message + extern "C" JNIEXPORT jint JNICALL Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( @@ -351,14 +376,22 @@ Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( // Reset short-term states token_predict_budget = 0; cached_token_chars.clear(); + assistant_ss.str(""); - // Obtain and tokenize system prompt - const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr); - LOGd("System prompt received: \n%s", system_text); - const auto system_tokens = common_tokenize(g_context, system_text, true, true); - env->ReleaseStringUTFChars(jsystem_prompt, system_text); + // Obtain system prompt from JEnv + const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr); + LOGd("System prompt received: \n%s", system_prompt); + std::string formatted_system_prompt(system_prompt); + env->ReleaseStringUTFChars(jsystem_prompt, system_prompt); - // Print each token in verbose mode + // Format system prompt if applicable + const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get()); + if (has_chat_template) { + formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt); + } + + // Tokenize system prompt + const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, has_chat_template, has_chat_template); for (auto id : system_tokens) { LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); } @@ -386,14 +419,22 @@ Java_android_llama_cpp_LLamaAndroid_processUserPrompt( // Reset short-term states token_predict_budget = 0; cached_token_chars.clear(); + assistant_ss.str(""); // Obtain and tokenize user prompt - const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr); - LOGd("User prompt received: \n%s", user_text); - const auto user_tokens = common_tokenize(g_context, user_text, true, true); - env->ReleaseStringUTFChars(juser_prompt, user_text); + const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr); + LOGd("User prompt received: \n%s", user_prompt); + std::string formatted_user_prompt(user_prompt); + env->ReleaseStringUTFChars(juser_prompt, user_prompt); - // Print each token in verbose mode + // Format user prompt if applicable + const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get()); + if (has_chat_template) { + formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt); + } + + // Decode formatted user prompts + const auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template); for (auto id : user_tokens) { LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); } @@ -481,6 +522,7 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop( // Stop if next token is EOG if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) { LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id); + chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str()); return nullptr; } @@ -493,6 +535,8 @@ Java_android_llama_cpp_LLamaAndroid_predictLoop( if (is_valid_utf8(cached_token_chars.c_str())) { result = env->NewStringUTF(cached_token_chars.c_str()); LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str()); + + assistant_ss << cached_token_chars; cached_token_chars.clear(); } else { LOGv("id: %d,\tappend to cache", new_token_id); diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index e84259422f..b99ab2867b 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -57,7 +57,7 @@ class LLamaAndroid { /** * Load the LLM, then process the formatted system prompt if provided */ - suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) = + suspend fun load(pathToModel: String, systemPrompt: String? = null) = withContext(runLoop) { when (threadLocalState.get()) { is State.NotInitialized -> { @@ -70,8 +70,8 @@ class LLamaAndroid { Log.i(TAG, "Loaded model $pathToModel") threadLocalState.set(State.EnvReady) - formattedSystemPrompt?.let { - initWithSystemPrompt(formattedSystemPrompt) + systemPrompt?.let { + initWithSystemPrompt(systemPrompt) } ?: run { Log.w(TAG, "No system prompt to process.") threadLocalState.set(State.AwaitingUserPrompt) @@ -108,10 +108,10 @@ class LLamaAndroid { * Send formatted user prompt to LLM */ fun sendUserPrompt( - formattedMessage: String, - nPredict: Int = DEFAULT_PREDICT_LENGTH, + message: String, + predictLength: Int = DEFAULT_PREDICT_LENGTH, ): Flow = flow { - require(formattedMessage.isNotEmpty()) { + require(message.isNotEmpty()) { Log.w(TAG, "User prompt discarded due to being empty!") } @@ -119,7 +119,7 @@ class LLamaAndroid { is State.AwaitingUserPrompt -> { Log.i(TAG, "Sending user prompt...") threadLocalState.set(State.Processing) - processUserPrompt(formattedMessage, nPredict).let { result -> + processUserPrompt(message, predictLength).let { result -> if (result != 0) { Log.e(TAG, "Failed to process user prompt: $result") return@flow