LLama: add ModelUnloadingState to engine State; add missing state checks in stub engine; fix instrumentation engine's error messages
This commit is contained in:
parent
481ba6e9d3
commit
c08d02d233
|
|
@ -3,13 +3,19 @@ package com.example.llama.revamp.engine
|
||||||
import android.llama.cpp.InferenceEngine
|
import android.llama.cpp.InferenceEngine
|
||||||
import android.llama.cpp.InferenceEngine.State
|
import android.llama.cpp.InferenceEngine.State
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
|
import com.example.llama.revamp.APP_NAME
|
||||||
import kotlinx.coroutines.CancellationException
|
import kotlinx.coroutines.CancellationException
|
||||||
|
import kotlinx.coroutines.CoroutineScope
|
||||||
|
import kotlinx.coroutines.Dispatchers
|
||||||
|
import kotlinx.coroutines.SupervisorJob
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.flow.Flow
|
import kotlinx.coroutines.flow.Flow
|
||||||
import kotlinx.coroutines.flow.MutableStateFlow
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.flow.catch
|
import kotlinx.coroutines.flow.catch
|
||||||
import kotlinx.coroutines.flow.flow
|
import kotlinx.coroutines.flow.flow
|
||||||
|
import kotlinx.coroutines.launch
|
||||||
|
import kotlinx.coroutines.withContext
|
||||||
import org.jetbrains.annotations.TestOnly
|
import org.jetbrains.annotations.TestOnly
|
||||||
import org.jetbrains.annotations.VisibleForTesting
|
import org.jetbrains.annotations.VisibleForTesting
|
||||||
import javax.inject.Singleton
|
import javax.inject.Singleton
|
||||||
|
|
@ -24,38 +30,54 @@ class StubInferenceEngine : InferenceEngine {
|
||||||
companion object {
|
companion object {
|
||||||
private val TAG = StubInferenceEngine::class.java.simpleName
|
private val TAG = StubInferenceEngine::class.java.simpleName
|
||||||
|
|
||||||
private const val STUB_MODEL_LOADING_TIME = 2000L
|
private const val STUB_LIBRARY_LOADING_TIME = 2_000L
|
||||||
private const val STUB_BENCHMARKING_TIME = 4000L
|
private const val STUB_MODEL_LOADING_TIME = 3_000L
|
||||||
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 3000L
|
private const val STUB_MODEL_UNLOADING_TIME = 2_000L
|
||||||
private const val STUB_USER_PROMPT_PROCESSING_TIME = 1500L
|
private const val STUB_BENCHMARKING_TIME = 8_000L
|
||||||
|
private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 4_000L
|
||||||
|
private const val STUB_USER_PROMPT_PROCESSING_TIME = 2_000L
|
||||||
private const val STUB_TOKEN_GENERATION_TIME = 200L
|
private const val STUB_TOKEN_GENERATION_TIME = 200L
|
||||||
}
|
}
|
||||||
|
|
||||||
private val _state = MutableStateFlow<State>(State.Uninitialized)
|
private val _state = MutableStateFlow<State>(State.Uninitialized)
|
||||||
override val state: StateFlow<State> = _state
|
override val state: StateFlow<State> = _state
|
||||||
|
|
||||||
|
private var _readyForSystemPrompt = false
|
||||||
|
|
||||||
|
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
|
||||||
|
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
|
||||||
|
|
||||||
init {
|
init {
|
||||||
|
llamaScope.launch {
|
||||||
Log.i(TAG, "Initiated!")
|
Log.i(TAG, "Initiated!")
|
||||||
|
|
||||||
// Simulate library loading
|
// Simulate library loading
|
||||||
|
delay(STUB_LIBRARY_LOADING_TIME)
|
||||||
|
|
||||||
_state.value = State.LibraryLoaded
|
_state.value = State.LibraryLoaded
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads a model from the given path.
|
* Loads a model from the given path.
|
||||||
*/
|
*/
|
||||||
override suspend fun loadModel(pathToModel: String) {
|
override suspend fun loadModel(pathToModel: String) =
|
||||||
Log.i(TAG, "loadModel! state: ${_state.value}")
|
withContext(llamaDispatcher) {
|
||||||
|
Log.i(TAG, "loadModel! state: ${_state.value.javaClass.simpleName}")
|
||||||
|
check(_state.value is State.LibraryLoaded) {
|
||||||
|
"Cannot load model at ${_state.value.javaClass.simpleName}"
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
_readyForSystemPrompt = false
|
||||||
_state.value = State.LoadingModel
|
_state.value = State.LoadingModel
|
||||||
|
|
||||||
// Simulate model loading
|
// Simulate model loading
|
||||||
delay(STUB_MODEL_LOADING_TIME)
|
delay(STUB_MODEL_LOADING_TIME)
|
||||||
|
|
||||||
|
_readyForSystemPrompt = true
|
||||||
_state.value = State.ModelReady
|
_state.value = State.ModelReady
|
||||||
|
|
||||||
|
|
||||||
} catch (e: CancellationException) {
|
} catch (e: CancellationException) {
|
||||||
// If coroutine is cancelled, propagate cancellation
|
// If coroutine is cancelled, propagate cancellation
|
||||||
throw e
|
throw e
|
||||||
|
|
@ -67,7 +89,15 @@ class StubInferenceEngine : InferenceEngine {
|
||||||
/**
|
/**
|
||||||
* Process the plain text system prompt
|
* Process the plain text system prompt
|
||||||
*/
|
*/
|
||||||
override suspend fun setSystemPrompt(prompt: String) {
|
override suspend fun setSystemPrompt(prompt: String) =
|
||||||
|
withContext(llamaDispatcher) {
|
||||||
|
check(_state.value is State.ModelReady) {
|
||||||
|
"Cannot load model at ${_state.value.javaClass.simpleName}"
|
||||||
|
}
|
||||||
|
check(_readyForSystemPrompt) {
|
||||||
|
"System prompt must be set ** RIGHT AFTER ** model loaded!"
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
_state.value = State.ProcessingSystemPrompt
|
_state.value = State.ProcessingSystemPrompt
|
||||||
|
|
||||||
|
|
@ -86,15 +116,17 @@ class StubInferenceEngine : InferenceEngine {
|
||||||
/**
|
/**
|
||||||
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
||||||
*/
|
*/
|
||||||
override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> {
|
override fun sendUserPrompt(message: String, predictLength: Int): Flow<String> = flow {
|
||||||
Log.i(TAG, "sendUserPrompt! state: ${_state.value}")
|
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
||||||
|
check(_state.value is State.ModelReady) {
|
||||||
|
"Cannot load model at ${_state.value.javaClass.simpleName}"
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
Log.i(TAG, "sendUserPrompt! \n$message")
|
||||||
_state.value = State.ProcessingUserPrompt
|
_state.value = State.ProcessingUserPrompt
|
||||||
|
|
||||||
// This would be replaced with actual token generation logic
|
// Simulate longer processing time
|
||||||
return flow {
|
|
||||||
try {
|
|
||||||
// Simulate longer processing time (1.5 seconds)
|
|
||||||
delay(STUB_USER_PROMPT_PROCESSING_TIME)
|
delay(STUB_USER_PROMPT_PROCESSING_TIME)
|
||||||
|
|
||||||
_state.value = State.Generating
|
_state.value = State.Generating
|
||||||
|
|
@ -122,22 +154,25 @@ class StubInferenceEngine : InferenceEngine {
|
||||||
}
|
}
|
||||||
throw e
|
throw e
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Runs a benchmark with the specified parameters.
|
* Runs a benchmark with the specified parameters.
|
||||||
*/
|
*/
|
||||||
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String {
|
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
||||||
Log.i(TAG, "bench! state: ${_state.value}")
|
withContext(llamaDispatcher) {
|
||||||
|
check(_state.value is State.ModelReady) {
|
||||||
_state.value = State.Benchmarking
|
"Cannot load model at ${_state.value.javaClass.simpleName}"
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
Log.i(TAG, "bench! state: ${_state.value}")
|
||||||
|
_state.value = State.Benchmarking
|
||||||
|
|
||||||
// Simulate benchmark running
|
// Simulate benchmark running
|
||||||
delay(STUB_BENCHMARKING_TIME)
|
delay(STUB_BENCHMARKING_TIME)
|
||||||
|
|
||||||
// Generate fake benchmark results
|
// Generate fake benchmark results
|
||||||
val modelDesc = "Kleidi Llama"
|
val modelDesc = APP_NAME
|
||||||
val model_size = "7"
|
val model_size = "7"
|
||||||
val model_n_params = "7"
|
val model_n_params = "7"
|
||||||
val backend = "CPU"
|
val backend = "CPU"
|
||||||
|
|
@ -158,27 +193,38 @@ class StubInferenceEngine : InferenceEngine {
|
||||||
|
|
||||||
_state.value = State.ModelReady
|
_state.value = State.ModelReady
|
||||||
|
|
||||||
return result.toString()
|
result.toString()
|
||||||
} catch (e: CancellationException) {
|
} catch (e: CancellationException) {
|
||||||
// If coroutine is cancelled, propagate cancellation
|
// If coroutine is cancelled, propagate cancellation
|
||||||
|
Log.w(TAG, "Unexpected user cancellation while benchmarking!")
|
||||||
_state.value = State.ModelReady
|
_state.value = State.ModelReady
|
||||||
throw e
|
throw e
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
|
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
|
||||||
return "Error: ${e.message}"
|
"Error: ${e.message}"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unloads the currently loaded model.
|
* Unloads the currently loaded model.
|
||||||
*/
|
*/
|
||||||
override suspend fun unloadModel() {
|
override suspend fun unloadModel() =
|
||||||
Log.i(TAG, "unloadModel! state: ${_state.value}")
|
withContext(llamaDispatcher) {
|
||||||
|
when(val state = _state.value) {
|
||||||
|
is State.ModelReady, is State.Error -> {
|
||||||
|
Log.i(TAG, "unloadModel! state: ${_state.value.javaClass.simpleName}")
|
||||||
|
_state.value = State.UnloadingModel
|
||||||
|
|
||||||
// Simulate model unloading time
|
// Simulate model unloading time
|
||||||
delay(2000)
|
delay(STUB_MODEL_UNLOADING_TIME)
|
||||||
|
|
||||||
_state.value = State.LibraryLoaded
|
_state.value = State.LibraryLoaded
|
||||||
}
|
}
|
||||||
|
else -> throw IllegalStateException(
|
||||||
|
"Cannot load model at ${_state.value.javaClass.simpleName}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cleans up resources when the engine is no longer needed.
|
* Cleans up resources when the engine is no longer needed.
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ interface InferenceEngine {
|
||||||
object LibraryLoaded : State()
|
object LibraryLoaded : State()
|
||||||
|
|
||||||
object LoadingModel : State()
|
object LoadingModel : State()
|
||||||
|
object UnloadingModel : State()
|
||||||
object ModelReady : State()
|
object ModelReady : State()
|
||||||
|
|
||||||
object ProcessingSystemPrompt : State()
|
object ProcessingSystemPrompt : State()
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
*/
|
*/
|
||||||
override suspend fun loadModel(pathToModel: String) =
|
override suspend fun loadModel(pathToModel: String) =
|
||||||
withContext(llamaDispatcher) {
|
withContext(llamaDispatcher) {
|
||||||
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
|
check(_state.value is State.LibraryLoaded) {
|
||||||
|
"Cannot load model in ${_state.value.javaClass.simpleName}!"
|
||||||
|
}
|
||||||
File(pathToModel).let {
|
File(pathToModel).let {
|
||||||
require(it.exists()) { "Model file not found: $pathToModel" }
|
require(it.exists()) { "Model file not found: $pathToModel" }
|
||||||
require(it.isFile) { "Model file is not a file: $pathToModel" }
|
require(it.isFile) { "Model file is not a file: $pathToModel" }
|
||||||
|
|
@ -114,7 +116,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
withContext(llamaDispatcher) {
|
withContext(llamaDispatcher) {
|
||||||
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
|
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
|
||||||
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
|
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
|
||||||
check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" }
|
check(_state.value is State.ModelReady) {
|
||||||
|
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
|
||||||
|
}
|
||||||
|
|
||||||
Log.i(TAG, "Sending system prompt...")
|
Log.i(TAG, "Sending system prompt...")
|
||||||
_readyForSystemPrompt = false
|
_readyForSystemPrompt = false
|
||||||
|
|
@ -139,13 +143,14 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
): Flow<String> = flow {
|
): Flow<String> = flow {
|
||||||
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
||||||
check(_state.value is State.ModelReady) {
|
check(_state.value is State.ModelReady) {
|
||||||
"User prompt discarded due to: ${_state.value}"
|
"User prompt discarded due to: ${_state.value.javaClass.simpleName}"
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Log.i(TAG, "Sending user prompt...")
|
Log.i(TAG, "Sending user prompt...")
|
||||||
_readyForSystemPrompt = false
|
_readyForSystemPrompt = false
|
||||||
_state.value = State.ProcessingUserPrompt
|
_state.value = State.ProcessingUserPrompt
|
||||||
|
|
||||||
processUserPrompt(message, predictLength).let { result ->
|
processUserPrompt(message, predictLength).let { result ->
|
||||||
if (result != 0) {
|
if (result != 0) {
|
||||||
Log.e(TAG, "Failed to process user prompt: $result")
|
Log.e(TAG, "Failed to process user prompt: $result")
|
||||||
|
|
@ -194,16 +199,19 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
*/
|
*/
|
||||||
override suspend fun unloadModel() =
|
override suspend fun unloadModel() =
|
||||||
withContext(llamaDispatcher) {
|
withContext(llamaDispatcher) {
|
||||||
when(_state.value) {
|
when(val state = _state.value) {
|
||||||
is State.ModelReady, is State.Error -> {
|
is State.ModelReady, is State.Error -> {
|
||||||
Log.i(TAG, "Unloading model and free resources...")
|
Log.i(TAG, "Unloading model and free resources...")
|
||||||
_readyForSystemPrompt = false
|
_readyForSystemPrompt = false
|
||||||
|
_state.value = State.UnloadingModel
|
||||||
|
|
||||||
unload()
|
unload()
|
||||||
|
|
||||||
_state.value = State.LibraryLoaded
|
_state.value = State.LibraryLoaded
|
||||||
Log.i(TAG, "Model unloaded!")
|
Log.i(TAG, "Model unloaded!")
|
||||||
Unit
|
Unit
|
||||||
}
|
}
|
||||||
else -> throw IllegalStateException("Cannot unload model in ${_state.value}")
|
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue