From cb508be78239946c4f0088d52df25badb72416a4 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Fri, 18 Apr 2025 15:01:33 -0700 Subject: [PATCH] UI: migrate ModelLoadingScreen onto ModelLoadingViewModel; update & refine ModelLoadingScreen --- .../revamp/ui/screens/ModelLoadingScreen.kt | 280 +++++++++--------- .../revamp/viewmodel/ModelLoadingViewModel.kt | 3 +- 2 files changed, 149 insertions(+), 134 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt index ca55df5e02..cf5c97687d 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt @@ -1,6 +1,7 @@ package com.example.llama.revamp.ui.screens import android.llama.cpp.InferenceEngine.State +import androidx.activity.compose.BackHandler import androidx.compose.animation.AnimatedVisibility import androidx.compose.animation.expandVertically import androidx.compose.animation.fadeIn @@ -48,13 +49,19 @@ import androidx.compose.ui.Modifier import androidx.compose.ui.semantics.Role import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp -import androidx.hilt.navigation.compose.hiltViewModel import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.ui.components.ModelCard +import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler import com.example.llama.revamp.viewmodel.ModelLoadingViewModel import kotlinx.coroutines.Job import kotlinx.coroutines.launch + +enum class Mode { + BENCHMARK, + CONVERSATION +} + enum class SystemPromptTab { PRESETS, CUSTOM, RECENTS } @@ -62,17 +69,18 @@ enum class SystemPromptTab { @OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class) @Composable fun ModelLoadingScreen( - engineState: State, + onNavigateBack: () -> Unit, onBenchmarkSelected: (prepareJob: Job) -> Unit, onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit, - onBackPressed: () -> Unit, - modelLoadingViewModel: ModelLoadingViewModel = hiltViewModel(), + viewModel: ModelLoadingViewModel, ) { val coroutineScope = rememberCoroutineScope() - val selectedModel by modelLoadingViewModel.selectedModel.collectAsState() - val presetPrompts by modelLoadingViewModel.presetPrompts.collectAsState() - val recentPrompts by modelLoadingViewModel.recentPrompts.collectAsState() + val engineState by viewModel.engineState.collectAsState() + val selectedModel by viewModel.selectedModel.collectAsState() + val presetPrompts by viewModel.presetPrompts.collectAsState() + val recentPrompts by viewModel.recentPrompts.collectAsState() + val unloadDialogState by viewModel.unloadModelState.collectAsState() var selectedMode by remember { mutableStateOf(null) } var useSystemPrompt by remember { mutableStateOf(false) } @@ -99,24 +107,30 @@ fun ModelLoadingScreen( // Check if we're in a loading state val isLoading = engineState !is State.Uninitialized && - engineState !is State.LibraryLoaded && + engineState !is State.Initialized && engineState !is State.ModelReady // Mode selection callbacks val handleBenchmarkSelected = { val prepareJob = coroutineScope.launch { - modelLoadingViewModel.prepareForBenchmark() + viewModel.prepareForBenchmark() } onBenchmarkSelected(prepareJob) } + // TODO-han.yin: refactor this into ViewModel too val handleConversationSelected = { systemPrompt: String? -> val prepareJob = coroutineScope.launch { - modelLoadingViewModel.prepareForConversation(systemPrompt) + viewModel.prepareForConversation(systemPrompt) } onConversationSelected(systemPrompt, prepareJob) } + // Handle back navigation requests + BackHandler { + viewModel.onBackPressed(onNavigateBack) + } + Column( modifier = Modifier .fillMaxSize() @@ -246,124 +260,33 @@ fun ModelLoadingScreen( .padding(top = 4.dp, bottom = 8.dp) ) - // Tab selector using SegmentedButton - SingleChoiceSegmentedButtonRow( - modifier = Modifier.fillMaxWidth() - ) { - SegmentedButton( - selected = selectedTab == SystemPromptTab.PRESETS, - onClick = { selectedTab = SystemPromptTab.PRESETS }, - shape = SegmentedButtonDefaults.itemShape(index = 0, count = 3), - icon = { - if (selectedTab == SystemPromptTab.PRESETS) { - Icon( - imageVector = Icons.Default.Check, - contentDescription = null - ) - } - }, - label = { Text("Presets") } - ) - - SegmentedButton( - selected = selectedTab == SystemPromptTab.CUSTOM, - onClick = { selectedTab = SystemPromptTab.CUSTOM }, - shape = SegmentedButtonDefaults.itemShape(index = 1, count = 3), - icon = { - if (selectedTab == SystemPromptTab.CUSTOM) { - Icon( - imageVector = Icons.Default.Check, - contentDescription = null - ) - } - }, - label = { Text("Custom") } - ) - - SegmentedButton( - selected = selectedTab == SystemPromptTab.RECENTS, - onClick = { selectedTab = SystemPromptTab.RECENTS }, - shape = SegmentedButtonDefaults.itemShape(index = 2, count = 3), - icon = { - if (selectedTab == SystemPromptTab.RECENTS) { - Icon( - imageVector = Icons.Default.Check, - contentDescription = null - ) - } - }, - label = { Text("Recents") } - ) - } + SystemPromptTabSelector( + selectedTab = selectedTab, + onTabSelected = { selectedTab = it } + ) Spacer(modifier = Modifier.height(8.dp)) - // Content based on selected tab - when (selectedTab) { - SystemPromptTab.PRESETS -> { - if (presetPrompts.isEmpty()) { - Text( - text = "No preset prompts available.", - style = MaterialTheme.typography.bodyMedium, - color = MaterialTheme.colorScheme.onSurfaceVariant, - modifier = Modifier.padding(16.dp) - ) - } else { - PromptList( - prompts = presetPrompts, - selectedPromptId = selectedPrompt?.id, - expandedPromptId = expandedPromptId, - onPromptSelected = { - selectedPrompt = it - expandedPromptId = it.id - }, - onExpandPrompt = { expandedPromptId = it } - ) + SystemPromptTabContent( + selectedTab = selectedTab, + presetPrompts = presetPrompts, + recentPrompts = recentPrompts, + customPromptText = customPromptText, + onCustomPromptChange = { + customPromptText = it + // Deselect any preset prompt if typing custom + if (it.isNotBlank()) { + selectedPrompt = null } - } - - SystemPromptTab.CUSTOM -> { - // Custom prompt editor (fill remaining space) - OutlinedTextField( - value = customPromptText, - onValueChange = { - customPromptText = it - // Deselect any preset prompt if typing custom - if (it.isNotBlank()) { - selectedPrompt = null - } - }, - modifier = Modifier - .fillMaxWidth() - .fillMaxSize(), // Fill available space - label = { Text("Enter system prompt") }, - placeholder = { Text("You are a helpful assistant...") }, - minLines = 5 - ) - } - - SystemPromptTab.RECENTS -> { - if (recentPrompts.isEmpty()) { - Text( - text = "No recent prompts found.", - style = MaterialTheme.typography.bodyMedium, - color = MaterialTheme.colorScheme.onSurfaceVariant, - modifier = Modifier.padding(16.dp) - ) - } else { - PromptList( - prompts = recentPrompts, - selectedPromptId = selectedPrompt?.id, - expandedPromptId = expandedPromptId, - onPromptSelected = { - selectedPrompt = it - expandedPromptId = it.id - }, - onExpandPrompt = { expandedPromptId = it } - ) - } - } - } + }, + selectedPromptId = selectedPrompt?.id, + expandedPromptId = expandedPromptId, + onPromptSelected = { + selectedPrompt = it + expandedPromptId = it.id + }, + onExpandPrompt = { expandedPromptId = it } + ) } } } @@ -388,7 +311,7 @@ fun ModelLoadingScreen( SystemPromptTab.PRESETS, SystemPromptTab.RECENTS -> selectedPrompt?.let { prompt -> // Save the prompt to recent prompts database - modelLoadingViewModel.savePromptToRecents(prompt) + viewModel.savePromptToRecents(prompt) prompt.content } @@ -396,7 +319,7 @@ fun ModelLoadingScreen( customPromptText.takeIf { it.isNotBlank() } ?.also { promptText -> // Save custom prompt to database - modelLoadingViewModel.saveCustomPromptToRecents( + viewModel.saveCustomPromptToRecents( promptText ) } @@ -437,11 +360,109 @@ fun ModelLoadingScreen( } } } + + // Unload confirmation dialog + ModelUnloadDialogHandler( + unloadModelState = unloadDialogState, + onUnloadConfirmed = { viewModel.onUnloadConfirmed(onNavigateBack) }, + onUnloadDismissed = { viewModel.onUnloadDismissed() }, + onNavigateBack = onNavigateBack, + ) } +@Composable +private fun SystemPromptTabSelector( + selectedTab: SystemPromptTab, + onTabSelected: (SystemPromptTab) -> Unit +) { + SingleChoiceSegmentedButtonRow( + modifier = Modifier.fillMaxWidth() + ) { + SystemPromptTab.entries.forEachIndexed { index, tab -> + SegmentedButton( + selected = selectedTab == tab, + onClick = { onTabSelected(tab) }, + shape = SegmentedButtonDefaults.itemShape( + index = index, + count = SystemPromptTab.entries.size + ), + icon = { + if (selectedTab == tab) { + Icon( + imageVector = Icons.Default.Check, + contentDescription = null + ) + } + }, + label = { + Text( + when (tab) { + SystemPromptTab.PRESETS -> "Presets" + SystemPromptTab.CUSTOM -> "Custom" + SystemPromptTab.RECENTS -> "Recents" + } + ) + } + ) + } + } +} + +@Composable +private fun SystemPromptTabContent( + selectedTab: SystemPromptTab, + presetPrompts: List, + recentPrompts: List, + customPromptText: String, + onCustomPromptChange: (String) -> Unit, + selectedPromptId: String?, + expandedPromptId: String?, + onPromptSelected: (SystemPrompt) -> Unit, + onExpandPrompt: (String) -> Unit +) { + when (selectedTab) { + SystemPromptTab.PRESETS, SystemPromptTab.RECENTS -> { + val prompts = if (selectedTab == SystemPromptTab.PRESETS) presetPrompts else recentPrompts + + if (prompts.isEmpty()) { + Text( + text = + if (selectedTab == SystemPromptTab.PRESETS) "No preset prompts available." + else "No recent prompts found.", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.padding(16.dp) + ) + } else { + PromptList( + prompts = prompts, + selectedPromptId = selectedPromptId, + expandedPromptId = expandedPromptId, + onPromptSelected = onPromptSelected, + onExpandPrompt = onExpandPrompt + ) + } + } + + SystemPromptTab.CUSTOM -> { + OutlinedTextField( + value = customPromptText, + onValueChange = onCustomPromptChange, + modifier = Modifier + .fillMaxWidth() + .fillMaxSize(), + label = { Text("Enter system prompt") }, + placeholder = { Text("You are a helpful assistant...") }, + minLines = 5 + ) + } + } +} + + @OptIn(ExperimentalFoundationApi::class) @Composable -fun PromptList( +private fun PromptList( prompts: List, selectedPromptId: String?, expandedPromptId: String?, @@ -516,8 +537,3 @@ fun PromptList( } } } - -enum class Mode { - BENCHMARK, - CONVERSATION -} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt index 46efcc6942..d12cfcb36a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt @@ -1,6 +1,5 @@ package com.example.llama.revamp.viewmodel -import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.repository.SystemPromptRepository @@ -16,7 +15,7 @@ import javax.inject.Inject class ModelLoadingViewModel @Inject constructor( private val modelLoadingService: ModelLoadingService, private val repository: SystemPromptRepository -) : ViewModel() { +) : ModelUnloadingViewModel(modelLoadingService) { /** * Currently selected model to be loaded