From 7dc9968f82682c9276b76283cf65f5953a7b6049 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Fri, 28 Mar 2025 12:21:14 -0700 Subject: [PATCH] Restructure `LLamaAndroid.kt` --- .../java/com/example/llama/MainViewModel.kt | 3 +- .../java/android/llama/cpp/LLamaAndroid.kt | 176 ++++++++++-------- 2 files changed, 100 insertions(+), 79 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt index c7ffa3f77c..9b1aa9d96c 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt @@ -45,7 +45,8 @@ class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instan messages += "" viewModelScope.launch { - llamaAndroid.sendMessage(text) + // TODO-hyin: implement format message + llamaAndroid.sendUserPrompt(formattedMessage = text) .catch { Log.e(tag, "send() failed", it) messages += it.message!! diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index dfba00475b..a8102eea04 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -11,31 +11,10 @@ import java.util.concurrent.Executors import kotlin.concurrent.thread class LLamaAndroid { - private val tag: String? = this::class.simpleName - - private val threadLocalState: ThreadLocal = ThreadLocal.withInitial { State.Idle } - - private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor { - thread(start = false, name = "Llm-RunLoop") { - Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}") - - // No-op if called more than once. - System.loadLibrary("llama-android") - - // Set llama log handler to Android - log_to_android() - backend_init() - - Log.d(tag, system_info()) - - it.run() - }.apply { - uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable -> - Log.e(tag, "Unhandled exception", exception) - } - } - }.asCoroutineDispatcher() - + /** + * JNI methods + * @see llama-android.cpp + */ private external fun log_to_android() private external fun system_info(): String private external fun backend_init() @@ -50,6 +29,40 @@ class LLamaAndroid { private external fun process_user_prompt(user_prompt: String, nLen: Int): Int private external fun predict_loop(): String? + /** + * Thread local state + */ + private sealed interface State { + data object Idle: State + data object ModelLoaded: State + data object ReadyForUserPrompt: State + } + private val threadLocalState: ThreadLocal = ThreadLocal.withInitial { State.Idle } + + private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor { + thread(start = false, name = "Llm-RunLoop") { + Log.d(TAG, "Dedicated thread for native code: ${Thread.currentThread().name}") + + // No-op if called more than once. + System.loadLibrary("llama-android") + + // Set llama log handler to Android + log_to_android() + backend_init() + + Log.d(TAG, system_info()) + + it.run() + }.apply { + uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable -> + Log.e(TAG, "Unhandled exception", exception) + } + } + }.asCoroutineDispatcher() + + /** + * Load the LLM, then process the system prompt if provided + */ suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) { withContext(runLoop) { when (threadLocalState.get()) { @@ -60,12 +73,13 @@ class LLamaAndroid { val result = ctx_init() if (result != 0) throw IllegalStateException("Initialization failed with error code: $result") - Log.i(tag, "Loaded model $pathToModel") + Log.i(TAG, "Loaded model $pathToModel") threadLocalState.set(State.ModelLoaded) formattedSystemPrompt?.let { initWithSystemPrompt(formattedSystemPrompt) } ?: { + Log.w(TAG, "No system prompt to process.") threadLocalState.set(State.ReadyForUserPrompt) } } @@ -74,11 +88,66 @@ class LLamaAndroid { } } + /** + * Helper method to process system prompt and update [State] + */ + private suspend fun initWithSystemPrompt(formattedMessage: String) { + withContext(runLoop) { + when (threadLocalState.get()) { + is State.ModelLoaded -> { + Log.i(TAG, "Process system prompt...") + process_system_prompt(formattedMessage).let { + if (it != 0) + throw IllegalStateException("Failed to process system prompt: $it") + } + + Log.i(TAG, "System prompt processed!") + threadLocalState.set(State.ReadyForUserPrompt) + } + else -> throw IllegalStateException( + "Failed to process system prompt: Model not loaded!" + ) + } + } + } + + /** + * Send plain text user prompt to LLM + */ + fun sendUserPrompt( + formattedMessage: String, + nPredict: Int = DEFAULT_PREDICT_LENGTH, + ): Flow = flow { + when (threadLocalState.get()) { + is State.ReadyForUserPrompt -> { + process_user_prompt(formattedMessage, nPredict).let { + if (it != 0) { + Log.e(TAG, "Failed to process user prompt: $it") + return@flow + } + } + + Log.i(TAG, "User prompt processed! Generating assistant prompt...") + while (true) { + val str = predict_loop() ?: break + if (str.isNotEmpty()) { + emit(str) + } + } + Log.i(TAG, "Assistant generation complete!") + } + else -> {} + } + }.flowOn(runLoop) + + /** + * Benchmark the model + */ suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { return withContext(runLoop) { when (val state = threadLocalState.get()) { is State.ModelLoaded -> { - Log.d(tag, "bench(): $state") + Log.d(TAG, "bench(): $state") bench_model(pp, tg, pl, nr) } @@ -88,50 +157,6 @@ class LLamaAndroid { } } - private suspend fun initWithSystemPrompt(systemPrompt: String) { - withContext(runLoop) { - when (threadLocalState.get()) { - is State.ModelLoaded -> { - process_system_prompt(systemPrompt).let { - if (it != 0) { - throw IllegalStateException("Failed to process system prompt: $it") - } - } - - Log.i(tag, "System prompt processed!") - threadLocalState.set(State.ReadyForUserPrompt) - } - else -> throw IllegalStateException("Model not loaded") - } - } - } - - fun sendMessage( - formattedUserPrompt: String, - nPredict: Int = DEFAULT_PREDICT_LENGTH, - ): Flow = flow { - when (threadLocalState.get()) { - is State.ReadyForUserPrompt -> { - process_user_prompt(formattedUserPrompt, nPredict).let { - if (it != 0) { - Log.e(tag, "Failed to process user prompt: $it") - return@flow - } - } - - Log.i(tag, "User prompt processed! Generating assistant prompt...") - while (true) { - val str = predict_loop() ?: break - if (str.isNotEmpty()) { - emit(str) - } - } - Log.i(tag, "Assistant generation complete!") - } - else -> {} - } - }.flowOn(runLoop) - /** * Unloads the model and frees resources. * @@ -150,17 +175,12 @@ class LLamaAndroid { } companion object { - private const val DEFAULT_PREDICT_LENGTH = 128 + private val TAG = this::class.simpleName - private sealed interface State { - data object Idle: State - data object ModelLoaded: State - data object ReadyForUserPrompt: State - } + private const val DEFAULT_PREDICT_LENGTH = 128 // Enforce only one instance of Llm. private val _instance: LLamaAndroid = LLamaAndroid() - fun instance(): LLamaAndroid = _instance } }