From 0ade7fb4d7e49041d9df43f2a77c73f2b1599c4a Mon Sep 17 00:00:00 2001 From: Han Yin Date: Mon, 31 Mar 2025 12:53:00 -0700 Subject: [PATCH] Polish binding: Remove verbose setup JNI APIs; Update state machine states. --- .../llama/src/main/cpp/llama-android.cpp | 74 +++++++----- .../java/android/llama/cpp/LLamaAndroid.kt | 109 +++++++++--------- 2 files changed, 101 insertions(+), 82 deletions(-) diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 109e630727..366d721fb7 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -12,10 +12,11 @@ * Logging utils */ #define TAG "llama-android.cpp" -#define LOGd(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__) -#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) -#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) -#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) +#define LOGv(...) __android_log_print(ANDROID_LOG_VERBOSE, TAG, __VA_ARGS__) +#define LOGd(...) __android_log_print(ANDROID_LOG_DEBUG, TAG, __VA_ARGS__) +#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) +#define LOGw(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) +#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__) /** * LLama resources: context, model, batch and sampler @@ -55,32 +56,35 @@ static void log_callback(ggml_log_level level, const char *fmt, void *data) { __android_log_print(priority, TAG, fmt, data); } -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv * /*unused*/, jobject /*unused*/) { - llama_log_set(log_callback, nullptr); -} +JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* reserved) { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { + return JNI_ERR; + } + // Set llama log handler to Android + llama_log_set(log_callback, nullptr); + + // Initialize backends + llama_backend_init(); + LOGi("Backend initiated."); + + return JNI_VERSION_1_6; +} extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject /*unused*/) { +Java_android_llama_cpp_LLamaAndroid_systemInfo(JNIEnv *env, jobject /*unused*/) { return env->NewStringUTF(llama_print_system_info()); } -extern "C" -JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv * /*unused*/, jobject /*unused*/) { - llama_backend_init(); -} - extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) { +Java_android_llama_cpp_LLamaAndroid_loadModel(JNIEnv *env, jobject, jstring filename) { llama_model_params model_params = llama_model_default_params(); const auto *path_to_model = env->GetStringUTFChars(filename, 0); - LOGi("Loading model from: %s", path_to_model); + LOGd("Loading model from: %s", path_to_model); model = llama_model_load_from_file(path_to_model, model_params); env->ReleaseStringUTFChars(filename, path_to_model); @@ -153,7 +157,7 @@ void new_sampler(float temp) { extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_LLamaAndroid_ctx_1init(JNIEnv * /*env*/, jobject /*unused*/) { +Java_android_llama_cpp_LLamaAndroid_initContext(JNIEnv * /*env*/, jobject /*unused*/) { int ret = init_context(); if (ret != 0) { return ret; } new_batch(BATCH_SIZE); @@ -163,7 +167,7 @@ Java_android_llama_cpp_LLamaAndroid_ctx_1init(JNIEnv * /*env*/, jobject /*unused extern "C" JNIEXPORT void JNICALL -Java_android_llama_cpp_LLamaAndroid_clean_1up(JNIEnv * /*unused*/, jobject /*unused*/) { +Java_android_llama_cpp_LLamaAndroid_cleanUp(JNIEnv * /*unused*/, jobject /*unused*/) { llama_model_free(model); llama_free(context); llama_backend_free(); @@ -173,7 +177,7 @@ Java_android_llama_cpp_LLamaAndroid_clean_1up(JNIEnv * /*unused*/, jobject /*unu extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_bench_1model(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) { +Java_android_llama_cpp_LLamaAndroid_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, jint pl, jint nr) { auto pp_avg = 0.0; auto tg_avg = 0.0; auto pp_std = 0.0; @@ -284,7 +288,7 @@ std::string cached_token_chars; extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_LLamaAndroid_process_1system_1prompt( +Java_android_llama_cpp_LLamaAndroid_processSystemPrompt( JNIEnv *env, jobject /*unused*/, jstring jsystem_prompt @@ -299,10 +303,15 @@ Java_android_llama_cpp_LLamaAndroid_process_1system_1prompt( // Obtain and tokenize system prompt const auto *const system_text = env->GetStringUTFChars(jsystem_prompt, nullptr); - LOGi("System prompt: \n%s", system_text); + LOGd("System prompt received: \n%s", system_text); const auto system_tokens = common_tokenize(context, system_text, true, true); env->ReleaseStringUTFChars(jsystem_prompt, system_text); + // Print each token in verbose mode + for (auto id : system_tokens) { + LOGv("token: `%s`\t -> `%d`", common_token_to_piece(context, id).c_str(), id); + } + // Add system prompt tokens to batch common_batch_clear(*batch); // TODO-hyin: support batch processing! @@ -325,11 +334,11 @@ Java_android_llama_cpp_LLamaAndroid_process_1system_1prompt( // TODO-hyin: support KV cache backtracking extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_LLamaAndroid_process_1user_1prompt( +Java_android_llama_cpp_LLamaAndroid_processUserPrompt( JNIEnv *env, jobject /*unused*/, jstring juser_prompt, - jint n_len + jint n_predict ) { // Reset short-term states token_predict_budget = 0; @@ -337,12 +346,17 @@ Java_android_llama_cpp_LLamaAndroid_process_1user_1prompt( // Obtain and tokenize user prompt const auto *const user_text = env->GetStringUTFChars(juser_prompt, nullptr); - LOGi("User prompt: \n%s", user_text); + LOGd("User prompt received: \n%s", user_text); const auto user_tokens = common_tokenize(context, user_text, true, true); env->ReleaseStringUTFChars(juser_prompt, user_text); + // Print each token in verbose mode + for (auto id : user_tokens) { + LOGv("token: `%s`\t -> `%d`", common_token_to_piece(context, id).c_str(), id); + } + // Check if context space is enough for desired tokens - int desired_budget = current_position + user_tokens.size() + n_len; + int desired_budget = current_position + user_tokens.size() + n_predict; if (desired_budget > llama_n_ctx(context)) { LOGe("error: total tokens exceed context size"); return -1; @@ -404,13 +418,13 @@ bool is_valid_utf8(const char *string) { extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_LLamaAndroid_predict_1loop( +Java_android_llama_cpp_LLamaAndroid_predictLoop( JNIEnv *env, jobject /*unused*/ ) { // Stop if running out of token budget if (current_position >= token_predict_budget) { - LOGi("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget); + LOGw("STOP: current position (%d) exceeds budget (%d)", current_position, token_predict_budget); return nullptr; } @@ -420,7 +434,7 @@ Java_android_llama_cpp_LLamaAndroid_predict_1loop( // Stop if next token is EOG if (llama_vocab_is_eog(llama_model_get_vocab(model), new_token_id)) { - LOGi("id: %d,\tIS EOG!\nSTOP.", new_token_id); + LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id); return nullptr; } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index a8102eea04..ed46c16713 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -15,42 +15,36 @@ class LLamaAndroid { * JNI methods * @see llama-android.cpp */ - private external fun log_to_android() - private external fun system_info(): String - private external fun backend_init() + private external fun systemInfo(): String - private external fun load_model(filename: String): Int - private external fun ctx_init(): Int - private external fun clean_up() + private external fun loadModel(filename: String): Int + private external fun initContext(): Int + private external fun cleanUp() - private external fun bench_model(pp: Int, tg: Int, pl: Int, nr: Int): String + private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String - private external fun process_system_prompt(system_prompt: String): Int - private external fun process_user_prompt(user_prompt: String, nLen: Int): Int - private external fun predict_loop(): String? + private external fun processSystemPrompt(systemPrompt: String): Int + private external fun processUserPrompt(userPrompt: String, nPredict: Int): Int + private external fun predictLoop(): String? /** * Thread local state */ private sealed interface State { - data object Idle: State - data object ModelLoaded: State - data object ReadyForUserPrompt: State + data object NotInitialized: State + data object EnvReady: State + data object AwaitingUserPrompt: State + data object Processing: State } - private val threadLocalState: ThreadLocal = ThreadLocal.withInitial { State.Idle } + private val threadLocalState: ThreadLocal = ThreadLocal.withInitial { State.NotInitialized } private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor { - thread(start = false, name = "Llm-RunLoop") { + thread(start = false, name = LLAMA_THREAD) { Log.d(TAG, "Dedicated thread for native code: ${Thread.currentThread().name}") // No-op if called more than once. - System.loadLibrary("llama-android") - - // Set llama log handler to Android - log_to_android() - backend_init() - - Log.d(TAG, system_info()) + System.loadLibrary(LIB_LLAMA_ANDROID) + Log.d(TAG, systemInfo()) it.run() }.apply { @@ -61,26 +55,26 @@ class LLamaAndroid { }.asCoroutineDispatcher() /** - * Load the LLM, then process the system prompt if provided + * Load the LLM, then process the formatted system prompt if provided */ suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) { withContext(runLoop) { when (threadLocalState.get()) { - is State.Idle -> { - val model = load_model(pathToModel) - if (model != 0) throw IllegalStateException("Load model failed") + is State.NotInitialized -> { + val modelResult = loadModel(pathToModel) + if (modelResult != 0) throw IllegalStateException("Load model failed: $modelResult") - val result = ctx_init() - if (result != 0) throw IllegalStateException("Initialization failed with error code: $result") + val initResult = initContext() + if (initResult != 0) throw IllegalStateException("Initialization failed with error code: $initResult") Log.i(TAG, "Loaded model $pathToModel") - threadLocalState.set(State.ModelLoaded) + threadLocalState.set(State.EnvReady) formattedSystemPrompt?.let { initWithSystemPrompt(formattedSystemPrompt) } ?: { Log.w(TAG, "No system prompt to process.") - threadLocalState.set(State.ReadyForUserPrompt) + threadLocalState.set(State.AwaitingUserPrompt) } } else -> throw IllegalStateException("Model already loaded") @@ -94,15 +88,16 @@ class LLamaAndroid { private suspend fun initWithSystemPrompt(formattedMessage: String) { withContext(runLoop) { when (threadLocalState.get()) { - is State.ModelLoaded -> { + is State.EnvReady -> { Log.i(TAG, "Process system prompt...") - process_system_prompt(formattedMessage).let { + threadLocalState.set(State.Processing) + processSystemPrompt(formattedMessage).let { if (it != 0) throw IllegalStateException("Failed to process system prompt: $it") } Log.i(TAG, "System prompt processed!") - threadLocalState.set(State.ReadyForUserPrompt) + threadLocalState.set(State.AwaitingUserPrompt) } else -> throw IllegalStateException( "Failed to process system prompt: Model not loaded!" @@ -112,31 +107,36 @@ class LLamaAndroid { } /** - * Send plain text user prompt to LLM + * Send formatted user prompt to LLM */ fun sendUserPrompt( formattedMessage: String, nPredict: Int = DEFAULT_PREDICT_LENGTH, ): Flow = flow { - when (threadLocalState.get()) { - is State.ReadyForUserPrompt -> { - process_user_prompt(formattedMessage, nPredict).let { - if (it != 0) { - Log.e(TAG, "Failed to process user prompt: $it") + when (val state = threadLocalState.get()) { + is State.AwaitingUserPrompt -> { + Log.i(TAG, "Sending user prompt...") + threadLocalState.set(State.Processing) + processUserPrompt(formattedMessage, nPredict).let { result -> + if (result != 0) { + Log.e(TAG, "Failed to process user prompt: $result") return@flow } } Log.i(TAG, "User prompt processed! Generating assistant prompt...") while (true) { - val str = predict_loop() ?: break - if (str.isNotEmpty()) { - emit(str) - } + predictLoop()?.let { utf8token -> + if (utf8token.isNotEmpty()) emit(utf8token) + } ?: break } + Log.i(TAG, "Assistant generation complete!") + threadLocalState.set(State.AwaitingUserPrompt) + } + else -> { + Log.w(TAG, "User prompt discarded due to incorrect state: $state") } - else -> {} } }.flowOn(runLoop) @@ -146,9 +146,9 @@ class LLamaAndroid { suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { return withContext(runLoop) { when (val state = threadLocalState.get()) { - is State.ModelLoaded -> { + is State.EnvReady -> { Log.d(TAG, "bench(): $state") - bench_model(pp, tg, pl, nr) + benchModel(pp, tg, pl, nr) } // TODO-hyin: catch exception in ViewController; disable button when state incorrect @@ -164,12 +164,14 @@ class LLamaAndroid { */ suspend fun unload() { withContext(runLoop) { - when (threadLocalState.get()) { - is State.ModelLoaded -> { - clean_up() - threadLocalState.set(State.Idle) + when (val state = threadLocalState.get()) { + is State.EnvReady, State.AwaitingUserPrompt -> { + cleanUp() + threadLocalState.set(State.NotInitialized) + } + else -> { + Log.w(TAG, "Cannot unload model due to incorrect state: $state") } - else -> {} } } } @@ -177,7 +179,10 @@ class LLamaAndroid { companion object { private val TAG = this::class.simpleName - private const val DEFAULT_PREDICT_LENGTH = 128 + private const val LIB_LLAMA_ANDROID = "llama-android" + private const val LLAMA_THREAD = "llama-thread" + + private const val DEFAULT_PREDICT_LENGTH = 64 // Enforce only one instance of Llm. private val _instance: LLamaAndroid = LLamaAndroid()