llm: properly propagate error to UI upon failing to load selected model
This commit is contained in:
parent
3da54f497a
commit
dd5b20d74d
|
|
@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService {
|
|||
/**
|
||||
* Load a model for benchmark
|
||||
*/
|
||||
suspend fun loadModelForBenchmark(): ModelLoadingMetrics
|
||||
suspend fun loadModelForBenchmark(): ModelLoadingMetrics?
|
||||
|
||||
/**
|
||||
* Load a model for conversation
|
||||
*/
|
||||
suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics
|
||||
suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics?
|
||||
}
|
||||
|
||||
interface BenchmarkService : InferenceService {
|
||||
|
|
@ -132,7 +132,7 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
|||
|
||||
override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model }
|
||||
|
||||
override suspend fun unloadModel() = inferenceEngine.unloadModel()
|
||||
override suspend fun unloadModel() = inferenceEngine.cleanUp()
|
||||
|
||||
/**
|
||||
* Shut down inference engine
|
||||
|
|
@ -145,8 +145,10 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
|||
*
|
||||
*/
|
||||
|
||||
override suspend fun loadModelForBenchmark(): ModelLoadingMetrics =
|
||||
_currentModel.value?.let { model ->
|
||||
override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? {
|
||||
checkNotNull(_currentModel.value) { "Attempt to load model for bench while none selected!" }
|
||||
|
||||
return _currentModel.value?.let { model ->
|
||||
try {
|
||||
val modelLoadStartTs = System.currentTimeMillis()
|
||||
inferenceEngine.loadModel(model.path)
|
||||
|
|
@ -154,12 +156,15 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
|||
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error loading model", e)
|
||||
throw e
|
||||
null
|
||||
}
|
||||
} ?: throw IllegalStateException("No model selected!")
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics =
|
||||
_currentModel.value?.let { model ->
|
||||
override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? {
|
||||
checkNotNull(_currentModel.value) { "Attempt to load model for chat while none selected!" }
|
||||
|
||||
return _currentModel.value?.let { model ->
|
||||
try {
|
||||
_systemPrompt.value = systemPrompt
|
||||
|
||||
|
|
@ -181,10 +186,10 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
|||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error loading model", e)
|
||||
throw e
|
||||
null
|
||||
}
|
||||
} ?: throw IllegalStateException("No model selected!")
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
*
|
||||
|
|
|
|||
|
|
@ -206,7 +206,7 @@ class StubInferenceEngine : InferenceEngine {
|
|||
/**
|
||||
* Unloads the currently loaded model.
|
||||
*/
|
||||
override suspend fun unloadModel() =
|
||||
override suspend fun cleanUp() =
|
||||
withContext(llamaDispatcher) {
|
||||
when(val state = _state.value) {
|
||||
is State.ModelReady, is State.Error -> {
|
||||
|
|
|
|||
|
|
@ -22,7 +22,10 @@ import androidx.compose.foundation.lazy.items
|
|||
import androidx.compose.foundation.selection.selectable
|
||||
import androidx.compose.material.icons.Icons
|
||||
import androidx.compose.material.icons.filled.Check
|
||||
import androidx.compose.material.icons.filled.Error
|
||||
import androidx.compose.material.icons.filled.PlayArrow
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.ButtonDefaults
|
||||
import androidx.compose.material3.Card
|
||||
import androidx.compose.material3.CircularProgressIndicator
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
|
|
@ -109,6 +112,7 @@ fun ModelLoadingScreen(
|
|||
|
||||
// Check if we're in a loading state
|
||||
val isLoading = engineState !is State.Initialized && engineState !is State.ModelReady
|
||||
val errorMessage = (engineState as? State.Error)?.errorMessage
|
||||
|
||||
// Handle back navigation requests
|
||||
BackHandler {
|
||||
|
|
@ -315,13 +319,10 @@ fun ModelLoadingScreen(
|
|||
customPromptText.takeIf { it.isNotBlank() }
|
||||
?.also { promptText ->
|
||||
// Save custom prompt to database
|
||||
viewModel.saveCustomPromptToRecents(
|
||||
promptText
|
||||
)
|
||||
viewModel.saveCustomPromptToRecents(promptText)
|
||||
}
|
||||
}
|
||||
} else null
|
||||
|
||||
viewModel.onConversationSelected(systemPrompt, onNavigateToConversation)
|
||||
}
|
||||
|
||||
|
|
@ -329,30 +330,51 @@ fun ModelLoadingScreen(
|
|||
}
|
||||
}
|
||||
},
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.height(56.dp),
|
||||
modifier = Modifier.fillMaxWidth().height(56.dp),
|
||||
colors = if (errorMessage != null)
|
||||
ButtonDefaults.buttonColors(
|
||||
disabledContainerColor = MaterialTheme.colorScheme.errorContainer.copy(alpha = 0.3f),
|
||||
disabledContentColor = MaterialTheme.colorScheme.onErrorContainer.copy(alpha = 0.7f)
|
||||
) else ButtonDefaults.buttonColors(),
|
||||
enabled = selectedMode != null && !isLoading &&
|
||||
(!useSystemPrompt || hasActiveSystemPrompt)
|
||||
) {
|
||||
if (isLoading) {
|
||||
CircularProgressIndicator(
|
||||
modifier = Modifier
|
||||
.height(24.dp)
|
||||
.width(24.dp)
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text(
|
||||
text = when (engineState) {
|
||||
is State.Initializing, State.Initialized -> "Initializing..."
|
||||
is State.LoadingModel -> "Loading model..."
|
||||
is State.ProcessingSystemPrompt -> "Processing system prompt..."
|
||||
else -> "Processing..."
|
||||
},
|
||||
style = MaterialTheme.typography.titleMedium
|
||||
)
|
||||
} else {
|
||||
Text(text = "Start", style = MaterialTheme.typography.titleMedium)
|
||||
when {
|
||||
errorMessage != null -> {
|
||||
Icon(
|
||||
imageVector = Icons.Default.Error,
|
||||
contentDescription = errorMessage,
|
||||
tint = MaterialTheme.colorScheme.error
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text(
|
||||
text = errorMessage,
|
||||
color = MaterialTheme.colorScheme.onErrorContainer
|
||||
)
|
||||
}
|
||||
|
||||
isLoading -> {
|
||||
CircularProgressIndicator(modifier = Modifier.height(24.dp).width(24.dp))
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text(
|
||||
text = when (engineState) {
|
||||
is State.Initializing, State.Initialized -> "Initializing..."
|
||||
is State.LoadingModel -> "Loading model..."
|
||||
is State.ProcessingSystemPrompt -> "Processing system prompt..."
|
||||
else -> "Processing..."
|
||||
},
|
||||
style = MaterialTheme.typography.titleMedium
|
||||
)
|
||||
}
|
||||
|
||||
else -> {
|
||||
Icon(
|
||||
imageVector = Icons.Default.PlayArrow,
|
||||
contentDescription = "Run model ${selectedModel?.name} with $selectedMode"
|
||||
)
|
||||
Spacer(modifier = Modifier.width(8.dp))
|
||||
Text(text = "Start", style = MaterialTheme.typography.titleMedium)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,10 +86,12 @@ class ModelLoadingViewModel @Inject constructor(
|
|||
*/
|
||||
fun onBenchmarkSelected(onNavigateToBenchmark: (ModelLoadingMetrics) -> Unit) =
|
||||
viewModelScope.launch {
|
||||
selectedModel.value?.let {
|
||||
modelRepository.updateModelLastUsed(it.id)
|
||||
selectedModel.value?.let { model ->
|
||||
modelLoadingService.loadModelForBenchmark()?.let { metrics ->
|
||||
modelRepository.updateModelLastUsed(model.id)
|
||||
onNavigateToBenchmark(metrics)
|
||||
}
|
||||
}
|
||||
onNavigateToBenchmark(modelLoadingService.loadModelForBenchmark())
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -100,10 +102,12 @@ class ModelLoadingViewModel @Inject constructor(
|
|||
systemPrompt: String? = null,
|
||||
onNavigateToConversation: (ModelLoadingMetrics) -> Unit
|
||||
) = viewModelScope.launch {
|
||||
selectedModel.value?.let {
|
||||
modelRepository.updateModelLastUsed(it.id)
|
||||
selectedModel.value?.let { model ->
|
||||
modelLoadingService.loadModelForConversation(systemPrompt)?.let { metrics ->
|
||||
modelRepository.updateModelLastUsed(model.id)
|
||||
onNavigateToConversation(metrics)
|
||||
}
|
||||
}
|
||||
onNavigateToConversation(modelLoadingService.loadModelForConversation(systemPrompt))
|
||||
}
|
||||
|
||||
companion object {
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ interface InferenceEngine {
|
|||
/**
|
||||
* Unloads the currently loaded model.
|
||||
*/
|
||||
suspend fun unloadModel()
|
||||
suspend fun cleanUp()
|
||||
|
||||
/**
|
||||
* Cleans up resources when the engine is no longer needed.
|
||||
|
|
|
|||
|
|
@ -30,8 +30,8 @@ import java.io.File
|
|||
* 2. Load a model with [loadModel]
|
||||
* 3. Send prompts with [sendUserPrompt]
|
||||
* 4. Generate responses as token streams
|
||||
* 5. Unload the model with [unloadModel] when switching models
|
||||
* 6. Call [destroy] when completely done
|
||||
* 5. Perform [cleanUp] when done with a model
|
||||
* 6. Properly [destroy] when completely done
|
||||
*
|
||||
* State transitions are managed automatically and validated at each operation.
|
||||
*
|
||||
|
|
@ -133,18 +133,23 @@ internal class InferenceEngineImpl private constructor(
|
|||
require(it.isFile) { "Model file is not a file: $pathToModel" }
|
||||
}
|
||||
|
||||
Log.i(TAG, "Loading model... \n$pathToModel")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = InferenceEngine.State.LoadingModel
|
||||
load(pathToModel).let { result ->
|
||||
if (result != 0) throw IllegalStateException("Failed to Load model: $result")
|
||||
try {
|
||||
Log.i(TAG, "Loading model... \n$pathToModel")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = InferenceEngine.State.LoadingModel
|
||||
load(pathToModel).let {
|
||||
if (it != 0) throw IllegalStateException("Failed to load the model!")
|
||||
}
|
||||
prepare().let {
|
||||
if (it != 0) throw IllegalStateException("Failed to prepare resources!")
|
||||
}
|
||||
Log.i(TAG, "Model loaded!")
|
||||
_readyForSystemPrompt = true
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
} catch (e: Exception) {
|
||||
_state.value = InferenceEngine.State.Error(e.message ?: "Unknown error")
|
||||
throw e
|
||||
}
|
||||
prepare().let { result ->
|
||||
if (result != 0) throw IllegalStateException("Failed to prepare resources: $result")
|
||||
}
|
||||
Log.i(TAG, "Model loaded!")
|
||||
_readyForSystemPrompt = true
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -235,12 +240,12 @@ internal class InferenceEngineImpl private constructor(
|
|||
}
|
||||
|
||||
/**
|
||||
* Unloads the model and frees resources
|
||||
* Unloads the model and frees resources, or reset error states
|
||||
*/
|
||||
override suspend fun unloadModel() =
|
||||
override suspend fun cleanUp() =
|
||||
withContext(llamaDispatcher) {
|
||||
when (val state = _state.value) {
|
||||
is InferenceEngine.State.ModelReady, is InferenceEngine.State.Error -> {
|
||||
is InferenceEngine.State.ModelReady -> {
|
||||
Log.i(TAG, "Unloading model and free resources...")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = InferenceEngine.State.UnloadingModel
|
||||
|
|
@ -252,6 +257,13 @@ internal class InferenceEngineImpl private constructor(
|
|||
Unit
|
||||
}
|
||||
|
||||
is InferenceEngine.State.Error -> {
|
||||
Log.i(TAG, "Resetting error states...")
|
||||
_state.value = InferenceEngine.State.Initialized
|
||||
Log.i(TAG, "States reset!")
|
||||
Unit
|
||||
}
|
||||
|
||||
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue