LLama: add ModelUnloadingState to engine State; add missing state checks in stub engine; fix instrumentation engine's error messages

This commit is contained in:
Han Yin 2025-04-18 12:10:09 -07:00
parent 481ba6e9d3
commit c08d02d233
3 changed files with 162 additions and 107 deletions

View File

@ -3,13 +3,19 @@ package com.example.llama.revamp.engine
import android.llama.cpp.InferenceEngine import android.llama.cpp.InferenceEngine
import android.llama.cpp.InferenceEngine.State import android.llama.cpp.InferenceEngine.State
import android.util.Log import android.util.Log
import com.example.llama.revamp.APP_NAME
import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.delay import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import org.jetbrains.annotations.TestOnly import org.jetbrains.annotations.TestOnly
import org.jetbrains.annotations.VisibleForTesting import org.jetbrains.annotations.VisibleForTesting
import javax.inject.Singleton import javax.inject.Singleton
@ -24,161 +30,201 @@ class StubInferenceEngine : InferenceEngine {
companion object { companion object {
private val TAG = StubInferenceEngine::class.java.simpleName private val TAG = StubInferenceEngine::class.java.simpleName
private const val STUB_MODEL_LOADING_TIME = 2000L private const val STUB_LIBRARY_LOADING_TIME = 2_000L
private const val STUB_BENCHMARKING_TIME = 4000L private const val STUB_MODEL_LOADING_TIME = 3_000L
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 3000L private const val STUB_MODEL_UNLOADING_TIME = 2_000L
private const val STUB_USER_PROMPT_PROCESSING_TIME = 1500L private const val STUB_BENCHMARKING_TIME = 8_000L
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 4_000L
private const val STUB_USER_PROMPT_PROCESSING_TIME = 2_000L
private const val STUB_TOKEN_GENERATION_TIME = 200L private const val STUB_TOKEN_GENERATION_TIME = 200L
} }
private val _state = MutableStateFlow<State>(State.Uninitialized) private val _state = MutableStateFlow<State>(State.Uninitialized)
override val state: StateFlow<State> = _state override val state: StateFlow<State> = _state
init { private var _readyForSystemPrompt = false
Log.i(TAG, "Initiated!")
// Simulate library loading private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
_state.value = State.LibraryLoaded private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
init {
llamaScope.launch {
Log.i(TAG, "Initiated!")
// Simulate library loading
delay(STUB_LIBRARY_LOADING_TIME)
_state.value = State.LibraryLoaded
}
} }
/** /**
* Loads a model from the given path. * Loads a model from the given path.
*/ */
override suspend fun loadModel(pathToModel: String) { override suspend fun loadModel(pathToModel: String) =
Log.i(TAG, "loadModel! state: ${_state.value}") withContext(llamaDispatcher) {
Log.i(TAG, "loadModel! state: ${_state.value.javaClass.simpleName}")
check(_state.value is State.LibraryLoaded) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
try { try {
_state.value = State.LoadingModel _readyForSystemPrompt = false
_state.value = State.LoadingModel
// Simulate model loading // Simulate model loading
delay(STUB_MODEL_LOADING_TIME) delay(STUB_MODEL_LOADING_TIME)
_state.value = State.ModelReady _readyForSystemPrompt = true
_state.value = State.ModelReady
} catch (e: CancellationException) {
} catch (e: CancellationException) { // If coroutine is cancelled, propagate cancellation
// If coroutine is cancelled, propagate cancellation throw e
throw e } catch (e: Exception) {
} catch (e: Exception) { _state.value = State.Error(e.message ?: "Unknown error during model loading")
_state.value = State.Error(e.message ?: "Unknown error during model loading") }
} }
}
/** /**
* Process the plain text system prompt * Process the plain text system prompt
*/ */
override suspend fun setSystemPrompt(prompt: String) { override suspend fun setSystemPrompt(prompt: String) =
try { withContext(llamaDispatcher) {
_state.value = State.ProcessingSystemPrompt check(_state.value is State.ModelReady) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
check(_readyForSystemPrompt) {
"System prompt must be set ** RIGHT AFTER ** model loaded!"
}
// Simulate processing system prompt try {
delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME) _state.value = State.ProcessingSystemPrompt
_state.value = State.ModelReady // Simulate processing system prompt
} catch (e: CancellationException) { delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
// If coroutine is cancelled, propagate cancellation
throw e _state.value = State.ModelReady
} catch (e: Exception) { } catch (e: CancellationException) {
_state.value = State.Error(e.message ?: "Unknown error during model loading") // If coroutine is cancelled, propagate cancellation
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during model loading")
}
} }
}
/** /**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens. * Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/ */
override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> { override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> = flow {
Log.i(TAG, "sendUserPrompt! state: ${_state.value}") require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.ModelReady) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
_state.value = State.ProcessingUserPrompt try {
Log.i(TAG, "sendUserPrompt! \n$message")
_state.value = State.ProcessingUserPrompt
// This would be replaced with actual token generation logic // Simulate longer processing time
return flow { delay(STUB_USER_PROMPT_PROCESSING_TIME)
try {
// Simulate longer processing time (1.5 seconds)
delay(STUB_USER_PROMPT_PROCESSING_TIME)
_state.value = State.Generating _state.value = State.Generating
// Simulate token generation // Simulate token generation
val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message" val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
response.split(" ").forEach { response.split(" ").forEach {
emit("$it ") emit("$it ")
delay(STUB_TOKEN_GENERATION_TIME) delay(STUB_TOKEN_GENERATION_TIME)
}
_state.value = State.ModelReady
} catch (e: CancellationException) {
// Handle cancellation gracefully
_state.value = State.ModelReady
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
throw e
}
}.catch { e ->
// If it's not a cancellation, update state to error
if (e !is CancellationException) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
} }
_state.value = State.ModelReady
} catch (e: CancellationException) {
// Handle cancellation gracefully
_state.value = State.ModelReady
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
throw e throw e
} }
}.catch { e ->
// If it's not a cancellation, update state to error
if (e !is CancellationException) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
}
throw e
} }
/** /**
* Runs a benchmark with the specified parameters. * Runs a benchmark with the specified parameters.
*/ */
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String { override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
Log.i(TAG, "bench! state: ${_state.value}") withContext(llamaDispatcher) {
check(_state.value is State.ModelReady) {
"Cannot load model at ${_state.value.javaClass.simpleName}"
}
_state.value = State.Benchmarking try {
Log.i(TAG, "bench! state: ${_state.value}")
_state.value = State.Benchmarking
try { // Simulate benchmark running
// Simulate benchmark running delay(STUB_BENCHMARKING_TIME)
delay(STUB_BENCHMARKING_TIME)
// Generate fake benchmark results // Generate fake benchmark results
val modelDesc = "Kleidi Llama" val modelDesc = APP_NAME
val model_size = "7" val model_size = "7"
val model_n_params = "7" val model_n_params = "7"
val backend = "CPU" val backend = "CPU"
// Random values for benchmarks // Random values for benchmarks
val pp_avg = (51.4 + Math.random() * 5.14).toFloat() val pp_avg = (51.4 + Math.random() * 5.14).toFloat()
val pp_std = (5.14 + Math.random() * 0.514).toFloat() val pp_std = (5.14 + Math.random() * 0.514).toFloat()
val tg_avg = (11.4 + Math.random() * 1.14).toFloat() val tg_avg = (11.4 + Math.random() * 1.14).toFloat()
val tg_std = (1.14 + Math.random() * 0.114).toFloat() val tg_std = (1.14 + Math.random() * 0.114).toFloat()
val result = StringBuilder() val result = StringBuilder()
result.append("| model | size | params | backend | test | t/s |\n") result.append("| model | size | params | backend | test | t/s |\n")
result.append("| --- | --- | --- | --- | --- | --- |\n") result.append("| --- | --- | --- | --- | --- | --- |\n")
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ") result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
result.append("$backend | pp $pp | $pp_avg ± $pp_std |\n") result.append("$backend | pp $pp | $pp_avg ± $pp_std |\n")
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ") result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n") result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
_state.value = State.ModelReady _state.value = State.ModelReady
return result.toString() result.toString()
} catch (e: CancellationException) { } catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation // If coroutine is cancelled, propagate cancellation
_state.value = State.ModelReady Log.w(TAG, "Unexpected user cancellation while benchmarking!")
throw e _state.value = State.ModelReady
} catch (e: Exception) { throw e
_state.value = State.Error(e.message ?: "Unknown error during benchmarking") } catch (e: Exception) {
return "Error: ${e.message}" _state.value = State.Error(e.message ?: "Unknown error during benchmarking")
"Error: ${e.message}"
}
} }
}
/** /**
* Unloads the currently loaded model. * Unloads the currently loaded model.
*/ */
override suspend fun unloadModel() { override suspend fun unloadModel() =
Log.i(TAG, "unloadModel! state: ${_state.value}") withContext(llamaDispatcher) {
when(val state = _state.value) {
is State.ModelReady, is State.Error -> {
Log.i(TAG, "unloadModel! state: ${_state.value.javaClass.simpleName}")
_state.value = State.UnloadingModel
// Simulate model unloading time // Simulate model unloading time
delay(2000) delay(STUB_MODEL_UNLOADING_TIME)
_state.value = State.LibraryLoaded
} _state.value = State.LibraryLoaded
}
else -> throw IllegalStateException(
"Cannot load model at ${_state.value.javaClass.simpleName}"
)
}
}
/** /**
* Cleans up resources when the engine is no longer needed. * Cleans up resources when the engine is no longer needed.

View File

@ -50,6 +50,7 @@ interface InferenceEngine {
object LibraryLoaded : State() object LibraryLoaded : State()
object LoadingModel : State() object LoadingModel : State()
object UnloadingModel : State()
object ModelReady : State() object ModelReady : State()
object ProcessingSystemPrompt : State() object ProcessingSystemPrompt : State()

View File

@ -87,7 +87,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
*/ */
override suspend fun loadModel(pathToModel: String) = override suspend fun loadModel(pathToModel: String) =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" } check(_state.value is State.LibraryLoaded) {
"Cannot load model in ${_state.value.javaClass.simpleName}!"
}
File(pathToModel).let { File(pathToModel).let {
require(it.exists()) { "Model file not found: $pathToModel" } require(it.exists()) { "Model file not found: $pathToModel" }
require(it.isFile) { "Model file is not a file: $pathToModel" } require(it.isFile) { "Model file is not a file: $pathToModel" }
@ -114,7 +116,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" } require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" } check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" } check(_state.value is State.ModelReady) {
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
}
Log.i(TAG, "Sending system prompt...") Log.i(TAG, "Sending system prompt...")
_readyForSystemPrompt = false _readyForSystemPrompt = false
@ -139,13 +143,14 @@ class LLamaAndroid private constructor() : InferenceEngine {
): Flow<String> = flow { ): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.ModelReady) { check(_state.value is State.ModelReady) {
"User prompt discarded due to: ${_state.value}" "User prompt discarded due to: ${_state.value.javaClass.simpleName}"
} }
try { try {
Log.i(TAG, "Sending user prompt...") Log.i(TAG, "Sending user prompt...")
_readyForSystemPrompt = false _readyForSystemPrompt = false
_state.value = State.ProcessingUserPrompt _state.value = State.ProcessingUserPrompt
processUserPrompt(message, predictLength).let { result -> processUserPrompt(message, predictLength).let { result ->
if (result != 0) { if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result") Log.e(TAG, "Failed to process user prompt: $result")
@ -194,16 +199,19 @@ class LLamaAndroid private constructor() : InferenceEngine {
*/ */
override suspend fun unloadModel() = override suspend fun unloadModel() =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
when(_state.value) { when(val state = _state.value) {
is State.ModelReady, is State.Error -> { is State.ModelReady, is State.Error -> {
Log.i(TAG, "Unloading model and free resources...") Log.i(TAG, "Unloading model and free resources...")
_readyForSystemPrompt = false _readyForSystemPrompt = false
_state.value = State.UnloadingModel
unload() unload()
_state.value = State.LibraryLoaded _state.value = State.LibraryLoaded
Log.i(TAG, "Model unloaded!") Log.i(TAG, "Model unloaded!")
Unit Unit
} }
else -> throw IllegalStateException("Cannot unload model in ${_state.value}") else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
} }
} }