llm: properly propagate error to UI upon failing to load selected model

This commit is contained in:
Han Yin 2025-07-21 13:10:41 -07:00
parent 3da54f497a
commit dd5b20d74d
6 changed files with 104 additions and 61 deletions

View File

@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService {
/** /**
* Load a model for benchmark * Load a model for benchmark
*/ */
suspend fun loadModelForBenchmark(): ModelLoadingMetrics suspend fun loadModelForBenchmark(): ModelLoadingMetrics?
/** /**
* Load a model for conversation * Load a model for conversation
*/ */
suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics?
} }
interface BenchmarkService : InferenceService { interface BenchmarkService : InferenceService {
@ -132,7 +132,7 @@ internal class InferenceServiceImpl @Inject internal constructor(
override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model } override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model }
override suspend fun unloadModel() = inferenceEngine.unloadModel() override suspend fun unloadModel() = inferenceEngine.cleanUp()
/** /**
* Shut down inference engine * Shut down inference engine
@ -145,8 +145,10 @@ internal class InferenceServiceImpl @Inject internal constructor(
* *
*/ */
override suspend fun loadModelForBenchmark(): ModelLoadingMetrics = override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? {
_currentModel.value?.let { model -> checkNotNull(_currentModel.value) { "Attempt to load model for bench while none selected!" }
return _currentModel.value?.let { model ->
try { try {
val modelLoadStartTs = System.currentTimeMillis() val modelLoadStartTs = System.currentTimeMillis()
inferenceEngine.loadModel(model.path) inferenceEngine.loadModel(model.path)
@ -154,12 +156,15 @@ internal class InferenceServiceImpl @Inject internal constructor(
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs) ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error loading model", e) Log.e(TAG, "Error loading model", e)
throw e null
} }
} ?: throw IllegalStateException("No model selected!") }
}
override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics = override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? {
_currentModel.value?.let { model -> checkNotNull(_currentModel.value) { "Attempt to load model for chat while none selected!" }
return _currentModel.value?.let { model ->
try { try {
_systemPrompt.value = systemPrompt _systemPrompt.value = systemPrompt
@ -181,10 +186,10 @@ internal class InferenceServiceImpl @Inject internal constructor(
} }
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error loading model", e) Log.e(TAG, "Error loading model", e)
throw e null
} }
} ?: throw IllegalStateException("No model selected!") }
}
/* /*
* *

View File

@ -206,7 +206,7 @@ class StubInferenceEngine : InferenceEngine {
/** /**
* Unloads the currently loaded model. * Unloads the currently loaded model.
*/ */
override suspend fun unloadModel() = override suspend fun cleanUp() =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
when(val state = _state.value) { when(val state = _state.value) {
is State.ModelReady, is State.Error -> { is State.ModelReady, is State.Error -> {

View File

@ -22,7 +22,10 @@ import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.selection.selectable import androidx.compose.foundation.selection.selectable
import androidx.compose.material.icons.Icons import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.Check import androidx.compose.material.icons.filled.Check
import androidx.compose.material.icons.filled.Error
import androidx.compose.material.icons.filled.PlayArrow
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.ButtonDefaults
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
@ -109,6 +112,7 @@ fun ModelLoadingScreen(
// Check if we're in a loading state // Check if we're in a loading state
val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady
val errorMessage = (engineState as? State.Error)?.errorMessage
// Handle back navigation requests // Handle back navigation requests
BackHandler { BackHandler {
@ -315,13 +319,10 @@ fun ModelLoadingScreen(
customPromptText.takeIf { it.isNotBlank() } customPromptText.takeIf { it.isNotBlank() }
?.also { promptText -> ?.also { promptText ->
// Save custom prompt to database // Save custom prompt to database
viewModel.saveCustomPromptToRecents( viewModel.saveCustomPromptToRecents(promptText)
promptText
)
} }
} }
} else null } else null
viewModel.onConversationSelected(systemPrompt, onNavigateToConversation) viewModel.onConversationSelected(systemPrompt, onNavigateToConversation)
} }
@ -329,30 +330,51 @@ fun ModelLoadingScreen(
} }
} }
}, },
modifier = Modifier modifier = Modifier.fillMaxWidth().height(56.dp),
.fillMaxWidth() colors = if (errorMessage != null)
.height(56.dp), ButtonDefaults.buttonColors(
disabledContainerColor = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f),
disabledContentColor = MaterialTheme.colorScheme.onErrorContainer.copy(alpha = 0.7f)
) else ButtonDefaults.buttonColors(),
enabled = selectedMode != null && !isLoading && enabled = selectedMode != null && !isLoading &&
(!useSystemPrompt || hasActiveSystemPrompt) (!useSystemPrompt || hasActiveSystemPrompt)
) { ) {
if (isLoading) { when {
CircularProgressIndicator( errorMessage != null -> {
modifier = Modifier Icon(
.height(24.dp) imageVector = Icons.Default.Error,
.width(24.dp) contentDescription = errorMessage,
) tint = MaterialTheme.colorScheme.error
Spacer(modifier = Modifier.width(8.dp)) )
Text( Spacer(modifier = Modifier.width(8.dp))
text = when (engineState) { Text(
is State.Initializing, State.Initialized -> "Initializing..." text = errorMessage,
is State.LoadingModel -> "Loading model..." color = MaterialTheme.colorScheme.onErrorContainer
is State.ProcessingSystemPrompt -> "Processing system prompt..." )
else -> "Processing..." }
},
style = MaterialTheme.typography.titleMedium isLoading -> {
) CircularProgressIndicator(modifier = Modifier.height(24.dp).width(24.dp))
} else { Spacer(modifier = Modifier.width(8.dp))
Text(text = "Start", style = MaterialTheme.typography.titleMedium) Text(
text = when (engineState) {
is State.Initializing, State.Initialized -> "Initializing..."
is State.LoadingModel -> "Loading model..."
is State.ProcessingSystemPrompt -> "Processing system prompt..."
else -> "Processing..."
},
style = MaterialTheme.typography.titleMedium
)
}
else -> {
Icon(
imageVector = Icons.Default.PlayArrow,
contentDescription = "Run model ${selectedModel?.name} with $selectedMode"
)
Spacer(modifier = Modifier.width(8.dp))
Text(text = "Start", style = MaterialTheme.typography.titleMedium)
}
} }
} }
} }

View File

@ -86,10 +86,12 @@ class ModelLoadingViewModel @Inject constructor(
*/ */
fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) = fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) =
viewModelScope.launch { viewModelScope.launch {
selectedModel.value?.let { selectedModel.value?.let { model ->
modelRepository.updateModelLastUsed(it.id) modelLoadingService.loadModelForBenchmark()?.let { metrics ->
modelRepository.updateModelLastUsed(model.id)
onNavigateToBenchmark(metrics)
}
} }
onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark())
} }
/** /**
@ -100,10 +102,12 @@ class ModelLoadingViewModel @Inject constructor(
systemPrompt: String? = null, systemPrompt: String? = null,
onNavigateToConversation: (ModelLoadingMetrics) -> Unit onNavigateToConversation: (ModelLoadingMetrics) -> Unit
) = viewModelScope.launch { ) = viewModelScope.launch {
selectedModel.value?.let { selectedModel.value?.let { model ->
modelRepository.updateModelLastUsed(it.id) modelLoadingService.loadModelForConversation(systemPrompt)?.let { metrics ->
modelRepository.updateModelLastUsed(model.id)
onNavigateToConversation(metrics)
}
} }
onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt))
} }
companion object { companion object {

View File

@ -36,7 +36,7 @@ interface InferenceEngine {
/** /**
* Unloads the currently loaded model. * Unloads the currently loaded model.
*/ */
suspend fun unloadModel() suspend fun cleanUp()
/** /**
* Cleans up resources when the engine is no longer needed. * Cleans up resources when the engine is no longer needed.

View File

@ -30,8 +30,8 @@ import java.io.File
* 2. Load a model with [loadModel] * 2. Load a model with [loadModel]
* 3. Send prompts with [sendUserPrompt] * 3. Send prompts with [sendUserPrompt]
* 4. Generate responses as token streams * 4. Generate responses as token streams
* 5. Unload the model with [unloadModel] when switching models * 5. Perform [cleanUp] when done with a model
* 6. Call [destroy] when completely done * 6. Properly [destroy] when completely done
* *
* State transitions are managed automatically and validated at each operation. * State transitions are managed automatically and validated at each operation.
* *
@ -133,18 +133,23 @@ internal class InferenceEngineImpl private constructor(
require(it.isFile) { "Model file is not a file: $pathToModel" } require(it.isFile) { "Model file is not a file: $pathToModel" }
} }
Log.i(TAG, "Loading model... \n$pathToModel") try {
_readyForSystemPrompt = false Log.i(TAG, "Loading model... \n$pathToModel")
_state.value = InferenceEngine.State.LoadingModel _readyForSystemPrompt = false
load(pathToModel).let { result -> _state.value = InferenceEngine.State.LoadingModel
if (result != 0) throw IllegalStateException("Failed to Load model: $result") load(pathToModel).let {
if (it != 0) throw IllegalStateException("Failed to load the model!")
}
prepare().let {
if (it != 0) throw IllegalStateException("Failed to prepare resources!")
}
Log.i(TAG, "Model loaded!")
_readyForSystemPrompt = true
_state.value = InferenceEngine.State.ModelReady
} catch (e: Exception) {
_state.value = InferenceEngine.State.Error(e.message ?: "Unknown error")
throw e
} }
prepare().let { result ->
if (result != 0) throw IllegalStateException("Failed to prepare resources: $result")
}
Log.i(TAG, "Model loaded!")
_readyForSystemPrompt = true
_state.value = InferenceEngine.State.ModelReady
} }
/** /**
@ -235,12 +240,12 @@ internal class InferenceEngineImpl private constructor(
} }
/** /**
* Unloads the model and frees resources * Unloads the model and frees resources, or reset error states
*/ */
override suspend fun unloadModel() = override suspend fun cleanUp() =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
when (val state = _state.value) { when (val state = _state.value) {
is InferenceEngine.State.ModelReady, is InferenceEngine.State.Error -> { is InferenceEngine.State.ModelReady -> {
Log.i(TAG, "Unloading model and free resources...") Log.i(TAG, "Unloading model and free resources...")
_readyForSystemPrompt = false _readyForSystemPrompt = false
_state.value = InferenceEngine.State.UnloadingModel _state.value = InferenceEngine.State.UnloadingModel
@ -252,6 +257,13 @@ internal class InferenceEngineImpl private constructor(
Unit Unit
} }
is InferenceEngine.State.Error -> {
Log.i(TAG, "Resetting error states...")
_state.value = InferenceEngine.State.Initialized
Log.i(TAG, "States reset!")
Unit
}
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}") else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
} }
} }