From dd5b20d74d212c766e1068bddfe4f5df7410dcfb Mon Sep 17 00:00:00 2001 From: Han Yin Date: Mon, 21 Jul 2025 13:10:41 -0700 Subject: [PATCH] llm: properly propagate error to UI upon failing to load selected model --- .../example/llama/engine/InferenceServices.kt | 29 ++++---- .../llama/engine/StubInferenceEngine.kt | 2 +- .../llama/ui/screens/ModelLoadingScreen.kt | 72 ++++++++++++------- .../llama/viewmodel/ModelLoadingViewModel.kt | 16 +++-- .../java/android/llama/cpp/InferenceEngine.kt | 2 +- .../llama/cpp/internal/InferenceEngineImpl.kt | 44 +++++++----- 6 files changed, 104 insertions(+), 61 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/engine/InferenceServices.kt b/examples/llama.android/app/src/main/java/com/example/llama/engine/InferenceServices.kt index 201bfc6adf..c9ffec68c8 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/engine/InferenceServices.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/engine/InferenceServices.kt @@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService { /** * Load a model for benchmark */ - suspend fun loadModelForBenchmark(): ModelLoadingMetrics + suspend fun loadModelForBenchmark(): ModelLoadingMetrics? /** * Load a model for conversation */ - suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics + suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? } interface BenchmarkService : InferenceService { @@ -132,7 +132,7 @@ internal class InferenceServiceImpl @Inject internal constructor( override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model } - override suspend fun unloadModel() = inferenceEngine.unloadModel() + override suspend fun unloadModel() = inferenceEngine.cleanUp() /** * Shut down inference engine @@ -145,8 +145,10 @@ internal class InferenceServiceImpl @Inject internal constructor( * */ - override suspend fun loadModelForBenchmark(): ModelLoadingMetrics = - _currentModel.value?.let { model -> + override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? { + checkNotNull(_currentModel.value) { "Attempt to load model for bench while none selected!" } + + return _currentModel.value?.let { model -> try { val modelLoadStartTs = System.currentTimeMillis() inferenceEngine.loadModel(model.path) @@ -154,12 +156,15 @@ internal class InferenceServiceImpl @Inject internal constructor( ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs) } catch (e: Exception) { Log.e(TAG, "Error loading model", e) - throw e + null } - } ?: throw IllegalStateException("No model selected!") + } + } - override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics = - _currentModel.value?.let { model -> + override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? { + checkNotNull(_currentModel.value) { "Attempt to load model for chat while none selected!" } + + return _currentModel.value?.let { model -> try { _systemPrompt.value = systemPrompt @@ -181,10 +186,10 @@ internal class InferenceServiceImpl @Inject internal constructor( } } catch (e: Exception) { Log.e(TAG, "Error loading model", e) - throw e + null } - } ?: throw IllegalStateException("No model selected!") - + } + } /* * diff --git a/examples/llama.android/app/src/main/java/com/example/llama/engine/StubInferenceEngine.kt b/examples/llama.android/app/src/main/java/com/example/llama/engine/StubInferenceEngine.kt index bd8fc1c365..bcd36b7914 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/engine/StubInferenceEngine.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/engine/StubInferenceEngine.kt @@ -206,7 +206,7 @@ class StubInferenceEngine : InferenceEngine { /** * Unloads the currently loaded model. */ - override suspend fun unloadModel() = + override suspend fun cleanUp() = withContext(llamaDispatcher) { when(val state = _state.value) { is State.ModelReady, is State.Error -> { diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelLoadingScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelLoadingScreen.kt index 0c9f5ec0bd..5381c56c7e 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelLoadingScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelLoadingScreen.kt @@ -22,7 +22,10 @@ import androidx.compose.foundation.lazy.items import androidx.compose.foundation.selection.selectable import androidx.compose.material.icons.Icons import androidx.compose.material.icons.filled.Check +import androidx.compose.material.icons.filled.Error +import androidx.compose.material.icons.filled.PlayArrow import androidx.compose.material3.Button +import androidx.compose.material3.ButtonDefaults import androidx.compose.material3.Card import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.ExperimentalMaterial3Api @@ -109,6 +112,7 @@ fun ModelLoadingScreen( // Check if we're in a loading state val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady + val errorMessage = (engineState as? State.Error)?.errorMessage // Handle back navigation requests BackHandler { @@ -315,13 +319,10 @@ fun ModelLoadingScreen( customPromptText.takeIf { it.isNotBlank() } ?.also { promptText -> // Save custom prompt to database - viewModel.saveCustomPromptToRecents( - promptText - ) + viewModel.saveCustomPromptToRecents(promptText) } } } else null - viewModel.onConversationSelected(systemPrompt, onNavigateToConversation) } @@ -329,30 +330,51 @@ fun ModelLoadingScreen( } } }, - modifier = Modifier - .fillMaxWidth() - .height(56.dp), + modifier = Modifier.fillMaxWidth().height(56.dp), + colors = if (errorMessage != null) + ButtonDefaults.buttonColors( + disabledContainerColor = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f), + disabledContentColor = MaterialTheme.colorScheme.onErrorContainer.copy(alpha = 0.7f) + ) else ButtonDefaults.buttonColors(), enabled = selectedMode != null && !isLoading && (!useSystemPrompt || hasActiveSystemPrompt) ) { - if (isLoading) { - CircularProgressIndicator( - modifier = Modifier - .height(24.dp) - .width(24.dp) - ) - Spacer(modifier = Modifier.width(8.dp)) - Text( - text = when (engineState) { - is State.Initializing, State.Initialized -> "Initializing..." - is State.LoadingModel -> "Loading model..." - is State.ProcessingSystemPrompt -> "Processing system prompt..." - else -> "Processing..." - }, - style = MaterialTheme.typography.titleMedium - ) - } else { - Text(text = "Start", style = MaterialTheme.typography.titleMedium) + when { + errorMessage != null -> { + Icon( + imageVector = Icons.Default.Error, + contentDescription = errorMessage, + tint = MaterialTheme.colorScheme.error + ) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = errorMessage, + color = MaterialTheme.colorScheme.onErrorContainer + ) + } + + isLoading -> { + CircularProgressIndicator(modifier = Modifier.height(24.dp).width(24.dp)) + Spacer(modifier = Modifier.width(8.dp)) + Text( + text = when (engineState) { + is State.Initializing, State.Initialized -> "Initializing..." + is State.LoadingModel -> "Loading model..." + is State.ProcessingSystemPrompt -> "Processing system prompt..." + else -> "Processing..." + }, + style = MaterialTheme.typography.titleMedium + ) + } + + else -> { + Icon( + imageVector = Icons.Default.PlayArrow, + contentDescription = "Run model ${selectedModel?.name} with $selectedMode" + ) + Spacer(modifier = Modifier.width(8.dp)) + Text(text = "Start", style = MaterialTheme.typography.titleMedium) + } } } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelLoadingViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelLoadingViewModel.kt index 671ae2f69a..689fee13c0 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelLoadingViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/ModelLoadingViewModel.kt @@ -86,10 +86,12 @@ class ModelLoadingViewModel @Inject constructor( */ fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) = viewModelScope.launch { - selectedModel.value?.let { - modelRepository.updateModelLastUsed(it.id) + selectedModel.value?.let { model -> + modelLoadingService.loadModelForBenchmark()?.let { metrics -> + modelRepository.updateModelLastUsed(model.id) + onNavigateToBenchmark(metrics) + } } - onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark()) } /** @@ -100,10 +102,12 @@ class ModelLoadingViewModel @Inject constructor( systemPrompt: String? = null, onNavigateToConversation: (ModelLoadingMetrics) -> Unit ) = viewModelScope.launch { - selectedModel.value?.let { - modelRepository.updateModelLastUsed(it.id) + selectedModel.value?.let { model -> + modelLoadingService.loadModelForConversation(systemPrompt)?.let { metrics -> + modelRepository.updateModelLastUsed(model.id) + onNavigateToConversation(metrics) + } } - onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt)) } companion object { diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt index b39a8bcfc4..36cbce5afa 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt @@ -36,7 +36,7 @@ interface InferenceEngine { /** * Unloads the currently loaded model. */ - suspend fun unloadModel() + suspend fun cleanUp() /** * Cleans up resources when the engine is no longer needed. diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt index 8a242f11a3..9f322857d8 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt @@ -30,8 +30,8 @@ import java.io.File * 2. Load a model with [loadModel] * 3. Send prompts with [sendUserPrompt] * 4. Generate responses as token streams - * 5. Unload the model with [unloadModel] when switching models - * 6. Call [destroy] when completely done + * 5. Perform [cleanUp] when done with a model + * 6. Properly [destroy] when completely done * * State transitions are managed automatically and validated at each operation. * @@ -133,18 +133,23 @@ internal class InferenceEngineImpl private constructor( require(it.isFile) { "Model file is not a file: $pathToModel" } } - Log.i(TAG, "Loading model... \n$pathToModel") - _readyForSystemPrompt = false - _state.value = InferenceEngine.State.LoadingModel - load(pathToModel).let { result -> - if (result != 0) throw IllegalStateException("Failed to Load model: $result") + try { + Log.i(TAG, "Loading model... \n$pathToModel") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.LoadingModel + load(pathToModel).let { + if (it != 0) throw IllegalStateException("Failed to load the model!") + } + prepare().let { + if (it != 0) throw IllegalStateException("Failed to prepare resources!") + } + Log.i(TAG, "Model loaded!") + _readyForSystemPrompt = true + _state.value = InferenceEngine.State.ModelReady + } catch (e: Exception) { + _state.value = InferenceEngine.State.Error(e.message ?: "Unknown error") + throw e } - prepare().let { result -> - if (result != 0) throw IllegalStateException("Failed to prepare resources: $result") - } - Log.i(TAG, "Model loaded!") - _readyForSystemPrompt = true - _state.value = InferenceEngine.State.ModelReady } /** @@ -235,12 +240,12 @@ internal class InferenceEngineImpl private constructor( } /** - * Unloads the model and frees resources + * Unloads the model and frees resources, or reset error states */ - override suspend fun unloadModel() = + override suspend fun cleanUp() = withContext(llamaDispatcher) { when (val state = _state.value) { - is InferenceEngine.State.ModelReady, is InferenceEngine.State.Error -> { + is InferenceEngine.State.ModelReady -> { Log.i(TAG, "Unloading model and free resources...") _readyForSystemPrompt = false _state.value = InferenceEngine.State.UnloadingModel @@ -252,6 +257,13 @@ internal class InferenceEngineImpl private constructor( Unit } + is InferenceEngine.State.Error -> { + Log.i(TAG, "Resetting error states...") + _state.value = InferenceEngine.State.Initialized + Log.i(TAG, "States reset!") + Unit + } + else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}") } }