From 4848bf93d075618bf9b3fda3d67d48ca7a7342fe Mon Sep 17 00:00:00 2001 From: Han Yin Date: Sat, 12 Apr 2025 12:28:39 -0700 Subject: [PATCH] data: introduce repo for System Prompt; flow data from Room to VM --- .../llama/revamp/data/model/SystemPrompt.kt | 89 ++++++++----- .../data/repository/SystemPromptRepository.kt | 108 ++++++++++++++++ .../revamp/ui/screens/ModeSelectionScreen.kt | 117 ++++++++++++------ .../revamp/util/ViewModelFactoryProvider.kt | 21 +++- .../revamp/viewmodel/SystemPromptViewModel.kt | 86 +++++++++++++ 5 files changed, 344 insertions(+), 77 deletions(-) create mode 100644 examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/SystemPromptRepository.kt create mode 100644 examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/SystemPromptViewModel.kt diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/SystemPrompt.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/SystemPrompt.kt index 91ceb7a3f2..7d89f8b51a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/SystemPrompt.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/SystemPrompt.kt @@ -1,72 +1,93 @@ package com.example.llama.revamp.data.model +import java.text.SimpleDateFormat +import java.util.Date +import java.util.Locale +import java.util.UUID + /** - * Data class representing a system prompt for LLM. + * Sealed class for system prompts with distinct types. */ -data class SystemPrompt( - val id: String, - val name: String, - val content: String, - val category: Category, - val lastUsed: Long? = null -) { - enum class Category { - STAFF_PICK, - USER_CREATED, - RECENT +sealed class SystemPrompt { + abstract val id: String + abstract val content: String + abstract val title: String + abstract val timestamp: Long? + + /** + * Preset system prompt from predefined collection. + */ + data class Preset( + override val id: String, + override val content: String, + val name: String, + override val timestamp: Long? = null + ) : SystemPrompt() { + override val title: String + get() = name + } + + /** + * Custom system prompt created by the user. + */ + data class Custom( + override val id: String = UUID.randomUUID().toString(), + override val content: String, + override val timestamp: Long = System.currentTimeMillis() + ) : SystemPrompt() { + override val title: String + get() = if (timestamp != null) { + val dateFormat = SimpleDateFormat("yyyy-MM-dd HH:mm", Locale.getDefault()) + dateFormat.format(Date(timestamp)) + } else { + "Custom Prompt" + } } companion object { /** - * Creates a list of sample system prompts for development and testing. + * Creates a list of sample presets. */ fun getStaffPickedPrompts(): List { return listOf( - SystemPrompt( + Preset( id = "assistant", name = "Helpful Assistant", - content = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should be informative and engaging. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.", - category = Category.STAFF_PICK + content = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should be informative and engaging. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." ), - SystemPrompt( + Preset( id = "coder", name = "Coding Assistant", - content = "You are a helpful programming assistant. When asked coding questions, provide clear and functional code examples when applicable. If a question is ambiguous, ask for clarification. Focus on providing accurate solutions with good coding practices and explain your solutions.", - category = Category.STAFF_PICK + content = "You are a helpful programming assistant. When asked coding questions, provide clear and functional code examples when applicable. If a question is ambiguous, ask for clarification. Focus on providing accurate solutions with good coding practices and explain your solutions." ), - SystemPrompt( + Preset( id = "summarizer", name = "Text Summarizer", - content = "You are a helpful assistant that specializes in summarizing text. When provided with a text, create a concise summary that captures the main points, key details, and overall message. Adjust summary length based on original content length. Maintain factual accuracy and avoid adding information not present in the original text.", - category = Category.STAFF_PICK + content = "You are a helpful assistant that specializes in summarizing text. When provided with a text, create a concise summary that captures the main points, key details, and overall message. Adjust summary length based on original content length. Maintain factual accuracy and avoid adding information not present in the original text." ), - SystemPrompt( + Preset( id = "creative", name = "Creative Writer", - content = "You are a creative writing assistant with a vivid imagination. Help users draft stories, poems, scripts, and other creative content. Provide imaginative ideas while following the user's specifications. When responding, focus on being original, engaging, and matching the requested tone and style.", - category = Category.STAFF_PICK + content = "You are a creative writing assistant with a vivid imagination. Help users draft stories, poems, scripts, and other creative content. Provide imaginative ideas while following the user's specifications. When responding, focus on being original, engaging, and matching the requested tone and style." ) ) } /** - * Get recent system prompts (would normally be from storage) + * Creates a placeholder list of recent prompts. + * In a real implementation, this would be loaded from the database. */ fun getRecentPrompts(): List { return listOf( - SystemPrompt( + Custom( id = "custom-1", - name = "Technical Writer", content = "You are a technical documentation specialist. When responding, focus on clarity, precision, and structure. Use appropriate technical terminology based on the context, but avoid jargon when simpler terms would suffice. Include examples where helpful, and organize information in a logical manner.", - category = Category.USER_CREATED, - lastUsed = System.currentTimeMillis() - 3600000 // 1 hour ago + timestamp = System.currentTimeMillis() - 3600000 // 1 hour ago ), - SystemPrompt( + Custom( id = "custom-2", - name = "Science Educator", content = "You are a science educator with expertise in explaining complex concepts in accessible ways. Provide accurate, informative responses that help users understand scientific topics. Use analogies, examples, and clear explanations to make difficult concepts understandable. Cite established scientific consensus and explain levels of certainty when appropriate.", - category = Category.USER_CREATED, - lastUsed = System.currentTimeMillis() - 86400000 // 1 day ago + timestamp = System.currentTimeMillis() - 86400000 // 1 day ago ) ) } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/SystemPromptRepository.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/SystemPromptRepository.kt new file mode 100644 index 0000000000..9529d65d04 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/repository/SystemPromptRepository.kt @@ -0,0 +1,108 @@ +package com.example.llama.revamp.data.repository + +import android.content.Context +import com.example.llama.revamp.data.local.AppDatabase +import com.example.llama.revamp.data.local.SystemPromptEntity +import com.example.llama.revamp.data.model.SystemPrompt +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.map +import java.util.UUID + +/** + * Repository for managing system prompts. + */ +class SystemPromptRepository(context: Context) { + + private val systemPromptDao = AppDatabase.getDatabase(context).systemPromptDao() + + // Maximum number of recent prompts to keep + private val MAX_RECENT_PROMPTS = 10 + + /** + * Get all preset prompts. + */ + fun getPresetPrompts(): Flow> { + // For now, we'll just return the static list since we don't store presets in the database + return kotlinx.coroutines.flow.flowOf(SystemPrompt.getStaffPickedPrompts()) + } + + /** + * Get recent prompts from the database. + */ + fun getRecentPrompts(): Flow> { + return systemPromptDao.getRecentPrompts(MAX_RECENT_PROMPTS) + .map { entities -> + entities.map { it.toDomainModel() } + } + } + + /** + * Save a prompt to the recents list. + * If it's already in recents, just update the timestamp. + */ + suspend fun savePromptToRecents(prompt: SystemPrompt) { + // Check if this prompt already exists + val existingPrompt = systemPromptDao.getPromptById(prompt.id) + + if (existingPrompt != null) { + // Update the timestamp to mark it as recently used + systemPromptDao.updatePromptTimestamp(prompt.id, System.currentTimeMillis()) + } else { + // Insert as a new prompt + systemPromptDao.insertPrompt(SystemPromptEntity.fromDomainModel(prompt)) + + // Check if we need to trim the list + pruneOldPrompts() + } + } + + /** + * Create and save a custom prompt. + */ + suspend fun saveCustomPrompt(content: String): SystemPrompt { + val customPrompt = SystemPrompt.Custom( + id = UUID.randomUUID().toString(), + content = content + ) + + systemPromptDao.insertPrompt(SystemPromptEntity.fromDomainModel(customPrompt)) + + // Check if we need to trim the list + pruneOldPrompts() + + return customPrompt + } + + /** + * Remove prompts if we exceed the maximum count. + */ + private suspend fun pruneOldPrompts() { + val count = systemPromptDao.getPromptCount() + if (count > MAX_RECENT_PROMPTS) { + // Get all prompts and delete the oldest ones + val allPrompts = systemPromptDao.getAllPrompts().first() + val promptsToDelete = allPrompts + .sortedByDescending { it.timestamp } + .drop(MAX_RECENT_PROMPTS) + + promptsToDelete.forEach { + systemPromptDao.deletePrompt(it) + } + } + } + + /** + * Delete a prompt by ID. + */ + suspend fun deletePrompt(id: String) { + systemPromptDao.deletePromptById(id) + } + + /** + * Delete all prompts. + */ + suspend fun deleteAllPrompts() { + systemPromptDao.deleteAllPrompts() + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModeSelectionScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModeSelectionScreen.kt index 8d14e7655b..09704c50bf 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModeSelectionScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModeSelectionScreen.kt @@ -17,10 +17,7 @@ import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.width import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items -import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.selection.selectable -import androidx.compose.foundation.selection.selectableGroup -import androidx.compose.foundation.verticalScroll import androidx.compose.material.icons.Icons import androidx.compose.material.icons.filled.Check import androidx.compose.material3.Button @@ -39,24 +36,26 @@ import androidx.compose.material3.SingleChoiceSegmentedButtonRow import androidx.compose.material3.Switch import androidx.compose.material3.Text import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember -import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.semantics.Role import androidx.compose.ui.text.style.TextOverflow import androidx.compose.ui.unit.dp +import androidx.lifecycle.viewmodel.compose.viewModel import com.example.llama.revamp.data.model.SystemPrompt +import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.navigation.NavigationActions import com.example.llama.revamp.ui.components.AppScaffold -import kotlinx.coroutines.launch -import java.text.SimpleDateFormat -import java.util.Date -import java.util.Locale +import com.example.llama.revamp.util.ViewModelFactoryProvider +import com.example.llama.revamp.viewmodel.SystemPromptViewModel enum class SystemPromptTab { PRESETS, CUSTOM, RECENTS @@ -72,17 +71,37 @@ fun ModeSelectionScreen( drawerState: DrawerState, navigationActions: NavigationActions ) { - val staffPickedPrompts = remember { SystemPrompt.getStaffPickedPrompts() } - val recentPrompts = remember { SystemPrompt.getRecentPrompts() } + // Set up SystemPromptViewModel + val context = LocalContext.current + val repository = remember { SystemPromptRepository(context) } + val factory = remember { ViewModelFactoryProvider.getSystemPromptViewModelFactory(repository) } + val viewModel: SystemPromptViewModel = viewModel(factory = factory) + + val presetPrompts by viewModel.presetPrompts.collectAsState() + val recentPrompts by viewModel.recentPrompts.collectAsState() var selectedMode by remember { mutableStateOf(null) } var useSystemPrompt by remember { mutableStateOf(false) } - var selectedPrompt by remember { mutableStateOf(staffPickedPrompts.firstOrNull()) } + var selectedPrompt by remember { mutableStateOf(null) } var selectedTab by remember { mutableStateOf(SystemPromptTab.PRESETS) } var customPromptText by remember { mutableStateOf("") } var expandedPromptId by remember { mutableStateOf(null) } - val coroutineScope = rememberCoroutineScope() + // Automatically select first preset and expand it + LaunchedEffect(presetPrompts) { + if (presetPrompts.isNotEmpty() && selectedPrompt == null) { + val firstPreset = presetPrompts.first() + selectedPrompt = firstPreset + expandedPromptId = firstPreset.id + } + } + + // Determine if a system prompt is actually selected/entered when the switch is on + val hasActiveSystemPrompt = when { + !useSystemPrompt -> true // Not using system prompt, so this is fine + selectedTab == SystemPromptTab.CUSTOM -> customPromptText.isNotBlank() + else -> selectedPrompt != null + } // Check if we're in a loading state val isLoading = engineState !is InferenceEngine.State.Uninitialized && @@ -134,10 +153,14 @@ fun ModeSelectionScreen( modifier = Modifier .fillMaxWidth() .padding(bottom = 8.dp) + // Only use weight if system prompt is active, otherwise wrap content + .then(if (useSystemPrompt) Modifier.weight(1f) else Modifier) ) { Column( modifier = Modifier .fillMaxWidth() + // Only fill height if system prompt is active + .then(if (useSystemPrompt) Modifier.fillMaxSize() else Modifier) ) { // Conversation option Row( @@ -199,6 +222,7 @@ fun ModeSelectionScreen( Column( modifier = Modifier .fillMaxWidth() + .fillMaxSize() // Fill remaining card space .padding(horizontal = 16.dp, vertical = 8.dp) ) { HorizontalDivider(modifier = Modifier.padding(vertical = 8.dp)) @@ -258,20 +282,29 @@ fun ModeSelectionScreen( // Content based on selected tab when (selectedTab) { SystemPromptTab.PRESETS -> { - PromptList( - prompts = staffPickedPrompts, - selectedPromptId = selectedPrompt?.id, - expandedPromptId = expandedPromptId, - onPromptSelected = { - selectedPrompt = it - expandedPromptId = it.id - }, - onExpandPrompt = { expandedPromptId = it } - ) + if (presetPrompts.isEmpty()) { + Text( + text = "No preset prompts available.", + style = MaterialTheme.typography.bodyMedium, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.padding(16.dp) + ) + } else { + PromptList( + prompts = presetPrompts, + selectedPromptId = selectedPrompt?.id, + expandedPromptId = expandedPromptId, + onPromptSelected = { + selectedPrompt = it + expandedPromptId = it.id + }, + onExpandPrompt = { expandedPromptId = it } + ) + } } SystemPromptTab.CUSTOM -> { - // Custom prompt editor + // Custom prompt editor (fill remaining space) OutlinedTextField( value = customPromptText, onValueChange = { @@ -283,11 +316,10 @@ fun ModeSelectionScreen( }, modifier = Modifier .fillMaxWidth() - .height(200.dp), + .fillMaxSize(), // Fill available space label = { Text("Enter system prompt") }, placeholder = { Text("You are a helpful assistant...") }, - minLines = 5, - maxLines = 10 + minLines = 5 ) } @@ -318,7 +350,10 @@ fun ModeSelectionScreen( } } - Spacer(modifier = Modifier.weight(1f)) + // Flexible spacer when system prompt is not active + if (!useSystemPrompt) { + Spacer(modifier = Modifier.weight(1f)) + } // Start button Button( @@ -329,9 +364,16 @@ fun ModeSelectionScreen( val systemPrompt = if (useSystemPrompt) { when (selectedTab) { SystemPromptTab.PRESETS, SystemPromptTab.RECENTS -> - selectedPrompt?.content + selectedPrompt?.let { prompt -> + // Save the prompt to recent prompts database + viewModel.savePromptToRecents(prompt) + prompt.content + } SystemPromptTab.CUSTOM -> - customPromptText.takeIf { it.isNotBlank() } + customPromptText.takeIf { it.isNotBlank() }?.also { promptText -> + // Save custom prompt to database + viewModel.saveCustomPromptToRecents(promptText) + } } } else null onConversationSelected(systemPrompt) @@ -342,7 +384,8 @@ fun ModeSelectionScreen( modifier = Modifier .fillMaxWidth() .height(56.dp), - enabled = selectedMode != null && !isLoading + enabled = selectedMode != null && !isLoading && + (!useSystemPrompt || hasActiveSystemPrompt) ) { if (isLoading) { CircularProgressIndicator( @@ -379,7 +422,7 @@ fun PromptList( LazyColumn( modifier = Modifier .fillMaxWidth() - .height(250.dp), + .fillMaxSize(), // Fill available space verticalArrangement = Arrangement.spacedBy(8.dp) ) { items( @@ -416,16 +459,8 @@ fun PromptList( .weight(1f) .padding(start = 8.dp) ) { - // Format title for recents if needed - val title = if (prompt.category == SystemPrompt.Category.USER_CREATED && prompt.lastUsed != null) { - val dateFormat = SimpleDateFormat("yyyy-MM-dd HH:mm", Locale.getDefault()) - dateFormat.format(Date(prompt.lastUsed)) - } else { - prompt.name - } - Text( - text = title, + text = prompt.title, style = MaterialTheme.typography.titleSmall, color = if (isSelected) MaterialTheme.colorScheme.primary @@ -443,7 +478,7 @@ fun PromptList( } } - if (prompt != prompts.last()) { + if (prompt.id != prompts.last().id) { HorizontalDivider( modifier = Modifier.padding(top = 8.dp, start = 40.dp) ) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ViewModelFactoryProvider.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ViewModelFactoryProvider.kt index 3b3083d855..2b7e72fe14 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ViewModelFactoryProvider.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/util/ViewModelFactoryProvider.kt @@ -3,15 +3,15 @@ package com.example.llama.revamp.util import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModelProvider import com.example.llama.revamp.data.preferences.UserPreferences +import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.monitoring.PerformanceMonitor import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.PerformanceViewModel +import com.example.llama.revamp.viewmodel.SystemPromptViewModel /** * Utility class to provide ViewModel factories. - * - * TODO-han.yin: Replace with Hilt */ object ViewModelFactoryProvider { @@ -49,4 +49,21 @@ object ViewModelFactoryProvider { } } } + + /** + * Creates a factory for SystemPromptViewModel. + */ + fun getSystemPromptViewModelFactory( + repository: SystemPromptRepository + ): ViewModelProvider.Factory { + return object : ViewModelProvider.Factory { + @Suppress("UNCHECKED_CAST") + override fun create(modelClass: Class): T { + if (modelClass.isAssignableFrom(SystemPromptViewModel::class.java)) { + return SystemPromptViewModel(repository) as T + } + throw IllegalArgumentException("Unknown ViewModel class: ${modelClass.name}") + } + } + } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/SystemPromptViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/SystemPromptViewModel.kt new file mode 100644 index 0000000000..d3ee2df380 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/SystemPromptViewModel.kt @@ -0,0 +1,86 @@ +package com.example.llama.revamp.viewmodel + +import androidx.lifecycle.ViewModel +import androidx.lifecycle.ViewModelProvider +import androidx.lifecycle.viewModelScope +import com.example.llama.revamp.data.model.SystemPrompt +import com.example.llama.revamp.data.repository.SystemPromptRepository +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.launch + +/** + * ViewModel for handling system prompts. + */ +class SystemPromptViewModel( + private val repository: SystemPromptRepository +) : ViewModel() { + + // Preset prompts + val presetPrompts: StateFlow> = repository.getPresetPrompts() + .stateIn( + scope = viewModelScope, + started = SharingStarted.WhileSubscribed(5000), + initialValue = emptyList() + ) + + // Recent prompts + val recentPrompts: StateFlow> = repository.getRecentPrompts() + .stateIn( + scope = viewModelScope, + started = SharingStarted.WhileSubscribed(5000), + initialValue = emptyList() + ) + + /** + * Save a prompt to the recents list. + */ + fun savePromptToRecents(prompt: SystemPrompt) { + viewModelScope.launch { + repository.savePromptToRecents(prompt) + } + } + + /** + * Create and save a custom prompt. + */ + fun saveCustomPromptToRecents(content: String) { + viewModelScope.launch { + repository.saveCustomPrompt(content) + } + } + + /** + * Delete a prompt by ID. + */ + fun deletePrompt(id: String) { + viewModelScope.launch { + repository.deletePrompt(id) + } + } + + /** + * Clear all recent prompts. + */ + fun clearRecentPrompts() { + viewModelScope.launch { + repository.deleteAllPrompts() + } + } + + /** + * Factory for creating SystemPromptViewModel instances. + */ + class Factory( + private val repository: SystemPromptRepository + ) : ViewModelProvider.Factory { + @Suppress("UNCHECKED_CAST") + override fun create(modelClass: Class): T { + if (modelClass.isAssignableFrom(SystemPromptViewModel::class.java)) { + return SystemPromptViewModel(repository) as T + } + throw IllegalArgumentException("Unknown ViewModel class") + } + } +}