diff --git a/examples/llama.android/app/src/main/java/com/example/llama/engine/StubInferenceEngine.kt b/examples/llama.android/app/src/main/java/com/example/llama/engine/StubInferenceEngine.kt index bcd36b7914..936f43b891 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/engine/StubInferenceEngine.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/engine/StubInferenceEngine.kt @@ -80,7 +80,7 @@ class StubInferenceEngine : InferenceEngine { // If coroutine is cancelled, propagate cancellation throw e } catch (e: Exception) { - _state.value = State.Error(e.message ?: "Unknown error during model loading") + _state.value = State.Error(e) } } @@ -107,7 +107,7 @@ class StubInferenceEngine : InferenceEngine { // If coroutine is cancelled, propagate cancellation throw e } catch (e: Exception) { - _state.value = State.Error(e.message ?: "Unknown error during model loading") + _state.value = State.Error(e) } } @@ -142,13 +142,13 @@ class StubInferenceEngine : InferenceEngine { _state.value = State.ModelReady throw e } catch (e: Exception) { - _state.value = State.Error(e.message ?: "Unknown error during generation") + _state.value = State.Error(e) throw e } }.catch { e -> // If it's not a cancellation, update state to error if (e !is CancellationException) { - _state.value = State.Error(e.message ?: "Unknown error during generation") + _state.value = State.Error(Exception(e)) } throw e } @@ -198,7 +198,7 @@ class StubInferenceEngine : InferenceEngine { _state.value = State.ModelReady throw e } catch (e: Exception) { - _state.value = State.Error(e.message ?: "Unknown error during benchmarking") + _state.value = State.Error(e) "Error: ${e.message}" } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelLoadingScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelLoadingScreen.kt index 5381c56c7e..cb70888408 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelLoadingScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/ui/screens/ModelLoadingScreen.kt @@ -1,6 +1,7 @@ package com.example.llama.ui.screens import android.llama.cpp.InferenceEngine.State +import android.llama.cpp.UnsupportedArchitectureException import androidx.activity.compose.BackHandler import androidx.compose.animation.AnimatedVisibility import androidx.compose.animation.expandVertically @@ -112,7 +113,7 @@ fun ModelLoadingScreen( // Check if we're in a loading state val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady - val errorMessage = (engineState as? State.Error)?.errorMessage + val exception = (engineState as? State.Error)?.exception // Handle back navigation requests BackHandler { @@ -331,7 +332,7 @@ fun ModelLoadingScreen( } }, modifier = Modifier.fillMaxWidth().height(56.dp), - colors = if (errorMessage != null) + colors = if (exception != null) ButtonDefaults.buttonColors( disabledContainerColor = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f), disabledContentColor = MaterialTheme.colorScheme.onErrorContainer.copy(alpha = 0.7f) @@ -340,15 +341,21 @@ fun ModelLoadingScreen( (!useSystemPrompt || hasActiveSystemPrompt) ) { when { - errorMessage != null -> { + exception != null -> { + val message = if (exception is UnsupportedArchitectureException) { + "Unsupported architecture: ${selectedModel?.metadata?.architecture?.architecture}" + } else { + exception.message ?: "Unknown error" + } + Icon( imageVector = Icons.Default.Error, - contentDescription = errorMessage, + contentDescription = message, tint = MaterialTheme.colorScheme.error ) Spacer(modifier = Modifier.width(8.dp)) Text( - text = errorMessage, + text = message, color = MaterialTheme.colorScheme.onErrorContainer ) } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt index 36cbce5afa..62ae87af12 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt @@ -15,6 +15,8 @@ interface InferenceEngine { /** * Load a model from the given path. + * + * @throws UnsupportedArchitectureException if model architecture not supported */ suspend fun loadModel(pathToModel: String) @@ -61,7 +63,7 @@ interface InferenceEngine { object Generating : State() - data class Error(val errorMessage: String = "") : State() + data class Error(val exception: Exception) : State() } companion object { @@ -81,3 +83,5 @@ val State.isModelLoaded: Boolean this !is State.Initialized && this !is State.LoadingModel && this !is State.UnloadingModel + +class UnsupportedArchitectureException : Exception() diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt index 40153af873..008ce9ebdd 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt @@ -2,6 +2,7 @@ package android.llama.cpp.internal import android.llama.cpp.InferenceEngine import android.llama.cpp.LLamaTier +import android.llama.cpp.UnsupportedArchitectureException import android.util.Log import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope @@ -143,7 +144,7 @@ internal class InferenceEngineImpl private constructor( _state.value = InferenceEngine.State.LoadingModel load(pathToModel).let { // TODO-han.yin: find a better way to pass other error codes - if (it != 0) throw IOException("Unsupported architecture") + if (it != 0) throw UnsupportedArchitectureException() } prepare().let { if (it != 0) throw IOException("Failed to prepare resources") @@ -152,9 +153,8 @@ internal class InferenceEngineImpl private constructor( _readyForSystemPrompt = true _state.value = InferenceEngine.State.ModelReady } catch (e: Exception) { - val msg = e.message ?: "Unknown error" - Log.e(TAG, msg + "\n" + pathToModel, e) - _state.value = InferenceEngine.State.Error(msg) + Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e) + _state.value = InferenceEngine.State.Error(e) throw e } } @@ -177,9 +177,10 @@ internal class InferenceEngineImpl private constructor( _state.value = InferenceEngine.State.ProcessingSystemPrompt processSystemPrompt(prompt).let { result -> if (result != 0) { - val errorMessage = "Failed to process system prompt: $result" - _state.value = InferenceEngine.State.Error(errorMessage) - throw IllegalStateException(errorMessage) + RuntimeException("Failed to process system prompt: $result").also { + _state.value = InferenceEngine.State.Error(it) + throw it + } } } Log.i(TAG, "System prompt processed! Awaiting user prompt...") @@ -225,7 +226,7 @@ internal class InferenceEngineImpl private constructor( throw e } catch (e: Exception) { Log.e(TAG, "Error during generation!", e) - _state.value = InferenceEngine.State.Error(e.message ?: "Unknown error") + _state.value = InferenceEngine.State.Error(e) throw e } }.flowOn(llamaDispatcher)