diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt index a41ad919e0..03f57348ed 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/MainActivity.kt @@ -170,7 +170,6 @@ fun AppContent( composable(AppDestinations.MODEL_SELECTION_ROUTE) { ModelSelectionScreen( onModelSelected = { modelInfo -> - mainVewModel.selectModel(modelInfo) navigationActions.navigateToModelLoading() }, onManageModelsClicked = { @@ -184,34 +183,28 @@ fun AppContent( composable(AppDestinations.MODEL_LOADING_ROUTE) { ModelLoadingScreen( engineState = engineState, - onBenchmarkSelected = { - // Store a reference to the loading job + onBenchmarkSelected = { prepareJob -> + // Wait for preparation to complete, then navigate if still active val loadingJob = coroutineScope.launch { - mainVewModel.prepareForBenchmark() - // Check if the job wasn't cancelled before navigating - if (isActive) { - navigationActions.navigateToBenchmark() - } + prepareJob.join() + if (isActive) { navigationActions.navigateToBenchmark() } } - - // Update the pendingNavigation handler to cancel any ongoing loading pendingNavigation = { + prepareJob.cancel() loadingJob.cancel() navigationActions.navigateUp() } }, - onConversationSelected = { systemPrompt -> - // Store a reference to the loading job + onConversationSelected = { systemPrompt, prepareJob -> + // Wait for preparation to complete, then navigate if still active val loadingJob = coroutineScope.launch { - mainVewModel.prepareForConversation(systemPrompt) - // Check if the job wasn't cancelled before navigating - if (isActive) { - navigationActions.navigateToConversation() - } + prepareJob.join() + if (isActive) { navigationActions.navigateToConversation() } } - // Update the pendingNavigation handler to cancel any ongoing loading + pendingNavigation = { + prepareJob.cancel() loadingJob.cancel() navigationActions.navigateUp() } @@ -229,8 +222,7 @@ fun AppContent( onBackPressed = { // Need to unload model before going back handleBackWithModelCheck() - }, - viewModel = mainVewModel + } ) } @@ -240,14 +232,7 @@ fun AppContent( onBackPressed = { // Need to unload model before going back handleBackWithModelCheck() - }, - onRerunPressed = { - mainVewModel.rerunBenchmark() - }, - onSharePressed = { - // Stub for sharing functionality - }, - viewModel = mainVewModel + } ) } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt index 2d1b0c9cb8..42c46a59cc 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/BenchmarkScreen.kt @@ -26,14 +26,13 @@ import androidx.hilt.navigation.compose.hiltViewModel import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.theme.MonospacedTextStyle +import com.example.llama.revamp.viewmodel.BenchmarkViewModel import com.example.llama.revamp.viewmodel.MainViewModel @Composable fun BenchmarkScreen( onBackPressed: () -> Unit, - onRerunPressed: () -> Unit, - onSharePressed: () -> Unit, - viewModel: MainViewModel = hiltViewModel() + viewModel: BenchmarkViewModel = hiltViewModel() ) { val engineState by viewModel.engineState.collectAsState() val benchmarkResults by viewModel.benchmarkResults.collectAsState() diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt index cf97827402..3323804160 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ConversationScreen.kt @@ -62,7 +62,7 @@ import androidx.lifecycle.LifecycleEventObserver import androidx.lifecycle.compose.LocalLifecycleOwner import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.ui.components.PerformanceAppScaffold -import com.example.llama.revamp.viewmodel.MainViewModel +import com.example.llama.revamp.viewmodel.ConversationViewModel import com.example.llama.revamp.viewmodel.Message import kotlinx.coroutines.launch @@ -72,7 +72,7 @@ import kotlinx.coroutines.launch @Composable fun ConversationScreen( onBackPressed: () -> Unit, - viewModel: MainViewModel = hiltViewModel() + viewModel: ConversationViewModel = hiltViewModel() ) { val engineState by viewModel.engineState.collectAsState() val messages by viewModel.messages.collectAsState() diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt index 1033b4621e..f77194e143 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt @@ -40,6 +40,7 @@ import androidx.compose.runtime.collectAsState import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.remember +import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.setValue import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier @@ -50,7 +51,10 @@ import androidx.hilt.navigation.compose.hiltViewModel import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.ui.components.PerformanceAppScaffold +import com.example.llama.revamp.viewmodel.ModelLoadingViewModel import com.example.llama.revamp.viewmodel.SystemPromptViewModel +import kotlinx.coroutines.Job +import kotlinx.coroutines.launch enum class SystemPromptTab { PRESETS, CUSTOM, RECENTS @@ -59,14 +63,17 @@ enum class SystemPromptTab { @OptIn(ExperimentalMaterial3Api::class, ExperimentalFoundationApi::class) @Composable fun ModelLoadingScreen( - viewModel: SystemPromptViewModel = hiltViewModel(), engineState: InferenceEngine.State, - onBenchmarkSelected: () -> Unit, - onConversationSelected: (String?) -> Unit, + onBenchmarkSelected: (prepareJob: Job) -> Unit, + onConversationSelected: (systemPrompt: String?, prepareJob: Job) -> Unit, onBackPressed: () -> Unit, + modelLoadingViewModel: ModelLoadingViewModel = hiltViewModel(), + systemPromptViewModel: SystemPromptViewModel = hiltViewModel(), ) { - val presetPrompts by viewModel.presetPrompts.collectAsState() - val recentPrompts by viewModel.recentPrompts.collectAsState() + val coroutineScope = rememberCoroutineScope() + + val presetPrompts by systemPromptViewModel.presetPrompts.collectAsState() + val recentPrompts by systemPromptViewModel.recentPrompts.collectAsState() var selectedMode by remember { mutableStateOf(null) } var useSystemPrompt by remember { mutableStateOf(false) } @@ -96,6 +103,21 @@ fun ModelLoadingScreen( engineState !is InferenceEngine.State.LibraryLoaded && engineState !is InferenceEngine.State.AwaitingUserPrompt + // Mode selection callbacks + val handleBenchmarkSelected = { + val prepareJob = coroutineScope.launch { + modelLoadingViewModel.prepareForBenchmark() + } + onBenchmarkSelected(prepareJob) + } + + val handleConversationSelected = { systemPrompt: String? -> + val prepareJob = coroutineScope.launch { + modelLoadingViewModel.prepareForConversation(systemPrompt) + } + onConversationSelected(systemPrompt, prepareJob) + } + PerformanceAppScaffold( title = "Load Model", onNavigateBack = onBackPressed, @@ -143,7 +165,7 @@ fun ModelLoadingScreen( modifier = Modifier .fillMaxWidth() .padding(bottom = 4.dp) - // Only use weight if system prompt is active, otherwise wrap content + // Only fill height if system prompt is active .then(if (useSystemPrompt) Modifier.weight(1f) else Modifier) ) { Column( @@ -355,14 +377,15 @@ fun ModelLoadingScreen( Button( onClick = { when (selectedMode) { - Mode.BENCHMARK -> onBenchmarkSelected() + Mode.BENCHMARK -> handleBenchmarkSelected() + Mode.CONVERSATION -> { val systemPrompt = if (useSystemPrompt) { when (selectedTab) { SystemPromptTab.PRESETS, SystemPromptTab.RECENTS -> selectedPrompt?.let { prompt -> // Save the prompt to recent prompts database - viewModel.savePromptToRecents(prompt) + systemPromptViewModel.savePromptToRecents(prompt) prompt.content } @@ -370,15 +393,15 @@ fun ModelLoadingScreen( customPromptText.takeIf { it.isNotBlank() } ?.also { promptText -> // Save custom prompt to database - viewModel.saveCustomPromptToRecents(promptText) + systemPromptViewModel.saveCustomPromptToRecents(promptText) } } } else null - onConversationSelected(systemPrompt) + + handleConversationSelected(systemPrompt) } - null -> { /* No mode selected */ - } + null -> { /* No mode selected */ } } }, modifier = Modifier diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt index fe56d576bd..c64f62bbda 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelSelectionScreen.kt @@ -30,7 +30,7 @@ import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.ui.components.ModelCard import com.example.llama.revamp.ui.components.ModelCardActions import com.example.llama.revamp.ui.components.PerformanceAppScaffold -import com.example.llama.revamp.viewmodel.MainViewModel +import com.example.llama.revamp.viewmodel.ModelSelectionViewModel @OptIn(ExperimentalMaterial3Api::class) @Composable @@ -38,10 +38,15 @@ fun ModelSelectionScreen( onModelSelected: (ModelInfo) -> Unit, onManageModelsClicked: () -> Unit, onMenuClicked: () -> Unit, - viewModel: MainViewModel = hiltViewModel(), + viewModel: ModelSelectionViewModel = hiltViewModel(), ) { val models by viewModel.availableModels.collectAsState() + val handleModelSelection = { model: ModelInfo -> + viewModel.selectModel(model) + onModelSelected(model) + } + PerformanceAppScaffold( title = "Models", onMenuOpen = onMenuClicked, @@ -60,11 +65,13 @@ fun ModelSelectionScreen( items(models) { model -> ModelCard( model = model, - onClick = { onModelSelected(model) }, + onClick = { handleModelSelection(model) }, modifier = Modifier.padding(vertical = 4.dp), isSelected = null, // Not in selection mode actionButton = { - ModelCardActions.PlayButton(onClick = { onModelSelected(model) }) + ModelCardActions.PlayButton { + handleModelSelection(model) + } } ) Spacer(modifier = Modifier.height(8.dp)) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt new file mode 100644 index 0000000000..179d5eb159 --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt @@ -0,0 +1,29 @@ +package com.example.llama.revamp.viewmodel + +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.example.llama.revamp.data.model.ModelInfo +import com.example.llama.revamp.engine.InferenceEngine +import com.example.llama.revamp.engine.InferenceManager +import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.launch +import javax.inject.Inject + +@HiltViewModel +class BenchmarkViewModel @Inject constructor( + private val inferenceManager: InferenceManager +) : ViewModel() { + + val engineState: StateFlow = inferenceManager.engineState + val benchmarkResults: StateFlow = inferenceManager.benchmarkResults + val selectedModel: StateFlow = inferenceManager.currentModel + + /** + * Run benchmark with specified parameters + */ + fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) = + viewModelScope.launch { + inferenceManager.benchmark(pp, tg, pl, nr) + } +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ConversationViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ConversationViewModel.kt new file mode 100644 index 0000000000..fd8e243f1d --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ConversationViewModel.kt @@ -0,0 +1,166 @@ +package com.example.llama.revamp.viewmodel + +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.example.llama.revamp.engine.InferenceManager +import com.example.llama.revamp.engine.TokenMetrics +import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.Job +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.launch +import java.text.SimpleDateFormat +import java.util.Date +import java.util.Locale +import javax.inject.Inject +import kotlin.getValue + + +@HiltViewModel +class ConversationViewModel @Inject constructor( + private val inferenceManager: InferenceManager +) : ViewModel() { + + val engineState = inferenceManager.engineState + val selectedModel = inferenceManager.currentModel + val systemPrompt = inferenceManager.systemPrompt + + // Messages in conversation + private val _messages = MutableStateFlow>(emptyList()) + val messages: StateFlow> = _messages.asStateFlow() + + // Token generation job + private var tokenCollectionJob: Job? = null + + /** + * Send a message with the provided content. + * Note: This matches the existing UI which manages input state outside the ViewModel. + */ + fun sendMessage(content: String) { + if (content.isBlank()) return + + // Cancel ongoing collection + tokenCollectionJob?.cancel() + + // Add user message + val userMessage = Message.User( + content = content, + timestamp = System.currentTimeMillis() + ) + _messages.value = _messages.value + userMessage + + // Add placeholder for assistant response + val assistantMessage = Message.Assistant.Ongoing( + content = "", + timestamp = System.currentTimeMillis() + ) + _messages.value = _messages.value + assistantMessage + + // Collect response + tokenCollectionJob = viewModelScope.launch { + try { + inferenceManager.generateResponse(content) + .collect { (text, isComplete) -> + updateAssistantMessage(text, isComplete) + } + } catch (e: Exception) { + // Handle error + handleResponseError(e) + } + } + } + + /** + * Handle updating the assistant message + */ + private fun updateAssistantMessage(text: String, isComplete: Boolean) { + val currentMessages = _messages.value.toMutableList() + val lastIndex = currentMessages.size - 1 + val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing + + if (currentAssistantMessage != null) { + if (isComplete) { + // Final message with metrics + currentMessages[lastIndex] = Message.Assistant.Completed( + content = text, + timestamp = currentAssistantMessage.timestamp, + metrics = inferenceManager.createTokenMetrics() + ) + } else { + // Ongoing message update + currentMessages[lastIndex] = Message.Assistant.Ongoing( + content = text, + timestamp = currentAssistantMessage.timestamp + ) + } + _messages.value = currentMessages + } + } + + /** + * Handle response error + */ + private fun handleResponseError(e: Exception) { + val currentMessages = _messages.value.toMutableList() + val lastIndex = currentMessages.size - 1 + val currentAssistantMessage = currentMessages.getOrNull(lastIndex) as? Message.Assistant.Ongoing + + if (currentAssistantMessage != null) { + currentMessages[lastIndex] = Message.Assistant.Completed( + content = "${currentAssistantMessage.content}[Error: ${e.message}]", + timestamp = currentAssistantMessage.timestamp, + metrics = inferenceManager.createTokenMetrics() + ) + _messages.value = currentMessages + } + } + + /** + * Clear conversation + */ + fun clearConversation() { + tokenCollectionJob?.cancel() + _messages.value = emptyList() + } + + override fun onCleared() { + tokenCollectionJob?.cancel() + super.onCleared() + } +} + + +/** + * Sealed class representing messages in a conversation. + */ +sealed class Message { + abstract val timestamp: Long + abstract val content: String + + val formattedTime: String + get() = datetimeFormatter.format(Date(timestamp)) + + data class User( + override val timestamp: Long, + override val content: String + ) : Message() + + sealed class Assistant : Message() { + data class Ongoing( + override val timestamp: Long, + override val content: String, + ) : Assistant() + + data class Completed( + override val timestamp: Long, + override val content: String, + val metrics: TokenMetrics + ) : Assistant() + } + + companion object { + private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) } + } +} + diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt index fee0d6250a..ae709896d7 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt @@ -1,276 +1,25 @@ package com.example.llama.revamp.viewmodel import androidx.lifecycle.ViewModel -import androidx.lifecycle.viewModelScope -import com.example.llama.revamp.data.model.ModelInfo -import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.engine.InferenceEngine +import com.example.llama.revamp.engine.InferenceManager import dagger.hilt.android.lifecycle.HiltViewModel -import kotlinx.coroutines.Job -import kotlinx.coroutines.flow.MutableStateFlow -import kotlinx.coroutines.flow.SharingStarted -import kotlinx.coroutines.flow.StateFlow -import kotlinx.coroutines.flow.asStateFlow -import kotlinx.coroutines.flow.catch -import kotlinx.coroutines.flow.onCompletion -import kotlinx.coroutines.flow.stateIn -import kotlinx.coroutines.launch -import java.text.SimpleDateFormat -import java.util.Date -import java.util.Locale import javax.inject.Inject -/** - * Main ViewModel that handles the LLM engine state and operations. - */ @HiltViewModel +/** + * Main ViewModel that expose the core states of [InferenceEngine] + */ class MainViewModel @Inject constructor ( - private val inferenceEngine: InferenceEngine, - private val modelRepository: ModelRepository, + private val inferenceManager: InferenceManager, ) : ViewModel() { - // Expose the engine state - val engineState: StateFlow = inferenceEngine.state - - // Benchmark results - val benchmarkResults: StateFlow = inferenceEngine.benchmarkResults - - // Available models for selection - val availableModels: StateFlow> = modelRepository.getModels() - .stateIn( - scope = viewModelScope, - started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS), - initialValue = emptyList() - ) - - // Selected model information - private val _selectedModel = MutableStateFlow(null) - val selectedModel: StateFlow = _selectedModel.asStateFlow() - - // Messages in the conversation - private val _messages = MutableStateFlow>(emptyList()) - val messages: StateFlow> = _messages.asStateFlow() - - // System prompt for the conversation - private val _systemPrompt = MutableStateFlow(null) - val systemPrompt: StateFlow = _systemPrompt.asStateFlow() - - // Flag to track if token collection is active - private var tokenCollectionJob: Job? = null + val engineState = inferenceManager.engineState /** - * Selects a model for use. + * Unload the current model and release the resources */ - fun selectModel(modelInfo: ModelInfo) { - _selectedModel.value = modelInfo - - viewModelScope.launch { - modelRepository.updateModelLastUsed(modelInfo.id) - } - } - - /** - * Prepares the engine for benchmark mode. - */ - suspend fun prepareForBenchmark() { - _selectedModel.value?.let { model -> - inferenceEngine.loadModel(model.path) - } - } - - /** - * Runs the benchmark with current parameters. - */ - suspend fun runBenchmark() = inferenceEngine.bench(512, 128, 1, 3) - - /** - * Reruns the benchmark. - */ - fun rerunBenchmark() = viewModelScope.launch { runBenchmark() } - - /** - * Prepares the engine for conversation mode. - */ - suspend fun prepareForConversation(systemPrompt: String? = null) { - _systemPrompt.value = systemPrompt - _selectedModel.value?.let { model -> - inferenceEngine.loadModel(model.path, systemPrompt) - } - } - - /** - * Tracks token generation metrics - */ - private var generationStartTime: Long = 0L - private var firstTokenTime: Long = 0L - private var tokenCount: Int = 0 - private var isFirstToken: Boolean = true - - /** - * Sends a user message and collects the response. - */ - fun sendMessage(content: String) { - if (content.isBlank()) return - - // Cancel any ongoing token collection - tokenCollectionJob?.cancel() - - // Add user message - val userMessage = Message.User( - content = content, - timestamp = System.currentTimeMillis() - ) - _messages.value = _messages.value + userMessage - - // Create placeholder for assistant message - val assistantMessage = Message.Assistant.Ongoing( - content = "", - timestamp = System.currentTimeMillis() - ) - _messages.value = _messages.value + assistantMessage - - // Reset metrics tracking - generationStartTime = System.currentTimeMillis() - firstTokenTime = 0L - tokenCount = 0 - isFirstToken = true - - // Get response from engine - tokenCollectionJob = viewModelScope.launch { - val response = StringBuilder() - - try { - inferenceEngine.sendUserPrompt(content) - .catch { e -> - // Handle errors during token collection - val currentMessages = _messages.value.toMutableList() - if (currentMessages.size >= 2) { - val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant.Ongoing - if (currentAssistantMessage != null) { - // Create metrics with error indication - val errorMetrics = TokenMetrics( - tokensCount = tokenCount, - ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L, - tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime) - ) - - currentMessages[messageIndex] = Message.Assistant.Completed( - content = "${response}[Error: ${e.message}]", - timestamp = currentAssistantMessage.timestamp, - metrics = errorMetrics - ) - _messages.value = currentMessages - } - } - } - .onCompletion { cause -> - // Handle completion (normal or cancelled) - val currentMessages = _messages.value.toMutableList() - if (currentMessages.isNotEmpty()) { - val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing - if (currentAssistantMessage != null) { - // Calculate final metrics - val endTime = System.currentTimeMillis() - val totalTimeMs = endTime - generationStartTime - - val metrics = TokenMetrics( - tokensCount = tokenCount, - ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L, - tpsMs = calculateTPS(tokenCount, totalTimeMs) - ) - - currentMessages[messageIndex] = Message.Assistant.Completed( - content = response.toString(), - timestamp = currentAssistantMessage.timestamp, - metrics = metrics - ) - _messages.value = currentMessages - } - } - } - .collect { token -> - // Track first token time - if (isFirstToken && token.isNotBlank()) { - firstTokenTime = System.currentTimeMillis() - isFirstToken = false - } - - // Count tokens - each non-empty emission is at least one token - if (token.isNotBlank()) { - tokenCount++ - } - - response.append(token) - - // Safely update the assistant message with the generated text - val currentMessages = _messages.value.toMutableList() - if (currentMessages.isNotEmpty()) { - val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing - if (currentAssistantMessage != null) { - currentMessages[messageIndex] = Message.Assistant.Ongoing( - content = response.toString(), - timestamp = currentAssistantMessage.timestamp - ) - _messages.value = currentMessages - } - } - } - } catch (e: Exception) { - // Handle any unexpected exceptions - val currentMessages = _messages.value.toMutableList() - if (currentMessages.isNotEmpty()) { - val messageIndex = currentMessages.size - 1 - val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant.Ongoing - if (currentAssistantMessage != null) { - // Create metrics with error indication - val errorMetrics = TokenMetrics( - tokensCount = tokenCount, - ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L, - tpsMs = calculateTPS(tokenCount, System.currentTimeMillis() - generationStartTime) - ) - - currentMessages[messageIndex] = Message.Assistant.Completed( - content = "${response}[Error: ${e.message}]", - timestamp = currentAssistantMessage.timestamp, - metrics = errorMetrics - ) - _messages.value = currentMessages - } - } - } - } - } - - /** - * Calculate tokens per second. - */ - private fun calculateTPS(tokens: Int, timeMs: Long): Float { - if (tokens <= 0 || timeMs <= 0) return 0f - return (tokens.toFloat() * 1000f) / timeMs - } - - /** - * Unloads the currently loaded model after cleanup chores: - * - Cancel any ongoing token collection - * - Clear messages - */ - suspend fun unloadModel() { - tokenCollectionJob?.cancel() - _messages.value = emptyList() - - inferenceEngine.unloadModel() - } - - /** - * Clean up resources when ViewModel is cleared. - */ - override fun onCleared() { - inferenceEngine.destroy() - super.onCleared() - } + suspend fun unloadModel() = inferenceManager.unloadModel() companion object { private val TAG = MainViewModel::class.java.simpleName @@ -279,44 +28,3 @@ class MainViewModel @Inject constructor ( } } -/** - * Sealed class representing messages in a conversation. - */ -sealed class Message { - abstract val timestamp: Long - abstract val content: String - - val formattedTime: String - get() = datetimeFormatter.format(Date(timestamp)) - - data class User( - override val timestamp: Long, - override val content: String - ) : Message() - - sealed class Assistant : Message() { - data class Ongoing( - override val timestamp: Long, - override val content: String, - ) : Assistant() - - data class Completed( - override val timestamp: Long, - override val content: String, - val metrics: TokenMetrics - ) : Assistant() - } - - companion object { - private val datetimeFormatter by lazy { SimpleDateFormat("h:mm a", Locale.getDefault()) } - } -} - -data class TokenMetrics( - val tokensCount: Int, - val ttftMs: Long, - val tpsMs: Float, -) { - val text: String - get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}" -} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt new file mode 100644 index 0000000000..2b84399fde --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt @@ -0,0 +1,27 @@ +package com.example.llama.revamp.viewmodel + +import androidx.lifecycle.ViewModel +import com.example.llama.revamp.engine.InferenceManager +import dagger.hilt.android.lifecycle.HiltViewModel +import javax.inject.Inject + +@HiltViewModel +class ModelLoadingViewModel @Inject constructor( + private val inferenceManager: InferenceManager +) : ViewModel() { + + val engineState = inferenceManager.engineState + val selectedModel = inferenceManager.currentModel + + /** + * Prepares the engine for benchmark mode. + */ + suspend fun prepareForBenchmark() = + inferenceManager.loadModelForBenchmark() + + /** + * Prepare for conversation + */ + suspend fun prepareForConversation(systemPrompt: String? = null) = + inferenceManager.loadModelForConversation(systemPrompt) +} diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt new file mode 100644 index 0000000000..f39d024b8c --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt @@ -0,0 +1,58 @@ +package com.example.llama.revamp.viewmodel + +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.example.llama.revamp.data.model.ModelInfo +import com.example.llama.revamp.data.repository.ModelRepository +import com.example.llama.revamp.engine.InferenceManager +import dagger.hilt.android.lifecycle.HiltViewModel +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.stateIn +import kotlinx.coroutines.launch +import javax.inject.Inject + + +@HiltViewModel +class ModelSelectionViewModel @Inject constructor( + private val inferenceManager: InferenceManager, + private val modelRepository: ModelRepository +) : ViewModel() { + + /** + * Available models for selection + */ + val availableModels: StateFlow> = modelRepository.getModels() + .stateIn( + scope = viewModelScope, + started = SharingStarted.WhileSubscribed(SUBSCRIPTION_TIMEOUT_MS), + initialValue = emptyList() + ) + + /** + * Access to currently selected model + */ + val selectedModel = inferenceManager.currentModel + + /** + * Select a model and update its last used timestamp + */ + fun selectModel(modelInfo: ModelInfo) { + inferenceManager.setCurrentModel(modelInfo) + + viewModelScope.launch { + modelRepository.updateModelLastUsed(modelInfo.id) + } + } + + /** + * Unload model when navigating away + */ + suspend fun unloadModel() = inferenceManager.unloadModel() + + companion object { + private val TAG = ModelSelectionViewModel::class.java.simpleName + + private const val SUBSCRIPTION_TIMEOUT_MS = 5000L + } +}