diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt index 6517b3f022..df34127212 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt @@ -49,12 +49,14 @@ class StubInferenceEngine : InferenceEngine { init { llamaScope.launch { - Log.i(TAG, "Initiated!") + Log.i(TAG, "Loading and initializing native library!") + _state.value = State.Initializing // Simulate library loading delay(STUB_LIBRARY_LOADING_TIME) - _state.value = State.LibraryLoaded + Log.i(TAG, "Native library initialized!") + _state.value = State.Initialized } } @@ -64,7 +66,7 @@ class StubInferenceEngine : InferenceEngine { override suspend fun loadModel(pathToModel: String) = withContext(llamaDispatcher) { Log.i(TAG, "loadModel! state: ${_state.value.javaClass.simpleName}") - check(_state.value is State.LibraryLoaded) { + check(_state.value is State.Initialized) { "Cannot load model at ${_state.value.javaClass.simpleName}" } @@ -218,7 +220,7 @@ class StubInferenceEngine : InferenceEngine { // Simulate model unloading time delay(STUB_MODEL_UNLOADING_TIME) - _state.value = State.LibraryLoaded + _state.value = State.Initialized } else -> throw IllegalStateException( "Cannot load model at ${_state.value.javaClass.simpleName}" diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt index b1f1733ea4..b39a8bcfc4 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt @@ -1,5 +1,6 @@ package android.llama.cpp +import android.llama.cpp.InferenceEngine.State import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.StateFlow @@ -47,17 +48,18 @@ interface InferenceEngine { */ sealed class State { object Uninitialized : State() - object LibraryLoaded : State() + object Initializing : State() + object Initialized : State() object LoadingModel : State() object UnloadingModel : State() object ModelReady : State() + object Benchmarking : State() object ProcessingSystemPrompt : State() object ProcessingUserPrompt : State() - object Generating : State() - object Benchmarking : State() + object Generating : State() data class Error(val errorMessage: String = "") : State() } @@ -66,3 +68,16 @@ interface InferenceEngine { const val DEFAULT_PREDICT_LENGTH = 1024 } } + +val State.isUninterruptible + get() = this !is State.Initialized && + this !is State.ModelReady && + this !is State.Generating && + this !is State.Error + +val State.isModelLoaded: Boolean + get() = this !is State.Uninitialized && + this !is State.Initializing && + this !is State.Initialized && + this !is State.LoadingModel && + this !is State.UnloadingModel diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index cbf8c4823a..6088ba5921 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -71,13 +71,18 @@ class LLamaAndroid private constructor() : InferenceEngine { init { llamaScope.launch { try { + check(_state.value is State.Uninitialized) { + "Cannot load native library in ${_state.value.javaClass.simpleName}!" + } + + _state.value = State.Initializing System.loadLibrary(LIB_LLAMA_ANDROID) init() - _state.value = State.LibraryLoaded + _state.value = State.Initialized Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}") } catch (e: Exception) { - _state.value = State.Error("Failed to load native library: ${e.message}") Log.e(TAG, "Failed to load native library", e) + throw e } } } @@ -87,7 +92,7 @@ class LLamaAndroid private constructor() : InferenceEngine { */ override suspend fun loadModel(pathToModel: String) = withContext(llamaDispatcher) { - check(_state.value is State.LibraryLoaded) { + check(_state.value is State.Initialized) { "Cannot load model in ${_state.value.javaClass.simpleName}!" } File(pathToModel).let { @@ -207,7 +212,7 @@ class LLamaAndroid private constructor() : InferenceEngine { unload() - _state.value = State.LibraryLoaded + _state.value = State.Initialized Log.i(TAG, "Model unloaded!") Unit } @@ -223,7 +228,7 @@ class LLamaAndroid private constructor() : InferenceEngine { llamaScope.cancel() when(_state.value) { is State.Uninitialized -> {} - is State.LibraryLoaded -> shutdown() + is State.Initialized -> shutdown() else -> { unload(); shutdown() } } }