LLama: add ModelUnloadingState to engine State; add missing state checks in stub engine; fix instrumentation engine's error messages

This commit is contained in:
Han Yin 2025-04-18 12:10:09 -07:00
parent 481ba6e9d3
commit c08d02d233
3 changed files with 162 additions and 107 deletions

View File

@ -3,13 +3,19 @@ package com.example.llama.revamp.engine
import android.llama.cpp.InferenceEngine
import android.llama.cpp.InferenceEngine.State
import android.util.Log
import com.example.llama.revamp.APP_NAME
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.jetbrains.annotations.TestOnly
import org.jetbrains.annotations.VisibleForTesting
import javax.inject.Singleton
@ -24,38 +30,54 @@ class StubInferenceEngine : InferenceEngine {
companion object {
private val TAG = StubInferenceEngine::class.java.simpleName
private const val STUB_MODEL_LOADING_TIME = 2000L
private const val STUB_BENCHMARKING_TIME = 4000L
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 3000L
private const val STUB_USER_PROMPT_PROCESSING_TIME = 1500L
private const val STUB_LIBRARY_LOADING_TIME = 2_000L
private const val STUB_MODEL_LOADING_TIME = 3_000L
private const val STUB_MODEL_UNLOADING_TIME = 2_000L
private const val STUB_BENCHMARKING_TIME = 8_000L
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 4_000L
private const val STUB_USER_PROMPT_PROCESSING_TIME = 2_000L
private const val STUB_TOKEN_GENERATION_TIME = 200L
}
private val _state = MutableStateFlow<State>(State.Uninitialized)
override val state: StateFlow<State> = _state
private var _readyForSystemPrompt = false
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
init {
llamaScope.launch {
Log.i(TAG, "Initiated!")
// Simulate library loading
delay(STUB_LIBRARY_LOADING_TIME)
_state.value = State.LibraryLoaded
}
}
/**
* Loads a model from the given path.
*/
override suspend fun loadModel(pathToModel: String) {
Log.i(TAG, "loadModel! state: ${_state.value}")
override suspend fun loadModel(pathToModel: String) =
withContext(llamaDispatcher) {
Log.i(TAG, "loadModel! state: ${_state.value.javaClass.simpleName}")
check(_state.value is State.LibraryLoaded) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
try {
_readyForSystemPrompt = false
_state.value = State.LoadingModel
// Simulate model loading
delay(STUB_MODEL_LOADING_TIME)
_readyForSystemPrompt = true
_state.value = State.ModelReady
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
throw e
@ -67,7 +89,15 @@ class StubInferenceEngine : InferenceEngine {
/**
* Process the plain text system prompt
*/
override suspend fun setSystemPrompt(prompt: String) {
override suspend fun setSystemPrompt(prompt: String) =
withContext(llamaDispatcher) {
check(_state.value is State.ModelReady) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
check(_readyForSystemPrompt) {
"System prompt must be set ** RIGHT AFTER ** model loaded!"
}
try {
_state.value = State.ProcessingSystemPrompt
@ -86,15 +116,17 @@ class StubInferenceEngine : InferenceEngine {
/**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/
override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> {
Log.i(TAG, "sendUserPrompt! state: ${_state.value}")
override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.ModelReady) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
try {
Log.i(TAG, "sendUserPrompt! \n$message")
_state.value = State.ProcessingUserPrompt
// This would be replaced with actual token generation logic
return flow {
try {
// Simulate longer processing time (1.5 seconds)
// Simulate longer processing time
delay(STUB_USER_PROMPT_PROCESSING_TIME)
_state.value = State.Generating
@ -122,22 +154,25 @@ class StubInferenceEngine : InferenceEngine {
}
throw e
}
}
/**
* Runs a benchmark with the specified parameters.
*/
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String {
Log.i(TAG, "bench! state: ${_state.value}")
_state.value = State.Benchmarking
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
withContext(llamaDispatcher) {
check(_state.value is State.ModelReady) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
try {
Log.i(TAG, "bench! state: ${_state.value}")
_state.value = State.Benchmarking
// Simulate benchmark running
delay(STUB_BENCHMARKING_TIME)
// Generate fake benchmark results
val modelDesc = "Kleidi Llama"
val modelDesc = APP_NAME
val model_size = "7"
val model_n_params = "7"
val backend = "CPU"
@ -158,27 +193,38 @@ class StubInferenceEngine : InferenceEngine {
_state.value = State.ModelReady
return result.toString()
result.toString()
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
Log.w(TAG, "Unexpected user cancellation while benchmarking!")
_state.value = State.ModelReady
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
return "Error: ${e.message}"
"Error: ${e.message}"
}
}
/**
* Unloads the currently loaded model.
*/
override suspend fun unloadModel() {
Log.i(TAG, "unloadModel! state: ${_state.value}")
override suspend fun unloadModel() =
withContext(llamaDispatcher) {
when(val state = _state.value) {
is State.ModelReady, is State.Error -> {
Log.i(TAG, "unloadModel! state: ${_state.value.javaClass.simpleName}")
_state.value = State.UnloadingModel
// Simulate model unloading time
delay(2000)
delay(STUB_MODEL_UNLOADING_TIME)
_state.value = State.LibraryLoaded
}
else -> throw IllegalStateException(
"Cannot load model at ${_state.value.javaClass.simpleName}"
)
}
}
/**
* Cleans up resources when the engine is no longer needed.

View File

@ -50,6 +50,7 @@ interface InferenceEngine {
object LibraryLoaded : State()
object LoadingModel : State()
object UnloadingModel : State()
object ModelReady : State()
object ProcessingSystemPrompt : State()

View File

@ -87,7 +87,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
*/
override suspend fun loadModel(pathToModel: String) =
withContext(llamaDispatcher) {
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
check(_state.value is State.LibraryLoaded) {
"Cannot load model in ${_state.value.javaClass.simpleName}!"
}
File(pathToModel).let {
require(it.exists()) { "Model file not found: $pathToModel" }
require(it.isFile) { "Model file is not a file: $pathToModel" }
@ -114,7 +116,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
withContext(llamaDispatcher) {
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" }
check(_state.value is State.ModelReady) {
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
}
Log.i(TAG, "Sending system prompt...")
_readyForSystemPrompt = false
@ -139,13 +143,14 @@ class LLamaAndroid private constructor() : InferenceEngine {
): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.ModelReady) {
"User prompt discarded due to: ${_state.value}"
"User prompt discarded due to: ${_state.value.javaClass.simpleName}"
}
try {
Log.i(TAG, "Sending user prompt...")
_readyForSystemPrompt = false
_state.value = State.ProcessingUserPrompt
processUserPrompt(message, predictLength).let { result ->
if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result")
@ -194,16 +199,19 @@ class LLamaAndroid private constructor() : InferenceEngine {
*/
override suspend fun unloadModel() =
withContext(llamaDispatcher) {
when(_state.value) {
when(val state = _state.value) {
is State.ModelReady, is State.Error -> {
Log.i(TAG, "Unloading model and free resources...")
_readyForSystemPrompt = false
_state.value = State.UnloadingModel
unload()
_state.value = State.LibraryLoaded
Log.i(TAG, "Model unloaded!")
Unit
}
else -> throw IllegalStateException("Cannot unload model in ${_state.value}")
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
}
}