UI: migrate ModelLoadingScreen onto ModelLoadingViewModel; update & refine ModelLoadingScreen

This commit is contained in:
Han Yin 2025-04-18 15:01:33 -07:00
parent f61c512223
commit cb508be782
2 changed files with 149 additions and 134 deletions

View File

@ -1,6 +1,7 @@
package com.example.llama.revamp.ui.screens package com.example.llama.revamp.ui.screens
import android.llama.cpp.InferenceEngine.State import android.llama.cpp.InferenceEngine.State
import androidx.activity.compose.BackHandler
import androidx.compose.animation.AnimatedVisibility import androidx.compose.animation.AnimatedVisibility
import androidx.compose.animation.expandVertically import androidx.compose.animation.expandVertically
import androidx.compose.animation.fadeIn import androidx.compose.animation.fadeIn
@ -48,13 +49,19 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.semantics.Role import androidx.compose.ui.semantics.Role
import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.text.style.TextOverflow
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.ui.components.ModelCard import com.example.llama.revamp.ui.components.ModelCard
import com.example.llama.revamp.ui.components.ModelUnloadDialogHandler
import com.example.llama.revamp.viewmodel.ModelLoadingViewModel import com.example.llama.revamp.viewmodel.ModelLoadingViewModel
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
enum class Mode {
BENCHMARK,
CONVERSATION
}
enum class SystemPromptTab { enum class SystemPromptTab {
PRESETS, CUSTOM, RECENTS PRESETS, CUSTOM, RECENTS
} }
@ -62,17 +69,18 @@ enum class SystemPromptTab {
@OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class) @OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class)
@Composable @Composable
fun ModelLoadingScreen( fun ModelLoadingScreen(
engineState: State, onNavigateBack: () -> Unit,
onBenchmarkSelected: (prepareJob: Job) -> Unit, onBenchmarkSelected: (prepareJob: Job) -> Unit,
onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit, onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit,
onBackPressed: () -> Unit, viewModel: ModelLoadingViewModel,
modelLoadingViewModel: ModelLoadingViewModel = hiltViewModel(),
) { ) {
val coroutineScope = rememberCoroutineScope() val coroutineScope = rememberCoroutineScope()
val selectedModel by modelLoadingViewModel.selectedModel.collectAsState() val engineState by viewModel.engineState.collectAsState()
val presetPrompts by modelLoadingViewModel.presetPrompts.collectAsState() val selectedModel by viewModel.selectedModel.collectAsState()
val recentPrompts by modelLoadingViewModel.recentPrompts.collectAsState() val presetPrompts by viewModel.presetPrompts.collectAsState()
val recentPrompts by viewModel.recentPrompts.collectAsState()
val unloadDialogState by viewModel.unloadModelState.collectAsState()
var selectedMode by remember { mutableStateOf<Mode?>(null) } var selectedMode by remember { mutableStateOf<Mode?>(null) }
var useSystemPrompt by remember { mutableStateOf(false) } var useSystemPrompt by remember { mutableStateOf(false) }
@ -99,24 +107,30 @@ fun ModelLoadingScreen(
// Check if we're in a loading state // Check if we're in a loading state
val isLoading = engineState !is State.Uninitialized && val isLoading = engineState !is State.Uninitialized &&
engineState !is State.LibraryLoaded && engineState !is State.Initialized &&
engineState !is State.ModelReady engineState !is State.ModelReady
// Mode selection callbacks // Mode selection callbacks
val handleBenchmarkSelected = { val handleBenchmarkSelected = {
val prepareJob = coroutineScope.launch { val prepareJob = coroutineScope.launch {
modelLoadingViewModel.prepareForBenchmark() viewModel.prepareForBenchmark()
} }
onBenchmarkSelected(prepareJob) onBenchmarkSelected(prepareJob)
} }
// TODO-han.yin: refactor this into ViewModel too
val handleConversationSelected = { systemPrompt: String? -> val handleConversationSelected = { systemPrompt: String? ->
val prepareJob = coroutineScope.launch { val prepareJob = coroutineScope.launch {
modelLoadingViewModel.prepareForConversation(systemPrompt) viewModel.prepareForConversation(systemPrompt)
} }
onConversationSelected(systemPrompt, prepareJob) onConversationSelected(systemPrompt, prepareJob)
} }
// Handle back navigation requests
BackHandler {
viewModel.onBackPressed(onNavigateBack)
}
Column( Column(
modifier = Modifier modifier = Modifier
.fillMaxSize() .fillMaxSize()
@ -246,124 +260,33 @@ fun ModelLoadingScreen(
.padding(top = 4.dp, bottom = 8.dp) .padding(top = 4.dp, bottom = 8.dp)
) )
// Tab selector using SegmentedButton SystemPromptTabSelector(
SingleChoiceSegmentedButtonRow( selectedTab = selectedTab,
modifier = Modifier.fillMaxWidth() onTabSelected = { selectedTab = it }
) { )
SegmentedButton(
selected = selectedTab == SystemPromptTab.PRESETS,
onClick = { selectedTab = SystemPromptTab.PRESETS },
shape = SegmentedButtonDefaults.itemShape(index = 0, count = 3),
icon = {
if (selectedTab == SystemPromptTab.PRESETS) {
Icon(
imageVector = Icons.Default.Check,
contentDescription = null
)
}
},
label = { Text("Presets") }
)
SegmentedButton(
selected = selectedTab == SystemPromptTab.CUSTOM,
onClick = { selectedTab = SystemPromptTab.CUSTOM },
shape = SegmentedButtonDefaults.itemShape(index = 1, count = 3),
icon = {
if (selectedTab == SystemPromptTab.CUSTOM) {
Icon(
imageVector = Icons.Default.Check,
contentDescription = null
)
}
},
label = { Text("Custom") }
)
SegmentedButton(
selected = selectedTab == SystemPromptTab.RECENTS,
onClick = { selectedTab = SystemPromptTab.RECENTS },
shape = SegmentedButtonDefaults.itemShape(index = 2, count = 3),
icon = {
if (selectedTab == SystemPromptTab.RECENTS) {
Icon(
imageVector = Icons.Default.Check,
contentDescription = null
)
}
},
label = { Text("Recents") }
)
}
Spacer(modifier = Modifier.height(8.dp)) Spacer(modifier = Modifier.height(8.dp))
// Content based on selected tab SystemPromptTabContent(
when (selectedTab) { selectedTab = selectedTab,
SystemPromptTab.PRESETS -> { presetPrompts = presetPrompts,
if (presetPrompts.isEmpty()) { recentPrompts = recentPrompts,
Text( customPromptText = customPromptText,
text = "No preset prompts available.", onCustomPromptChange = {
style = MaterialTheme.typography.bodyMedium, customPromptText = it
color = MaterialTheme.colorScheme.onSurfaceVariant, // Deselect any preset prompt if typing custom
modifier = Modifier.padding(16.dp) if (it.isNotBlank()) {
) selectedPrompt = null
} else {
PromptList(
prompts = presetPrompts,
selectedPromptId = selectedPrompt?.id,
expandedPromptId = expandedPromptId,
onPromptSelected = {
selectedPrompt = it
expandedPromptId = it.id
},
onExpandPrompt = { expandedPromptId = it }
)
} }
} },
selectedPromptId = selectedPrompt?.id,
SystemPromptTab.CUSTOM -> { expandedPromptId = expandedPromptId,
// Custom prompt editor (fill remaining space) onPromptSelected = {
OutlinedTextField( selectedPrompt = it
value = customPromptText, expandedPromptId = it.id
onValueChange = { },
customPromptText = it onExpandPrompt = { expandedPromptId = it }
// Deselect any preset prompt if typing custom )
if (it.isNotBlank()) {
selectedPrompt = null
}
},
modifier = Modifier
.fillMaxWidth()
.fillMaxSize(), // Fill available space
label = { Text("Enter system prompt") },
placeholder = { Text("You are a helpful assistant...") },
minLines = 5
)
}
SystemPromptTab.RECENTS -> {
if (recentPrompts.isEmpty()) {
Text(
text = "No recent prompts found.",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.padding(16.dp)
)
} else {
PromptList(
prompts = recentPrompts,
selectedPromptId = selectedPrompt?.id,
expandedPromptId = expandedPromptId,
onPromptSelected = {
selectedPrompt = it
expandedPromptId = it.id
},
onExpandPrompt = { expandedPromptId = it }
)
}
}
}
} }
} }
} }
@ -388,7 +311,7 @@ fun ModelLoadingScreen(
SystemPromptTab.PRESETS, SystemPromptTab.RECENTS -> SystemPromptTab.PRESETS, SystemPromptTab.RECENTS ->
selectedPrompt?.let { prompt -> selectedPrompt?.let { prompt ->
// Save the prompt to recent prompts database // Save the prompt to recent prompts database
modelLoadingViewModel.savePromptToRecents(prompt) viewModel.savePromptToRecents(prompt)
prompt.content prompt.content
} }
@ -396,7 +319,7 @@ fun ModelLoadingScreen(
customPromptText.takeIf { it.isNotBlank() } customPromptText.takeIf { it.isNotBlank() }
?.also { promptText -> ?.also { promptText ->
// Save custom prompt to database // Save custom prompt to database
modelLoadingViewModel.saveCustomPromptToRecents( viewModel.saveCustomPromptToRecents(
promptText promptText
) )
} }
@ -437,11 +360,109 @@ fun ModelLoadingScreen(
} }
} }
} }
// Unload confirmation dialog
ModelUnloadDialogHandler(
unloadModelState = unloadDialogState,
onUnloadConfirmed = { viewModel.onUnloadConfirmed(onNavigateBack) },
onUnloadDismissed = { viewModel.onUnloadDismissed() },
onNavigateBack = onNavigateBack,
)
} }
@Composable
private fun SystemPromptTabSelector(
selectedTab: SystemPromptTab,
onTabSelected: (SystemPromptTab) -> Unit
) {
SingleChoiceSegmentedButtonRow(
modifier = Modifier.fillMaxWidth()
) {
SystemPromptTab.entries.forEachIndexed { index, tab ->
SegmentedButton(
selected = selectedTab == tab,
onClick = { onTabSelected(tab) },
shape = SegmentedButtonDefaults.itemShape(
index = index,
count = SystemPromptTab.entries.size
),
icon = {
if (selectedTab == tab) {
Icon(
imageVector = Icons.Default.Check,
contentDescription = null
)
}
},
label = {
Text(
when (tab) {
SystemPromptTab.PRESETS -> "Presets"
SystemPromptTab.CUSTOM -> "Custom"
SystemPromptTab.RECENTS -> "Recents"
}
)
}
)
}
}
}
@Composable
private fun SystemPromptTabContent(
selectedTab: SystemPromptTab,
presetPrompts: List<SystemPrompt>,
recentPrompts: List<SystemPrompt>,
customPromptText: String,
onCustomPromptChange: (String) -> Unit,
selectedPromptId: String?,
expandedPromptId: String?,
onPromptSelected: (SystemPrompt) -> Unit,
onExpandPrompt: (String) -> Unit
) {
when (selectedTab) {
SystemPromptTab.PRESETS, SystemPromptTab.RECENTS -> {
val prompts = if (selectedTab == SystemPromptTab.PRESETS) presetPrompts else recentPrompts
if (prompts.isEmpty()) {
Text(
text =
if (selectedTab == SystemPromptTab.PRESETS) "No preset prompts available."
else "No recent prompts found.",
style = MaterialTheme.typography.bodyMedium,
color = MaterialTheme.colorScheme.onSurfaceVariant,
modifier = Modifier.padding(16.dp)
)
} else {
PromptList(
prompts = prompts,
selectedPromptId = selectedPromptId,
expandedPromptId = expandedPromptId,
onPromptSelected = onPromptSelected,
onExpandPrompt = onExpandPrompt
)
}
}
SystemPromptTab.CUSTOM -> {
OutlinedTextField(
value = customPromptText,
onValueChange = onCustomPromptChange,
modifier = Modifier
.fillMaxWidth()
.fillMaxSize(),
label = { Text("Enter system prompt") },
placeholder = { Text("You are a helpful assistant...") },
minLines = 5
)
}
}
}
@OptIn(ExperimentalFoundationApi::class) @OptIn(ExperimentalFoundationApi::class)
@Composable @Composable
fun PromptList( private fun PromptList(
prompts: List<SystemPrompt>, prompts: List<SystemPrompt>,
selectedPromptId: String?, selectedPromptId: String?,
expandedPromptId: String?, expandedPromptId: String?,
@ -516,8 +537,3 @@ fun PromptList(
} }
} }
} }
enum class Mode {
BENCHMARK,
CONVERSATION
}

View File

@ -1,6 +1,5 @@
package com.example.llama.revamp.viewmodel package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.data.repository.SystemPromptRepository
@ -16,7 +15,7 @@ import javax.inject.Inject
class ModelLoadingViewModel @Inject constructor( class ModelLoadingViewModel @Inject constructor(
private val modelLoadingService: ModelLoadingService, private val modelLoadingService: ModelLoadingService,
private val repository: SystemPromptRepository private val repository: SystemPromptRepository
) : ViewModel() { ) : ModelUnloadingViewModel(modelLoadingService) {
/** /**
* Currently selected model to be loaded * Currently selected model to be loaded