diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index ed46c16713..e84259422f 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -57,7 +57,7 @@ class LLamaAndroid { /** * Load the LLM, then process the formatted system prompt if provided */ - suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) { + suspend fun load(pathToModel: String, formattedSystemPrompt: String? = null) = withContext(runLoop) { when (threadLocalState.get()) { is State.NotInitialized -> { @@ -72,7 +72,7 @@ class LLamaAndroid { formattedSystemPrompt?.let { initWithSystemPrompt(formattedSystemPrompt) - } ?: { + } ?: run { Log.w(TAG, "No system prompt to process.") threadLocalState.set(State.AwaitingUserPrompt) } @@ -80,12 +80,11 @@ class LLamaAndroid { else -> throw IllegalStateException("Model already loaded") } } - } /** * Helper method to process system prompt and update [State] */ - private suspend fun initWithSystemPrompt(formattedMessage: String) { + private suspend fun initWithSystemPrompt(formattedMessage: String) = withContext(runLoop) { when (threadLocalState.get()) { is State.EnvReady -> { @@ -104,7 +103,6 @@ class LLamaAndroid { ) } } - } /** * Send formatted user prompt to LLM @@ -113,6 +111,10 @@ class LLamaAndroid { formattedMessage: String, nPredict: Int = DEFAULT_PREDICT_LENGTH, ): Flow = flow { + require(formattedMessage.isNotEmpty()) { + Log.w(TAG, "User prompt discarded due to being empty!") + } + when (val state = threadLocalState.get()) { is State.AwaitingUserPrompt -> { Log.i(TAG, "Sending user prompt...") @@ -143,26 +145,28 @@ class LLamaAndroid { /** * Benchmark the model */ - suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { - return withContext(runLoop) { - when (val state = threadLocalState.get()) { - is State.EnvReady -> { - Log.d(TAG, "bench(): $state") - benchModel(pp, tg, pl, nr) + suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String = + withContext(runLoop) { + when (threadLocalState.get()) { + is State.AwaitingUserPrompt -> { + threadLocalState.set(State.Processing) + Log.d(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)") + benchModel(pp, tg, pl, nr).also { + threadLocalState.set(State.AwaitingUserPrompt) + } } - // TODO-hyin: catch exception in ViewController; disable button when state incorrect + // TODO-hyin: disable button when state incorrect else -> throw IllegalStateException("No model loaded") } } - } /** * Unloads the model and frees resources. * * This is a no-op if there's no model loaded. */ - suspend fun unload() { + suspend fun unload() = withContext(runLoop) { when (val state = threadLocalState.get()) { is State.EnvReady, State.AwaitingUserPrompt -> { @@ -174,10 +178,9 @@ class LLamaAndroid { } } } - } companion object { - private val TAG = this::class.simpleName + private val TAG = LLamaAndroid::class.simpleName private const val LIB_LLAMA_ANDROID = "llama-android" private const val LLAMA_THREAD = "llama-thread"