LLama: add ModelUnloadingState to engine State; add missing state checks in stub engine; fix instrumentation engine's error messages
This commit is contained in:
parent
481ba6e9d3
commit
c08d02d233
|
|
@ -3,13 +3,19 @@ package com.example.llama.revamp.engine
|
|||
import android.llama.cpp.InferenceEngine
|
||||
import android.llama.cpp.InferenceEngine.State
|
||||
import android.util.Log
|
||||
import com.example.llama.revamp.APP_NAME
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.delay
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.catch
|
||||
import kotlinx.coroutines.flow.flow
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import org.jetbrains.annotations.TestOnly
|
||||
import org.jetbrains.annotations.VisibleForTesting
|
||||
import javax.inject.Singleton
|
||||
|
|
@ -24,161 +30,201 @@ class StubInferenceEngine : InferenceEngine {
|
|||
companion object {
|
||||
private val TAG = StubInferenceEngine::class.java.simpleName
|
||||
|
||||
private const val STUB_MODEL_LOADING_TIME = 2000L
|
||||
private const val STUB_BENCHMARKING_TIME = 4000L
|
||||
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 3000L
|
||||
private const val STUB_USER_PROMPT_PROCESSING_TIME = 1500L
|
||||
private const val STUB_LIBRARY_LOADING_TIME = 2_000L
|
||||
private const val STUB_MODEL_LOADING_TIME = 3_000L
|
||||
private const val STUB_MODEL_UNLOADING_TIME = 2_000L
|
||||
private const val STUB_BENCHMARKING_TIME = 8_000L
|
||||
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 4_000L
|
||||
private const val STUB_USER_PROMPT_PROCESSING_TIME = 2_000L
|
||||
private const val STUB_TOKEN_GENERATION_TIME = 200L
|
||||
}
|
||||
|
||||
private val _state = MutableStateFlow<State>(State.Uninitialized)
|
||||
override val state: StateFlow<State> = _state
|
||||
|
||||
init {
|
||||
Log.i(TAG, "Initiated!")
|
||||
private var _readyForSystemPrompt = false
|
||||
|
||||
// Simulate library loading
|
||||
_state.value = State.LibraryLoaded
|
||||
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
|
||||
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
|
||||
|
||||
init {
|
||||
llamaScope.launch {
|
||||
Log.i(TAG, "Initiated!")
|
||||
|
||||
// Simulate library loading
|
||||
delay(STUB_LIBRARY_LOADING_TIME)
|
||||
|
||||
_state.value = State.LibraryLoaded
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads a model from the given path.
|
||||
*/
|
||||
override suspend fun loadModel(pathToModel: String) {
|
||||
Log.i(TAG, "loadModel! state: ${_state.value}")
|
||||
override suspend fun loadModel(pathToModel: String) =
|
||||
withContext(llamaDispatcher) {
|
||||
Log.i(TAG, "loadModel! state: ${_state.value.javaClass.simpleName}")
|
||||
check(_state.value is State.LibraryLoaded) {
|
||||
"Cannot load model at ${_state.value.javaClass.simpleName}"
|
||||
}
|
||||
|
||||
try {
|
||||
_state.value = State.LoadingModel
|
||||
try {
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = State.LoadingModel
|
||||
|
||||
// Simulate model loading
|
||||
delay(STUB_MODEL_LOADING_TIME)
|
||||
// Simulate model loading
|
||||
delay(STUB_MODEL_LOADING_TIME)
|
||||
|
||||
_state.value = State.ModelReady
|
||||
_readyForSystemPrompt = true
|
||||
_state.value = State.ModelReady
|
||||
|
||||
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during model loading")
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during model loading")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process the plain text system prompt
|
||||
*/
|
||||
override suspend fun setSystemPrompt(prompt: String) {
|
||||
try {
|
||||
_state.value = State.ProcessingSystemPrompt
|
||||
override suspend fun setSystemPrompt(prompt: String) =
|
||||
withContext(llamaDispatcher) {
|
||||
check(_state.value is State.ModelReady) {
|
||||
"Cannot load model at ${_state.value.javaClass.simpleName}"
|
||||
}
|
||||
check(_readyForSystemPrompt) {
|
||||
"System prompt must be set ** RIGHT AFTER ** model loaded!"
|
||||
}
|
||||
|
||||
// Simulate processing system prompt
|
||||
delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
|
||||
try {
|
||||
_state.value = State.ProcessingSystemPrompt
|
||||
|
||||
_state.value = State.ModelReady
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during model loading")
|
||||
// Simulate processing system prompt
|
||||
delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
|
||||
|
||||
_state.value = State.ModelReady
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during model loading")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
||||
*/
|
||||
override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> {
|
||||
Log.i(TAG, "sendUserPrompt! state: ${_state.value}")
|
||||
override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> = flow {
|
||||
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
||||
check(_state.value is State.ModelReady) {
|
||||
"Cannot load model at ${_state.value.javaClass.simpleName}"
|
||||
}
|
||||
|
||||
_state.value = State.ProcessingUserPrompt
|
||||
try {
|
||||
Log.i(TAG, "sendUserPrompt! \n$message")
|
||||
_state.value = State.ProcessingUserPrompt
|
||||
|
||||
// This would be replaced with actual token generation logic
|
||||
return flow {
|
||||
try {
|
||||
// Simulate longer processing time (1.5 seconds)
|
||||
delay(STUB_USER_PROMPT_PROCESSING_TIME)
|
||||
// Simulate longer processing time
|
||||
delay(STUB_USER_PROMPT_PROCESSING_TIME)
|
||||
|
||||
_state.value = State.Generating
|
||||
_state.value = State.Generating
|
||||
|
||||
// Simulate token generation
|
||||
val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
|
||||
response.split(" ").forEach {
|
||||
emit("$it ")
|
||||
delay(STUB_TOKEN_GENERATION_TIME)
|
||||
}
|
||||
|
||||
_state.value = State.ModelReady
|
||||
} catch (e: CancellationException) {
|
||||
// Handle cancellation gracefully
|
||||
_state.value = State.ModelReady
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during generation")
|
||||
throw e
|
||||
}
|
||||
}.catch { e ->
|
||||
// If it's not a cancellation, update state to error
|
||||
if (e !is CancellationException) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during generation")
|
||||
// Simulate token generation
|
||||
val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message"
|
||||
response.split(" ").forEach {
|
||||
emit("$it ")
|
||||
delay(STUB_TOKEN_GENERATION_TIME)
|
||||
}
|
||||
|
||||
_state.value = State.ModelReady
|
||||
} catch (e: CancellationException) {
|
||||
// Handle cancellation gracefully
|
||||
_state.value = State.ModelReady
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during generation")
|
||||
throw e
|
||||
}
|
||||
}.catch { e ->
|
||||
// If it's not a cancellation, update state to error
|
||||
if (e !is CancellationException) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during generation")
|
||||
}
|
||||
throw e
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs a benchmark with the specified parameters.
|
||||
*/
|
||||
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String {
|
||||
Log.i(TAG, "bench! state: ${_state.value}")
|
||||
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
||||
withContext(llamaDispatcher) {
|
||||
check(_state.value is State.ModelReady) {
|
||||
"Cannot load model at ${_state.value.javaClass.simpleName}"
|
||||
}
|
||||
|
||||
_state.value = State.Benchmarking
|
||||
try {
|
||||
Log.i(TAG, "bench! state: ${_state.value}")
|
||||
_state.value = State.Benchmarking
|
||||
|
||||
try {
|
||||
// Simulate benchmark running
|
||||
delay(STUB_BENCHMARKING_TIME)
|
||||
// Simulate benchmark running
|
||||
delay(STUB_BENCHMARKING_TIME)
|
||||
|
||||
// Generate fake benchmark results
|
||||
val modelDesc = "Kleidi Llama"
|
||||
val model_size = "7"
|
||||
val model_n_params = "7"
|
||||
val backend = "CPU"
|
||||
// Generate fake benchmark results
|
||||
val modelDesc = APP_NAME
|
||||
val model_size = "7"
|
||||
val model_n_params = "7"
|
||||
val backend = "CPU"
|
||||
|
||||
// Random values for benchmarks
|
||||
val pp_avg = (51.4 + Math.random() * 5.14).toFloat()
|
||||
val pp_std = (5.14 + Math.random() * 0.514).toFloat()
|
||||
val tg_avg = (11.4 + Math.random() * 1.14).toFloat()
|
||||
val tg_std = (1.14 + Math.random() * 0.114).toFloat()
|
||||
// Random values for benchmarks
|
||||
val pp_avg = (51.4 + Math.random() * 5.14).toFloat()
|
||||
val pp_std = (5.14 + Math.random() * 0.514).toFloat()
|
||||
val tg_avg = (11.4 + Math.random() * 1.14).toFloat()
|
||||
val tg_std = (1.14 + Math.random() * 0.114).toFloat()
|
||||
|
||||
val result = StringBuilder()
|
||||
result.append("| model | size | params | backend | test | t/s |\n")
|
||||
result.append("| --- | --- | --- | --- | --- | --- |\n")
|
||||
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
|
||||
result.append("$backend | pp $pp | $pp_avg ± $pp_std |\n")
|
||||
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
|
||||
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
|
||||
val result = StringBuilder()
|
||||
result.append("| model | size | params | backend | test | t/s |\n")
|
||||
result.append("| --- | --- | --- | --- | --- | --- |\n")
|
||||
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
|
||||
result.append("$backend | pp $pp | $pp_avg ± $pp_std |\n")
|
||||
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
|
||||
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
|
||||
|
||||
_state.value = State.ModelReady
|
||||
_state.value = State.ModelReady
|
||||
|
||||
return result.toString()
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
_state.value = State.ModelReady
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
|
||||
return "Error: ${e.message}"
|
||||
result.toString()
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
Log.w(TAG, "Unexpected user cancellation while benchmarking!")
|
||||
_state.value = State.ModelReady
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
|
||||
"Error: ${e.message}"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unloads the currently loaded model.
|
||||
*/
|
||||
override suspend fun unloadModel() {
|
||||
Log.i(TAG, "unloadModel! state: ${_state.value}")
|
||||
override suspend fun unloadModel() =
|
||||
withContext(llamaDispatcher) {
|
||||
when(val state = _state.value) {
|
||||
is State.ModelReady, is State.Error -> {
|
||||
Log.i(TAG, "unloadModel! state: ${_state.value.javaClass.simpleName}")
|
||||
_state.value = State.UnloadingModel
|
||||
|
||||
// Simulate model unloading time
|
||||
delay(2000)
|
||||
_state.value = State.LibraryLoaded
|
||||
}
|
||||
// Simulate model unloading time
|
||||
delay(STUB_MODEL_UNLOADING_TIME)
|
||||
|
||||
_state.value = State.LibraryLoaded
|
||||
}
|
||||
else -> throw IllegalStateException(
|
||||
"Cannot load model at ${_state.value.javaClass.simpleName}"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleans up resources when the engine is no longer needed.
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ interface InferenceEngine {
|
|||
object LibraryLoaded : State()
|
||||
|
||||
object LoadingModel : State()
|
||||
object UnloadingModel : State()
|
||||
object ModelReady : State()
|
||||
|
||||
object ProcessingSystemPrompt : State()
|
||||
|
|
|
|||
|
|
@ -87,7 +87,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
*/
|
||||
override suspend fun loadModel(pathToModel: String) =
|
||||
withContext(llamaDispatcher) {
|
||||
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
|
||||
check(_state.value is State.LibraryLoaded) {
|
||||
"Cannot load model in ${_state.value.javaClass.simpleName}!"
|
||||
}
|
||||
File(pathToModel).let {
|
||||
require(it.exists()) { "Model file not found: $pathToModel" }
|
||||
require(it.isFile) { "Model file is not a file: $pathToModel" }
|
||||
|
|
@ -114,7 +116,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
withContext(llamaDispatcher) {
|
||||
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
|
||||
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
|
||||
check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" }
|
||||
check(_state.value is State.ModelReady) {
|
||||
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
|
||||
}
|
||||
|
||||
Log.i(TAG, "Sending system prompt...")
|
||||
_readyForSystemPrompt = false
|
||||
|
|
@ -139,13 +143,14 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
): Flow<String> = flow {
|
||||
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
||||
check(_state.value is State.ModelReady) {
|
||||
"User prompt discarded due to: ${_state.value}"
|
||||
"User prompt discarded due to: ${_state.value.javaClass.simpleName}"
|
||||
}
|
||||
|
||||
try {
|
||||
Log.i(TAG, "Sending user prompt...")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = State.ProcessingUserPrompt
|
||||
|
||||
processUserPrompt(message, predictLength).let { result ->
|
||||
if (result != 0) {
|
||||
Log.e(TAG, "Failed to process user prompt: $result")
|
||||
|
|
@ -194,16 +199,19 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
*/
|
||||
override suspend fun unloadModel() =
|
||||
withContext(llamaDispatcher) {
|
||||
when(_state.value) {
|
||||
when(val state = _state.value) {
|
||||
is State.ModelReady, is State.Error -> {
|
||||
Log.i(TAG, "Unloading model and free resources...")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = State.UnloadingModel
|
||||
|
||||
unload()
|
||||
|
||||
_state.value = State.LibraryLoaded
|
||||
Log.i(TAG, "Model unloaded!")
|
||||
Unit
|
||||
}
|
||||
else -> throw IllegalStateException("Cannot unload model in ${_state.value}")
|
||||
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue