llm: properly propagate error to UI upon failing to load selected model
This commit is contained in:
parent
3da54f497a
commit
dd5b20d74d
|
|
@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService {
|
||||||
/**
|
/**
|
||||||
* Load a model for benchmark
|
* Load a model for benchmark
|
||||||
*/
|
*/
|
||||||
suspend fun loadModelForBenchmark(): ModelLoadingMetrics
|
suspend fun loadModelForBenchmark(): ModelLoadingMetrics?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load a model for conversation
|
* Load a model for conversation
|
||||||
*/
|
*/
|
||||||
suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics
|
suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics?
|
||||||
}
|
}
|
||||||
|
|
||||||
interface BenchmarkService : InferenceService {
|
interface BenchmarkService : InferenceService {
|
||||||
|
|
@ -132,7 +132,7 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
||||||
|
|
||||||
override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model }
|
override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model }
|
||||||
|
|
||||||
override suspend fun unloadModel() = inferenceEngine.unloadModel()
|
override suspend fun unloadModel() = inferenceEngine.cleanUp()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Shut down inference engine
|
* Shut down inference engine
|
||||||
|
|
@ -145,8 +145,10 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
override suspend fun loadModelForBenchmark(): ModelLoadingMetrics =
|
override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? {
|
||||||
_currentModel.value?.let { model ->
|
checkNotNull(_currentModel.value) { "Attempt to load model for bench while none selected!" }
|
||||||
|
|
||||||
|
return _currentModel.value?.let { model ->
|
||||||
try {
|
try {
|
||||||
val modelLoadStartTs = System.currentTimeMillis()
|
val modelLoadStartTs = System.currentTimeMillis()
|
||||||
inferenceEngine.loadModel(model.path)
|
inferenceEngine.loadModel(model.path)
|
||||||
|
|
@ -154,12 +156,15 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
||||||
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
|
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.e(TAG, "Error loading model", e)
|
Log.e(TAG, "Error loading model", e)
|
||||||
throw e
|
null
|
||||||
}
|
}
|
||||||
} ?: throw IllegalStateException("No model selected!")
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics =
|
override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? {
|
||||||
_currentModel.value?.let { model ->
|
checkNotNull(_currentModel.value) { "Attempt to load model for chat while none selected!" }
|
||||||
|
|
||||||
|
return _currentModel.value?.let { model ->
|
||||||
try {
|
try {
|
||||||
_systemPrompt.value = systemPrompt
|
_systemPrompt.value = systemPrompt
|
||||||
|
|
||||||
|
|
@ -181,10 +186,10 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
||||||
}
|
}
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.e(TAG, "Error loading model", e)
|
Log.e(TAG, "Error loading model", e)
|
||||||
throw e
|
null
|
||||||
}
|
}
|
||||||
} ?: throw IllegalStateException("No model selected!")
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
*
|
*
|
||||||
|
|
|
||||||
|
|
@ -206,7 +206,7 @@ class StubInferenceEngine : InferenceEngine {
|
||||||
/**
|
/**
|
||||||
* Unloads the currently loaded model.
|
* Unloads the currently loaded model.
|
||||||
*/
|
*/
|
||||||
override suspend fun unloadModel() =
|
override suspend fun cleanUp() =
|
||||||
withContext(llamaDispatcher) {
|
withContext(llamaDispatcher) {
|
||||||
when(val state = _state.value) {
|
when(val state = _state.value) {
|
||||||
is State.ModelReady, is State.Error -> {
|
is State.ModelReady, is State.Error -> {
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,10 @@ import androidx.compose.foundation.lazy.items
|
||||||
import androidx.compose.foundation.selection.selectable
|
import androidx.compose.foundation.selection.selectable
|
||||||
import androidx.compose.material.icons.Icons
|
import androidx.compose.material.icons.Icons
|
||||||
import androidx.compose.material.icons.filled.Check
|
import androidx.compose.material.icons.filled.Check
|
||||||
|
import androidx.compose.material.icons.filled.Error
|
||||||
|
import androidx.compose.material.icons.filled.PlayArrow
|
||||||
import androidx.compose.material3.Button
|
import androidx.compose.material3.Button
|
||||||
|
import androidx.compose.material3.ButtonDefaults
|
||||||
import androidx.compose.material3.Card
|
import androidx.compose.material3.Card
|
||||||
import androidx.compose.material3.CircularProgressIndicator
|
import androidx.compose.material3.CircularProgressIndicator
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
|
|
@ -109,6 +112,7 @@ fun ModelLoadingScreen(
|
||||||
|
|
||||||
// Check if we're in a loading state
|
// Check if we're in a loading state
|
||||||
val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady
|
val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady
|
||||||
|
val errorMessage = (engineState as? State.Error)?.errorMessage
|
||||||
|
|
||||||
// Handle back navigation requests
|
// Handle back navigation requests
|
||||||
BackHandler {
|
BackHandler {
|
||||||
|
|
@ -315,13 +319,10 @@ fun ModelLoadingScreen(
|
||||||
customPromptText.takeIf { it.isNotBlank() }
|
customPromptText.takeIf { it.isNotBlank() }
|
||||||
?.also { promptText ->
|
?.also { promptText ->
|
||||||
// Save custom prompt to database
|
// Save custom prompt to database
|
||||||
viewModel.saveCustomPromptToRecents(
|
viewModel.saveCustomPromptToRecents(promptText)
|
||||||
promptText
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else null
|
} else null
|
||||||
|
|
||||||
viewModel.onConversationSelected(systemPrompt, onNavigateToConversation)
|
viewModel.onConversationSelected(systemPrompt, onNavigateToConversation)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -329,30 +330,51 @@ fun ModelLoadingScreen(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
modifier = Modifier
|
modifier = Modifier.fillMaxWidth().height(56.dp),
|
||||||
.fillMaxWidth()
|
colors = if (errorMessage != null)
|
||||||
.height(56.dp),
|
ButtonDefaults.buttonColors(
|
||||||
|
disabledContainerColor = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f),
|
||||||
|
disabledContentColor = MaterialTheme.colorScheme.onErrorContainer.copy(alpha = 0.7f)
|
||||||
|
) else ButtonDefaults.buttonColors(),
|
||||||
enabled = selectedMode != null && !isLoading &&
|
enabled = selectedMode != null && !isLoading &&
|
||||||
(!useSystemPrompt || hasActiveSystemPrompt)
|
(!useSystemPrompt || hasActiveSystemPrompt)
|
||||||
) {
|
) {
|
||||||
if (isLoading) {
|
when {
|
||||||
CircularProgressIndicator(
|
errorMessage != null -> {
|
||||||
modifier = Modifier
|
Icon(
|
||||||
.height(24.dp)
|
imageVector = Icons.Default.Error,
|
||||||
.width(24.dp)
|
contentDescription = errorMessage,
|
||||||
)
|
tint = MaterialTheme.colorScheme.error
|
||||||
Spacer(modifier = Modifier.width(8.dp))
|
)
|
||||||
Text(
|
Spacer(modifier = Modifier.width(8.dp))
|
||||||
text = when (engineState) {
|
Text(
|
||||||
is State.Initializing, State.Initialized -> "Initializing..."
|
text = errorMessage,
|
||||||
is State.LoadingModel -> "Loading model..."
|
color = MaterialTheme.colorScheme.onErrorContainer
|
||||||
is State.ProcessingSystemPrompt -> "Processing system prompt..."
|
)
|
||||||
else -> "Processing..."
|
}
|
||||||
},
|
|
||||||
style = MaterialTheme.typography.titleMedium
|
isLoading -> {
|
||||||
)
|
CircularProgressIndicator(modifier = Modifier.height(24.dp).width(24.dp))
|
||||||
} else {
|
Spacer(modifier = Modifier.width(8.dp))
|
||||||
Text(text = "Start", style = MaterialTheme.typography.titleMedium)
|
Text(
|
||||||
|
text = when (engineState) {
|
||||||
|
is State.Initializing, State.Initialized -> "Initializing..."
|
||||||
|
is State.LoadingModel -> "Loading model..."
|
||||||
|
is State.ProcessingSystemPrompt -> "Processing system prompt..."
|
||||||
|
else -> "Processing..."
|
||||||
|
},
|
||||||
|
style = MaterialTheme.typography.titleMedium
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
else -> {
|
||||||
|
Icon(
|
||||||
|
imageVector = Icons.Default.PlayArrow,
|
||||||
|
contentDescription = "Run model ${selectedModel?.name} with $selectedMode"
|
||||||
|
)
|
||||||
|
Spacer(modifier = Modifier.width(8.dp))
|
||||||
|
Text(text = "Start", style = MaterialTheme.typography.titleMedium)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -86,10 +86,12 @@ class ModelLoadingViewModel @Inject constructor(
|
||||||
*/
|
*/
|
||||||
fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) =
|
fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) =
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
selectedModel.value?.let {
|
selectedModel.value?.let { model ->
|
||||||
modelRepository.updateModelLastUsed(it.id)
|
modelLoadingService.loadModelForBenchmark()?.let { metrics ->
|
||||||
|
modelRepository.updateModelLastUsed(model.id)
|
||||||
|
onNavigateToBenchmark(metrics)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -100,10 +102,12 @@ class ModelLoadingViewModel @Inject constructor(
|
||||||
systemPrompt: String? = null,
|
systemPrompt: String? = null,
|
||||||
onNavigateToConversation: (ModelLoadingMetrics) -> Unit
|
onNavigateToConversation: (ModelLoadingMetrics) -> Unit
|
||||||
) = viewModelScope.launch {
|
) = viewModelScope.launch {
|
||||||
selectedModel.value?.let {
|
selectedModel.value?.let { model ->
|
||||||
modelRepository.updateModelLastUsed(it.id)
|
modelLoadingService.loadModelForConversation(systemPrompt)?.let { metrics ->
|
||||||
|
modelRepository.updateModelLastUsed(model.id)
|
||||||
|
onNavigateToConversation(metrics)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ interface InferenceEngine {
|
||||||
/**
|
/**
|
||||||
* Unloads the currently loaded model.
|
* Unloads the currently loaded model.
|
||||||
*/
|
*/
|
||||||
suspend fun unloadModel()
|
suspend fun cleanUp()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cleans up resources when the engine is no longer needed.
|
* Cleans up resources when the engine is no longer needed.
|
||||||
|
|
|
||||||
|
|
@ -30,8 +30,8 @@ import java.io.File
|
||||||
* 2. Load a model with [loadModel]
|
* 2. Load a model with [loadModel]
|
||||||
* 3. Send prompts with [sendUserPrompt]
|
* 3. Send prompts with [sendUserPrompt]
|
||||||
* 4. Generate responses as token streams
|
* 4. Generate responses as token streams
|
||||||
* 5. Unload the model with [unloadModel] when switching models
|
* 5. Perform [cleanUp] when done with a model
|
||||||
* 6. Call [destroy] when completely done
|
* 6. Properly [destroy] when completely done
|
||||||
*
|
*
|
||||||
* State transitions are managed automatically and validated at each operation.
|
* State transitions are managed automatically and validated at each operation.
|
||||||
*
|
*
|
||||||
|
|
@ -133,18 +133,23 @@ internal class InferenceEngineImpl private constructor(
|
||||||
require(it.isFile) { "Model file is not a file: $pathToModel" }
|
require(it.isFile) { "Model file is not a file: $pathToModel" }
|
||||||
}
|
}
|
||||||
|
|
||||||
Log.i(TAG, "Loading model... \n$pathToModel")
|
try {
|
||||||
_readyForSystemPrompt = false
|
Log.i(TAG, "Loading model... \n$pathToModel")
|
||||||
_state.value = InferenceEngine.State.LoadingModel
|
_readyForSystemPrompt = false
|
||||||
load(pathToModel).let { result ->
|
_state.value = InferenceEngine.State.LoadingModel
|
||||||
if (result != 0) throw IllegalStateException("Failed to Load model: $result")
|
load(pathToModel).let {
|
||||||
|
if (it != 0) throw IllegalStateException("Failed to load the model!")
|
||||||
|
}
|
||||||
|
prepare().let {
|
||||||
|
if (it != 0) throw IllegalStateException("Failed to prepare resources!")
|
||||||
|
}
|
||||||
|
Log.i(TAG, "Model loaded!")
|
||||||
|
_readyForSystemPrompt = true
|
||||||
|
_state.value = InferenceEngine.State.ModelReady
|
||||||
|
} catch (e: Exception) {
|
||||||
|
_state.value = InferenceEngine.State.Error(e.message ?: "Unknown error")
|
||||||
|
throw e
|
||||||
}
|
}
|
||||||
prepare().let { result ->
|
|
||||||
if (result != 0) throw IllegalStateException("Failed to prepare resources: $result")
|
|
||||||
}
|
|
||||||
Log.i(TAG, "Model loaded!")
|
|
||||||
_readyForSystemPrompt = true
|
|
||||||
_state.value = InferenceEngine.State.ModelReady
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -235,12 +240,12 @@ internal class InferenceEngineImpl private constructor(
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unloads the model and frees resources
|
* Unloads the model and frees resources, or reset error states
|
||||||
*/
|
*/
|
||||||
override suspend fun unloadModel() =
|
override suspend fun cleanUp() =
|
||||||
withContext(llamaDispatcher) {
|
withContext(llamaDispatcher) {
|
||||||
when (val state = _state.value) {
|
when (val state = _state.value) {
|
||||||
is InferenceEngine.State.ModelReady, is InferenceEngine.State.Error -> {
|
is InferenceEngine.State.ModelReady -> {
|
||||||
Log.i(TAG, "Unloading model and free resources...")
|
Log.i(TAG, "Unloading model and free resources...")
|
||||||
_readyForSystemPrompt = false
|
_readyForSystemPrompt = false
|
||||||
_state.value = InferenceEngine.State.UnloadingModel
|
_state.value = InferenceEngine.State.UnloadingModel
|
||||||
|
|
@ -252,6 +257,13 @@ internal class InferenceEngineImpl private constructor(
|
||||||
Unit
|
Unit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
is InferenceEngine.State.Error -> {
|
||||||
|
Log.i(TAG, "Resetting error states...")
|
||||||
|
_state.value = InferenceEngine.State.Initialized
|
||||||
|
Log.i(TAG, "States reset!")
|
||||||
|
Unit
|
||||||
|
}
|
||||||
|
|
||||||
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
|
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue