LLama: add ModelUnloadingState to engine State; add missing state checks in stub engine; fix instrumentation engine's error messages

This commit is contained in:
Han Yin 2025-04-18 12:10:09 -07:00
parent 481ba6e9d3
commit c08d02d233
3 changed files with 162 additions and 107 deletions

View File

@ -3,13 +3,19 @@ package com.example.llama.revamp.engine
import android.llama.cpp.InferenceEngine import android.llama.cpp.InferenceEngine
import android.llama.cpp.InferenceEngine.State import android.llama.cpp.InferenceEngine.State
import android.util.Log import android.util.Log
import com.example.llama.revamp.APP_NAME
import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.jetbrains.annotations.TestOnly import org.jetbrains.annotations.TestOnly
import org.jetbrains.annotations.VisibleForTesting import org.jetbrains.annotations.VisibleForTesting
import javax.inject.Singleton import javax.inject.Singleton
@ -24,38 +30,54 @@ class StubInferenceEngine : InferenceEngine {
companion object { companion object {
private val TAG = StubInferenceEngine::class.java.simpleName private val TAG = StubInferenceEngine::class.java.simpleName
private const val STUB_MODEL_LOADING_TIME = 2000L private const val STUB_LIBRARY_LOADING_TIME = 2_000L
private const val STUB_BENCHMARKING_TIME = 4000L private const val STUB_MODEL_LOADING_TIME = 3_000L
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 3000L private const val STUB_MODEL_UNLOADING_TIME = 2_000L
private const val STUB_USER_PROMPT_PROCESSING_TIME = 1500L private const val STUB_BENCHMARKING_TIME = 8_000L
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 4_000L
private const val STUB_USER_PROMPT_PROCESSING_TIME = 2_000L
private const val STUB_TOKEN_GENERATION_TIME = 200L private const val STUB_TOKEN_GENERATION_TIME = 200L
} }
private val _state = MutableStateFlow<State>(State.Uninitialized) private val _state = MutableStateFlow<State>(State.Uninitialized)
override val state: StateFlow<State> = _state override val state: StateFlow<State> = _state
private var _readyForSystemPrompt = false
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
init { init {
llamaScope.launch {
Log.i(TAG, "Initiated!") Log.i(TAG, "Initiated!")
// Simulate library loading // Simulate library loading
delay(STUB_LIBRARY_LOADING_TIME)
_state.value = State.LibraryLoaded _state.value = State.LibraryLoaded
} }
}
/** /**
* Loads a model from the given path. * Loads a model from the given path.
*/ */
override suspend fun loadModel(pathToModel: String) { override suspend fun loadModel(pathToModel: String) =
Log.i(TAG, "loadModel! state: ${_state.value}") withContext(llamaDispatcher) {
Log.i(TAG, "loadModel! state: ${_state.value.javaClass.simpleName}")
check(_state.value is State.LibraryLoaded) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
try { try {
_readyForSystemPrompt = false
_state.value = State.LoadingModel _state.value = State.LoadingModel
// Simulate model loading // Simulate model loading
delay(STUB_MODEL_LOADING_TIME) delay(STUB_MODEL_LOADING_TIME)
_readyForSystemPrompt = true
_state.value = State.ModelReady _state.value = State.ModelReady
} catch (e: CancellationException) { } catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation // If coroutine is cancelled, propagate cancellation
throw e throw e
@ -67,7 +89,15 @@ class StubInferenceEngine : InferenceEngine {
/** /**
* Process the plain text system prompt * Process the plain text system prompt
*/ */
override suspend fun setSystemPrompt(prompt: String) { override suspend fun setSystemPrompt(prompt: String) =
withContext(llamaDispatcher) {
check(_state.value is State.ModelReady) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
check(_readyForSystemPrompt) {
"System prompt must be set ** RIGHT AFTER ** model loaded!"
}
try { try {
_state.value = State.ProcessingSystemPrompt _state.value = State.ProcessingSystemPrompt
@ -86,15 +116,17 @@ class StubInferenceEngine : InferenceEngine {
/** /**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens. * Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/ */
override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> { override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> = flow {
Log.i(TAG, "sendUserPrompt! state: ${_state.value}") require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.ModelReady) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
try {
Log.i(TAG, "sendUserPrompt! \n$message")
_state.value = State.ProcessingUserPrompt _state.value = State.ProcessingUserPrompt
// This would be replaced with actual token generation logic // Simulate longer processing time
return flow {
try {
// Simulate longer processing time (1.5 seconds)
delay(STUB_USER_PROMPT_PROCESSING_TIME) delay(STUB_USER_PROMPT_PROCESSING_TIME)
_state.value = State.Generating _state.value = State.Generating
@ -122,22 +154,25 @@ class StubInferenceEngine : InferenceEngine {
} }
throw e throw e
} }
}
/** /**
* Runs a benchmark with the specified parameters. * Runs a benchmark with the specified parameters.
*/ */
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String { override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
Log.i(TAG, "bench! state: ${_state.value}") withContext(llamaDispatcher) {
check(_state.value is State.ModelReady) {
_state.value = State.Benchmarking "Cannot load model at ${_state.value.javaClass.simpleName}"
}
try { try {
Log.i(TAG, "bench! state: ${_state.value}")
_state.value = State.Benchmarking
// Simulate benchmark running // Simulate benchmark running
delay(STUB_BENCHMARKING_TIME) delay(STUB_BENCHMARKING_TIME)
// Generate fake benchmark results // Generate fake benchmark results
val modelDesc = "Kleidi Llama" val modelDesc = APP_NAME
val model_size = "7" val model_size = "7"
val model_n_params = "7" val model_n_params = "7"
val backend = "CPU" val backend = "CPU"
@ -158,27 +193,38 @@ class StubInferenceEngine : InferenceEngine {
_state.value = State.ModelReady _state.value = State.ModelReady
return result.toString() result.toString()
} catch (e: CancellationException) { } catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation // If coroutine is cancelled, propagate cancellation
Log.w(TAG, "Unexpected user cancellation while benchmarking!")
_state.value = State.ModelReady _state.value = State.ModelReady
throw e throw e
} catch (e: Exception) { } catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during benchmarking") _state.value = State.Error(e.message ?: "Unknown error during benchmarking")
return "Error: ${e.message}" "Error: ${e.message}"
} }
} }
/** /**
* Unloads the currently loaded model. * Unloads the currently loaded model.
*/ */
override suspend fun unloadModel() { override suspend fun unloadModel() =
Log.i(TAG, "unloadModel! state: ${_state.value}") withContext(llamaDispatcher) {
when(val state = _state.value) {
is State.ModelReady, is State.Error -> {
Log.i(TAG, "unloadModel! state: ${_state.value.javaClass.simpleName}")
_state.value = State.UnloadingModel
// Simulate model unloading time // Simulate model unloading time
delay(2000) delay(STUB_MODEL_UNLOADING_TIME)
_state.value = State.LibraryLoaded _state.value = State.LibraryLoaded
} }
else -> throw IllegalStateException(
"Cannot load model at ${_state.value.javaClass.simpleName}"
)
}
}
/** /**
* Cleans up resources when the engine is no longer needed. * Cleans up resources when the engine is no longer needed.

View File

@ -50,6 +50,7 @@ interface InferenceEngine {
object LibraryLoaded : State() object LibraryLoaded : State()
object LoadingModel : State() object LoadingModel : State()
object UnloadingModel : State()
object ModelReady : State() object ModelReady : State()
object ProcessingSystemPrompt : State() object ProcessingSystemPrompt : State()

View File

@ -87,7 +87,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
*/ */
override suspend fun loadModel(pathToModel: String) = override suspend fun loadModel(pathToModel: String) =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" } check(_state.value is State.LibraryLoaded) {
"Cannot load model in ${_state.value.javaClass.simpleName}!"
}
File(pathToModel).let { File(pathToModel).let {
require(it.exists()) { "Model file not found: $pathToModel" } require(it.exists()) { "Model file not found: $pathToModel" }
require(it.isFile) { "Model file is not a file: $pathToModel" } require(it.isFile) { "Model file is not a file: $pathToModel" }
@ -114,7 +116,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" } require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" } check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" } check(_state.value is State.ModelReady) {
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
}
Log.i(TAG, "Sending system prompt...") Log.i(TAG, "Sending system prompt...")
_readyForSystemPrompt = false _readyForSystemPrompt = false
@ -139,13 +143,14 @@ class LLamaAndroid private constructor() : InferenceEngine {
): Flow<String> = flow { ): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.ModelReady) { check(_state.value is State.ModelReady) {
"User prompt discarded due to: ${_state.value}" "User prompt discarded due to: ${_state.value.javaClass.simpleName}"
} }
try { try {
Log.i(TAG, "Sending user prompt...") Log.i(TAG, "Sending user prompt...")
_readyForSystemPrompt = false _readyForSystemPrompt = false
_state.value = State.ProcessingUserPrompt _state.value = State.ProcessingUserPrompt
processUserPrompt(message, predictLength).let { result -> processUserPrompt(message, predictLength).let { result ->
if (result != 0) { if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result") Log.e(TAG, "Failed to process user prompt: $result")
@ -194,16 +199,19 @@ class LLamaAndroid private constructor() : InferenceEngine {
*/ */
override suspend fun unloadModel() = override suspend fun unloadModel() =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
when(_state.value) { when(val state = _state.value) {
is State.ModelReady, is State.Error -> { is State.ModelReady, is State.Error -> {
Log.i(TAG, "Unloading model and free resources...") Log.i(TAG, "Unloading model and free resources...")
_readyForSystemPrompt = false _readyForSystemPrompt = false
_state.value = State.UnloadingModel
unload() unload()
_state.value = State.LibraryLoaded _state.value = State.LibraryLoaded
Log.i(TAG, "Model unloaded!") Log.i(TAG, "Model unloaded!")
Unit Unit
} }
else -> throw IllegalStateException("Cannot unload model in ${_state.value}") else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
} }
} }