diff --git a/examples/llama.android/app/src/main/AndroidManifest.xml b/examples/llama.android/app/src/main/AndroidManifest.xml index ffbd88a081..37b5c94626 100644 --- a/examples/llama.android/app/src/main/AndroidManifest.xml +++ b/examples/llama.android/app/src/main/AndroidManifest.xml @@ -20,6 +20,17 @@ android:exported="true" android:theme="@style/Theme.LlamaAndroid"> + + + + + + + + diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt new file mode 100644 index 0000000000..dd0d6231c8 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt @@ -0,0 +1,229 @@ +package com.example.llama.revamp + +import android.os.Bundle +import androidx.activity.ComponentActivity +import androidx.activity.compose.setContent +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.material3.DrawerValue +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Surface +import androidx.compose.material3.rememberDrawerState +import androidx.compose.runtime.Composable +import androidx.compose.runtime.LaunchedEffect +import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.rememberCoroutineScope +import androidx.compose.runtime.setValue +import androidx.compose.ui.Modifier +import androidx.lifecycle.viewmodel.compose.viewModel +import androidx.navigation.NavType +import androidx.navigation.compose.NavHost +import androidx.navigation.compose.composable +import androidx.navigation.compose.rememberNavController +import androidx.navigation.navArgument +import com.example.llama.revamp.engine.InferenceEngine +import com.example.llama.revamp.navigation.AppDestinations +import com.example.llama.revamp.navigation.NavigationActions +import com.example.llama.revamp.ui.components.UnloadModelConfirmationDialog +import com.example.llama.revamp.ui.screens.BenchmarkScreen +import com.example.llama.revamp.ui.screens.ConversationScreen +import com.example.llama.revamp.ui.screens.ModelSelectionScreen +import com.example.llama.revamp.ui.screens.ModeSelectionScreen +import com.example.llama.revamp.ui.screens.SettingsScreen +import com.example.llama.revamp.ui.screens.SettingsTab +import com.example.llama.revamp.ui.theme.LlamaTheme +import com.example.llama.revamp.util.ViewModelFactoryProvider +import com.example.llama.revamp.viewmodel.MainViewModel +import kotlinx.coroutines.launch + +class MainActivity : ComponentActivity() { + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + setContent { + LlamaTheme { + Surface( + modifier = Modifier.fillMaxSize(), + color = MaterialTheme.colorScheme.background + ) { + AppContent() + } + } + } + } +} + +@Composable +fun AppContent() { + val navController = rememberNavController() + val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed) + val coroutineScope = rememberCoroutineScope() + + // Create inference engine + val inferenceEngine = remember { InferenceEngine() } + + // Create factory for MainViewModel + val factory = remember { ViewModelFactoryProvider.getMainViewModelFactory(inferenceEngine) } + + // Get ViewModel instance with factory + val viewModel: MainViewModel = viewModel(factory = factory) + + val engineState by viewModel.engineState.collectAsState() + + val navigationActions = remember(navController) { + NavigationActions(navController) + } + + // Model unloading confirmation + var showUnloadDialog by remember { mutableStateOf(false) } + var pendingNavigation by remember { mutableStateOf<(() -> Unit)?>(null) } + + // Observe back button + LaunchedEffect(navController) { + navController.addOnDestinationChangedListener { _, destination, _ -> + // Log navigation for debugging + println("Navigation: ${destination.route}") + } + } + + // Handle drawer state + val openDrawer: () -> Unit = { + coroutineScope.launch { + drawerState.open() + } + } + + // Main Content + NavHost( + navController = navController, + startDestination = AppDestinations.MODEL_SELECTION_ROUTE + ) { + // Model Selection Screen + composable(AppDestinations.MODEL_SELECTION_ROUTE) { + ModelSelectionScreen( + onModelSelected = { modelInfo -> + viewModel.selectModel(modelInfo) + navigationActions.navigateToModeSelection() + }, + onManageModelsClicked = { + navigationActions.navigateToSettings(SettingsTab.MODEL_MANAGEMENT.name) + }, + onMenuClicked = openDrawer, + drawerState = drawerState, + navigationActions = navigationActions + ) + } + + // Mode Selection Screen + composable(AppDestinations.MODE_SELECTION_ROUTE) { + ModeSelectionScreen( + engineState = engineState, + onBenchmarkSelected = { + viewModel.prepareForBenchmark() + navigationActions.navigateToBenchmark() + }, + onConversationSelected = { systemPrompt -> + viewModel.prepareForConversation(systemPrompt) + navigationActions.navigateToConversation() + }, + onBackPressed = { + // Need to unload model before going back + if (viewModel.isModelLoaded()) { + showUnloadDialog = true + pendingNavigation = { navController.popBackStack() } + } else { + navController.popBackStack() + } + }, + drawerState = drawerState, + navigationActions = navigationActions + ) + } + + // Conversation Screen + composable(AppDestinations.CONVERSATION_ROUTE) { + ConversationScreen( + onBackPressed = { + // Need to unload model before going back + if (viewModel.isModelLoaded()) { + showUnloadDialog = true + pendingNavigation = { navController.popBackStack() } + } else { + navController.popBackStack() + } + }, + drawerState = drawerState, + navigationActions = navigationActions, + viewModel = viewModel + ) + } + + // Benchmark Screen + composable(AppDestinations.BENCHMARK_ROUTE) { + BenchmarkScreen( + onBackPressed = { + // Need to unload model before going back + if (viewModel.isModelLoaded()) { + showUnloadDialog = true + pendingNavigation = { navController.popBackStack() } + } else { + navController.popBackStack() + } + }, + onRerunPressed = { + viewModel.rerunBenchmark() + }, + onSharePressed = { + // Stub for sharing functionality + }, + drawerState = drawerState, + navigationActions = navigationActions, + viewModel = viewModel + ) + } + + // Settings Screen + composable( + route = "${AppDestinations.SETTINGS_ROUTE}/{tab}", + arguments = listOf( + navArgument("tab") { + type = NavType.StringType + defaultValue = SettingsTab.GENERAL.name + } + ) + ) { backStackEntry -> + val tabName = backStackEntry.arguments?.getString("tab") ?: SettingsTab.GENERAL.name + val tab = try { + SettingsTab.valueOf(tabName) + } catch (e: IllegalArgumentException) { + SettingsTab.GENERAL + } + + SettingsScreen( + selectedTab = tab, + onBackPressed = { navController.popBackStack() }, + drawerState = drawerState, + navigationActions = navigationActions + ) + } + } + + // Model unload confirmation dialog + if (showUnloadDialog) { + UnloadModelConfirmationDialog( + onConfirm = { + showUnloadDialog = false + coroutineScope.launch { + viewModel.unloadModel() + pendingNavigation?.invoke() + pendingNavigation = null + } + }, + onDismiss = { + showUnloadDialog = false + pendingNavigation = null + } + ) + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt new file mode 100644 index 0000000000..05439d81af --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt @@ -0,0 +1,205 @@ +package com.example.llama.revamp.viewmodel + +import androidx.lifecycle.ViewModel +import androidx.lifecycle.ViewModelProvider +import androidx.lifecycle.viewModelScope +import com.example.llama.revamp.data.model.ModelInfo +import com.example.llama.revamp.engine.InferenceEngine +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.launch +import java.text.SimpleDateFormat +import java.util.Date +import java.util.Locale + +/** + * Main ViewModel that handles the LLM engine state and operations. + */ +class MainViewModel( + private val inferenceEngine: InferenceEngine = InferenceEngine() +) : ViewModel() { + + // Expose the engine state + val engineState: StateFlow = inferenceEngine.state + + // Benchmark results + val benchmarkResults: StateFlow = inferenceEngine.benchmarkResults + + // Selected model information + private val _selectedModel = MutableStateFlow(null) + val selectedModel: StateFlow = _selectedModel.asStateFlow() + + // Benchmark parameters + private var pp: Int = 32 + private var tg: Int = 32 + private var pl: Int = 512 + + // Messages in the conversation + private val _messages = MutableStateFlow>(emptyList()) + val messages: StateFlow> = _messages.asStateFlow() + + // System prompt for the conversation + private val _systemPrompt = MutableStateFlow(null) + val systemPrompt: StateFlow = _systemPrompt.asStateFlow() + + /** + * Selects a model for use. + */ + fun selectModel(modelInfo: ModelInfo) { + _selectedModel.value = modelInfo + } + + /** + * Prepares the engine for benchmark mode. + */ + fun prepareForBenchmark() { + viewModelScope.launch { + _selectedModel.value?.let { model -> + inferenceEngine.loadModel(model.path) + runBenchmark() + } + } + } + + /** + * Runs the benchmark with current parameters. + */ + private suspend fun runBenchmark() { + inferenceEngine.bench(pp, tg, pl) + } + + /** + * Reruns the benchmark. + */ + fun rerunBenchmark() { + viewModelScope.launch { + runBenchmark() + } + } + + /** + * Prepares the engine for conversation mode. + */ + fun prepareForConversation(systemPrompt: String? = null) { + _systemPrompt.value = systemPrompt + viewModelScope.launch { + _selectedModel.value?.let { model -> + inferenceEngine.loadModel(model.path, systemPrompt) + } + } + } + + /** + * Sends a user message and collects the response. + */ + fun sendMessage(content: String) { + if (content.isBlank()) return + + // Add user message + val userMessage = Message.User( + content = content, + timestamp = System.currentTimeMillis() + ) + _messages.value = _messages.value + userMessage + + // Create placeholder for assistant message + val assistantMessage = Message.Assistant( + content = "", + timestamp = System.currentTimeMillis(), + isComplete = false + ) + _messages.value = _messages.value + assistantMessage + + // Get response from engine + val messageIndex = _messages.value.size - 1 + + viewModelScope.launch { + val response = StringBuilder() + + inferenceEngine.sendUserPrompt(content).collect { token -> + response.append(token) + + // Update the assistant message with the generated text + val currentMessages = _messages.value.toMutableList() + val currentAssistantMessage = currentMessages[messageIndex] as Message.Assistant + currentMessages[messageIndex] = currentAssistantMessage.copy( + content = response.toString(), + isComplete = false + ) + _messages.value = currentMessages + } + + // Mark message as complete when generation finishes + val finalMessages = _messages.value.toMutableList() + val finalAssistantMessage = finalMessages[messageIndex] as Message.Assistant + finalMessages[messageIndex] = finalAssistantMessage.copy( + isComplete = true + ) + _messages.value = finalMessages + } + } + + /** + * Unloads the currently loaded model. + */ + suspend fun unloadModel() { + inferenceEngine.unloadModel() + _messages.value = emptyList() + } + + /** + * Checks if a model is currently loaded. + */ + fun isModelLoaded(): Boolean { + return engineState.value !is InferenceEngine.State.Uninitialized && + engineState.value !is InferenceEngine.State.LibraryLoaded + } + + /** + * Clean up resources when ViewModel is cleared. + */ + override fun onCleared() { + inferenceEngine.destroy() + super.onCleared() + } + + /** + * Factory for creating MainViewModel instances. + */ + class Factory(private val inferenceEngine: InferenceEngine) : ViewModelProvider.Factory { + @Suppress("UNCHECKED_CAST") + override fun create(modelClass: Class): T { + if (modelClass.isAssignableFrom(MainViewModel::class.java)) { + return MainViewModel(inferenceEngine) as T + } + throw IllegalArgumentException("Unknown ViewModel class") + } + } +} + +/** + * Sealed class representing messages in a conversation. + */ +sealed class Message { + abstract val content: String + abstract val timestamp: Long + + val formattedTime: String + get() { + val formatter = SimpleDateFormat("h:mm a", Locale.getDefault()) + return formatter.format(Date(timestamp)) + } + + data class User( + override val content: String, + override val timestamp: Long + ) : Message() + + data class Assistant( + override val content: String, + override val timestamp: Long, + val isComplete: Boolean = true + ) : Message() +}