LLama: add ModelUnloadingState to engine State; add missing state checks in stub engine; fix instrumentation engine's error messages

This commit is contained in:
Han Yin 2025-04-18 12:10:09 -07:00
parent 481ba6e9d3
commit c08d02d233
3 changed files with 162 additions and 107 deletions

View File

@ -3,13 +3,19 @@ package com.example.llama.revamp.engine
import android.llama.cpp.InferenceEngine
import android.llama.cpp.InferenceEngine.State
import android.util.Log
import com.example.llama.revamp.APP_NAME
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.jetbrains.annotations.TestOnly
import org.jetbrains.annotations.VisibleForTesting
import javax.inject.Singleton
@ -24,161 +30,201 @@ class StubInferenceEngine : InferenceEngine {
companion object {
private val TAG = StubInferenceEngine::class.java.simpleName
private const val STUB_MODEL_LOADING_TIME = 2000L
private const val STUB_BENCHMARKING_TIME = 4000L
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 3000L
private const val STUB_USER_PROMPT_PROCESSING_TIME = 1500L
private const val STUB_LIBRARY_LOADING_TIME = 2_000L
private const val STUB_MODEL_LOADING_TIME = 3_000L
private const val STUB_MODEL_UNLOADING_TIME = 2_000L
private const val STUB_BENCHMARKING_TIME = 8_000L
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 4_000L
private const val STUB_USER_PROMPT_PROCESSING_TIME = 2_000L
private const val STUB_TOKEN_GENERATION_TIME = 200L
}
private val _state = MutableStateFlow<State>(State.Uninitialized)
override val state: StateFlow<State> = _state
init {
Log.i(TAG, "Initiated!")
private var _readyForSystemPrompt = false
// Simulate library loading
_state.value = State.LibraryLoaded
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
init {
llamaScope.launch {
Log.i(TAG, "Initiated!")
// Simulate library loading
delay(STUB_LIBRARY_LOADING_TIME)
_state.value = State.LibraryLoaded
}
}
/**
* Loads a model from the given path.
*/
override suspend fun loadModel(pathToModel: String) {
Log.i(TAG, "loadModel! state: ${_state.value}")
override suspend fun loadModel(pathToModel: String) =
withContext(llamaDispatcher) {
Log.i(TAG, "loadModel! state: ${_state.value.javaClass.simpleName}")
check(_state.value is State.LibraryLoaded) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
try {
_state.value = State.LoadingModel
try {
_readyForSystemPrompt = false
_state.value = State.LoadingModel
// Simulate model loading
delay(STUB_MODEL_LOADING_TIME)
// Simulate model loading
delay(STUB_MODEL_LOADING_TIME)
_state.value = State.ModelReady
_readyForSystemPrompt = true
_state.value = State.ModelReady
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during model loading")
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during model loading")
}
}
}
/**
* Process the plain text system prompt
*/
override suspend fun setSystemPrompt(prompt: String) {
try {
_state.value = State.ProcessingSystemPrompt
override suspend fun setSystemPrompt(prompt: String) =
withContext(llamaDispatcher) {
check(_state.value is State.ModelReady) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
check(_readyForSystemPrompt) {
"System prompt must be set ** RIGHT AFTER ** model loaded!"
}
// Simulate processing system prompt
delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
try {
_state.value = State.ProcessingSystemPrompt
_state.value = State.ModelReady
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during model loading")
// Simulate processing system prompt
delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
_state.value = State.ModelReady
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during model loading")
}
}
}
/**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/
override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> {
Log.i(TAG, "sendUserPrompt! state: ${_state.value}")
override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.ModelReady) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
_state.value = State.ProcessingUserPrompt
try {
Log.i(TAG, "sendUserPrompt! \n$message")
_state.value = State.ProcessingUserPrompt
// This would be replaced with actual token generation logic
return flow {
try {
// Simulate longer processing time (1.5 seconds)
delay(STUB_USER_PROMPT_PROCESSING_TIME)
// Simulate longer processing time
delay(STUB_USER_PROMPT_PROCESSING_TIME)
_state.value = State.Generating
_state.value = State.Generating
// Simulate token generation
val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
response.split(" ").forEach {
emit("$it ")
delay(STUB_TOKEN_GENERATION_TIME)
}
_state.value = State.ModelReady
} catch (e: CancellationException) {
// Handle cancellation gracefully
_state.value = State.ModelReady
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
throw e
}
}.catch { e ->
// If it's not a cancellation, update state to error
if (e !is CancellationException) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
// Simulate token generation
val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
response.split(" ").forEach {
emit("$it ")
delay(STUB_TOKEN_GENERATION_TIME)
}
_state.value = State.ModelReady
} catch (e: CancellationException) {
// Handle cancellation gracefully
_state.value = State.ModelReady
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
throw e
}
}.catch { e ->
// If it's not a cancellation, update state to error
if (e !is CancellationException) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
}
throw e
}
/**
* Runs a benchmark with the specified parameters.
*/
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String {
Log.i(TAG, "bench! state: ${_state.value}")
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
withContext(llamaDispatcher) {
check(_state.value is State.ModelReady) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
_state.value = State.Benchmarking
try {
Log.i(TAG, "bench! state: ${_state.value}")
_state.value = State.Benchmarking
try {
// Simulate benchmark running
delay(STUB_BENCHMARKING_TIME)
// Simulate benchmark running
delay(STUB_BENCHMARKING_TIME)
// Generate fake benchmark results
val modelDesc = "Kleidi Llama"
val model_size = "7"
val model_n_params = "7"
val backend = "CPU"
// Generate fake benchmark results
val modelDesc = APP_NAME
val model_size = "7"
val model_n_params = "7"
val backend = "CPU"
// Random values for benchmarks
val pp_avg = (51.4 + Math.random() * 5.14).toFloat()
val pp_std = (5.14 + Math.random() * 0.514).toFloat()
val tg_avg = (11.4 + Math.random() * 1.14).toFloat()
val tg_std = (1.14 + Math.random() * 0.114).toFloat()
// Random values for benchmarks
val pp_avg = (51.4 + Math.random() * 5.14).toFloat()
val pp_std = (5.14 + Math.random() * 0.514).toFloat()
val tg_avg = (11.4 + Math.random() * 1.14).toFloat()
val tg_std = (1.14 + Math.random() * 0.114).toFloat()
val result = StringBuilder()
result.append("| model | size | params | backend | test | t/s |\n")
result.append("| --- | --- | --- | --- | --- | --- |\n")
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
result.append("$backend | pp $pp | $pp_avg ± $pp_std |\n")
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
val result = StringBuilder()
result.append("| model | size | params | backend | test | t/s |\n")
result.append("| --- | --- | --- | --- | --- | --- |\n")
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
result.append("$backend | pp $pp | $pp_avg ± $pp_std |\n")
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
_state.value = State.ModelReady
_state.value = State.ModelReady
return result.toString()
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
_state.value = State.ModelReady
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
return "Error: ${e.message}"
result.toString()
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
Log.w(TAG, "Unexpected user cancellation while benchmarking!")
_state.value = State.ModelReady
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
"Error: ${e.message}"
}
}
}
/**
* Unloads the currently loaded model.
*/
override suspend fun unloadModel() {
Log.i(TAG, "unloadModel! state: ${_state.value}")
override suspend fun unloadModel() =
withContext(llamaDispatcher) {
when(val state = _state.value) {
is State.ModelReady, is State.Error -> {
Log.i(TAG, "unloadModel! state: ${_state.value.javaClass.simpleName}")
_state.value = State.UnloadingModel
// Simulate model unloading time
delay(2000)
_state.value = State.LibraryLoaded
}
// Simulate model unloading time
delay(STUB_MODEL_UNLOADING_TIME)
_state.value = State.LibraryLoaded
}
else -> throw IllegalStateException(
"Cannot load model at ${_state.value.javaClass.simpleName}"
)
}
}
/**
* Cleans up resources when the engine is no longer needed.

View File

@ -50,6 +50,7 @@ interface InferenceEngine {
object LibraryLoaded : State()
object LoadingModel : State()
object UnloadingModel : State()
object ModelReady : State()
object ProcessingSystemPrompt : State()

View File

@ -87,7 +87,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
*/
override suspend fun loadModel(pathToModel: String) =
withContext(llamaDispatcher) {
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
check(_state.value is State.LibraryLoaded) {
"Cannot load model in ${_state.value.javaClass.simpleName}!"
}
File(pathToModel).let {
require(it.exists()) { "Model file not found: $pathToModel" }
require(it.isFile) { "Model file is not a file: $pathToModel" }
@ -114,7 +116,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
withContext(llamaDispatcher) {
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" }
check(_state.value is State.ModelReady) {
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
}
Log.i(TAG, "Sending system prompt...")
_readyForSystemPrompt = false
@ -139,13 +143,14 @@ class LLamaAndroid private constructor() : InferenceEngine {
): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.ModelReady) {
"User prompt discarded due to: ${_state.value}"
"User prompt discarded due to: ${_state.value.javaClass.simpleName}"
}
try {
Log.i(TAG, "Sending user prompt...")
_readyForSystemPrompt = false
_state.value = State.ProcessingUserPrompt
processUserPrompt(message, predictLength).let { result ->
if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result")
@ -194,16 +199,19 @@ class LLamaAndroid private constructor() : InferenceEngine {
*/
override suspend fun unloadModel() =
withContext(llamaDispatcher) {
when(_state.value) {
when(val state = _state.value) {
is State.ModelReady, is State.Error -> {
Log.i(TAG, "Unloading model and free resources...")
_readyForSystemPrompt = false
_state.value = State.UnloadingModel
unload()
_state.value = State.LibraryLoaded
Log.i(TAG, "Model unloaded!")
Unit
}
else -> throw IllegalStateException("Cannot unload model in ${_state.value}")
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
}
}