llm: properly propagate error to UI upon failing to load selected model

This commit is contained in:
Han Yin 2025-07-21 13:10:41 -07:00
parent 3da54f497a
commit dd5b20d74d
6 changed files with 104 additions and 61 deletions

View File

@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService {
/**
* Load a model for benchmark
*/
suspend fun loadModelForBenchmark(): ModelLoadingMetrics
suspend fun loadModelForBenchmark(): ModelLoadingMetrics?
/**
* Load a model for conversation
*/
suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics
suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics?
}
interface BenchmarkService : InferenceService {
@ -132,7 +132,7 @@ internal class InferenceServiceImpl @Inject internal constructor(
override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model }
override suspend fun unloadModel() = inferenceEngine.unloadModel()
override suspend fun unloadModel() = inferenceEngine.cleanUp()
/**
* Shut down inference engine
@ -145,8 +145,10 @@ internal class InferenceServiceImpl @Inject internal constructor(
*
*/
override suspend fun loadModelForBenchmark(): ModelLoadingMetrics =
_currentModel.value?.let { model ->
override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? {
checkNotNull(_currentModel.value) { "Attempt to load model for bench while none selected!" }
return _currentModel.value?.let { model ->
try {
val modelLoadStartTs = System.currentTimeMillis()
inferenceEngine.loadModel(model.path)
@ -154,12 +156,15 @@ internal class InferenceServiceImpl @Inject internal constructor(
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
} catch (e: Exception) {
Log.e(TAG, "Error loading model", e)
throw e
null
}
} ?: throw IllegalStateException("No model selected!")
}
}
override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics =
_currentModel.value?.let { model ->
override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? {
checkNotNull(_currentModel.value) { "Attempt to load model for chat while none selected!" }
return _currentModel.value?.let { model ->
try {
_systemPrompt.value = systemPrompt
@ -181,10 +186,10 @@ internal class InferenceServiceImpl @Inject internal constructor(
}
} catch (e: Exception) {
Log.e(TAG, "Error loading model", e)
throw e
null
}
} ?: throw IllegalStateException("No model selected!")
}
}
/*
*

View File

@ -206,7 +206,7 @@ class StubInferenceEngine : InferenceEngine {
/**
* Unloads the currently loaded model.
*/
override suspend fun unloadModel() =
override suspend fun cleanUp() =
withContext(llamaDispatcher) {
when(val state = _state.value) {
is State.ModelReady, is State.Error -> {

View File

@ -22,7 +22,10 @@ import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.selection.selectable
import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Check
import androidx.compose.material.icons.filled.Error
import androidx.compose.material.icons.filled.PlayArrow
import androidx.compose.material3.Button
import androidx.compose.material3.ButtonDefaults
import androidx.compose.material3.Card
import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api
@ -109,6 +112,7 @@ fun ModelLoadingScreen(
// Check if we're in a loading state
val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady
val errorMessage = (engineState as? State.Error)?.errorMessage
// Handle back navigation requests
BackHandler {
@ -315,13 +319,10 @@ fun ModelLoadingScreen(
customPromptText.takeIf { it.isNotBlank() }
?.also { promptText ->
// Save custom prompt to database
viewModel.saveCustomPromptToRecents(
promptText
)
viewModel.saveCustomPromptToRecents(promptText)
}
}
} else null
viewModel.onConversationSelected(systemPrompt, onNavigateToConversation)
}
@ -329,30 +330,51 @@ fun ModelLoadingScreen(
}
}
},
modifier = Modifier
.fillMaxWidth()
.height(56.dp),
modifier = Modifier.fillMaxWidth().height(56.dp),
colors = if (errorMessage != null)
ButtonDefaults.buttonColors(
disabledContainerColor = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f),
disabledContentColor = MaterialTheme.colorScheme.onErrorContainer.copy(alpha = 0.7f)
) else ButtonDefaults.buttonColors(),
enabled = selectedMode != null && !isLoading &&
(!useSystemPrompt || hasActiveSystemPrompt)
) {
if (isLoading) {
CircularProgressIndicator(
modifier = Modifier
.height(24.dp)
.width(24.dp)
)
Spacer(modifier = Modifier.width(8.dp))
Text(
text = when (engineState) {
is State.Initializing, State.Initialized -> "Initializing..."
is State.LoadingModel -> "Loading model..."
is State.ProcessingSystemPrompt -> "Processing system prompt..."
else -> "Processing..."
},
style = MaterialTheme.typography.titleMedium
)
} else {
Text(text = "Start", style = MaterialTheme.typography.titleMedium)
when {
errorMessage != null -> {
Icon(
imageVector = Icons.Default.Error,
contentDescription = errorMessage,
tint = MaterialTheme.colorScheme.error
)
Spacer(modifier = Modifier.width(8.dp))
Text(
text = errorMessage,
color = MaterialTheme.colorScheme.onErrorContainer
)
}
isLoading -> {
CircularProgressIndicator(modifier = Modifier.height(24.dp).width(24.dp))
Spacer(modifier = Modifier.width(8.dp))
Text(
text = when (engineState) {
is State.Initializing, State.Initialized -> "Initializing..."
is State.LoadingModel -> "Loading model..."
is State.ProcessingSystemPrompt -> "Processing system prompt..."
else -> "Processing..."
},
style = MaterialTheme.typography.titleMedium
)
}
else -> {
Icon(
imageVector = Icons.Default.PlayArrow,
contentDescription = "Run model ${selectedModel?.name} with $selectedMode"
)
Spacer(modifier = Modifier.width(8.dp))
Text(text = "Start", style = MaterialTheme.typography.titleMedium)
}
}
}
}

View File

@ -86,10 +86,12 @@ class ModelLoadingViewModel @Inject constructor(
*/
fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) =
viewModelScope.launch {
selectedModel.value?.let {
modelRepository.updateModelLastUsed(it.id)
selectedModel.value?.let { model ->
modelLoadingService.loadModelForBenchmark()?.let { metrics ->
modelRepository.updateModelLastUsed(model.id)
onNavigateToBenchmark(metrics)
}
}
onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark())
}
/**
@ -100,10 +102,12 @@ class ModelLoadingViewModel @Inject constructor(
systemPrompt: String? = null,
onNavigateToConversation: (ModelLoadingMetrics) -> Unit
) = viewModelScope.launch {
selectedModel.value?.let {
modelRepository.updateModelLastUsed(it.id)
selectedModel.value?.let { model ->
modelLoadingService.loadModelForConversation(systemPrompt)?.let { metrics ->
modelRepository.updateModelLastUsed(model.id)
onNavigateToConversation(metrics)
}
}
onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt))
}
companion object {

View File

@ -36,7 +36,7 @@ interface InferenceEngine {
/**
* Unloads the currently loaded model.
*/
suspend fun unloadModel()
suspend fun cleanUp()
/**
* Cleans up resources when the engine is no longer needed.

View File

@ -30,8 +30,8 @@ import java.io.File
* 2. Load a model with [loadModel]
* 3. Send prompts with [sendUserPrompt]
* 4. Generate responses as token streams
* 5. Unload the model with [unloadModel] when switching models
* 6. Call [destroy] when completely done
* 5. Perform [cleanUp] when done with a model
* 6. Properly [destroy] when completely done
*
* State transitions are managed automatically and validated at each operation.
*
@ -133,18 +133,23 @@ internal class InferenceEngineImpl private constructor(
require(it.isFile) { "Model file is not a file: $pathToModel" }
}
Log.i(TAG, "Loading model... \n$pathToModel")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.LoadingModel
load(pathToModel).let { result ->
if (result != 0) throw IllegalStateException("Failed to Load model: $result")
try {
Log.i(TAG, "Loading model... \n$pathToModel")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.LoadingModel
load(pathToModel).let {
if (it != 0) throw IllegalStateException("Failed to load the model!")
}
prepare().let {
if (it != 0) throw IllegalStateException("Failed to prepare resources!")
}
Log.i(TAG, "Model loaded!")
_readyForSystemPrompt = true
_state.value = InferenceEngine.State.ModelReady
} catch (e: Exception) {
_state.value = InferenceEngine.State.Error(e.message ?: "Unknown error")
throw e
}
prepare().let { result ->
if (result != 0) throw IllegalStateException("Failed to prepare resources: $result")
}
Log.i(TAG, "Model loaded!")
_readyForSystemPrompt = true
_state.value = InferenceEngine.State.ModelReady
}
/**
@ -235,12 +240,12 @@ internal class InferenceEngineImpl private constructor(
}
/**
* Unloads the model and frees resources
* Unloads the model and frees resources, or reset error states
*/
override suspend fun unloadModel() =
override suspend fun cleanUp() =
withContext(llamaDispatcher) {
when (val state = _state.value) {
is InferenceEngine.State.ModelReady, is InferenceEngine.State.Error -> {
is InferenceEngine.State.ModelReady -> {
Log.i(TAG, "Unloading model and free resources...")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.UnloadingModel
@ -252,6 +257,13 @@ internal class InferenceEngineImpl private constructor(
Unit
}
is InferenceEngine.State.Error -> {
Log.i(TAG, "Resetting error states...")
_state.value = InferenceEngine.State.Initialized
Log.i(TAG, "States reset!")
Unit
}
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
}
}