bugfix: wait for model to load before navigating to benchmark screen; use NavigationActions instead of raw navController

This commit is contained in:
Han Yin 2025-04-13 22:14:09 -07:00
parent ea11ee3c94
commit 511df35704
9 changed files with 38 additions and 63 deletions

View File

@ -97,18 +97,16 @@ fun AppContent(
// Model unloading confirmation // Model unloading confirmation
var showUnloadDialog by remember { mutableStateOf(false) } var showUnloadDialog by remember { mutableStateOf(false) }
// Helper function to handle back press with model unloading check
val handleBackWithModelCheck = { val handleBackWithModelCheck = {
if (isModelLoading) { if (isModelLoading) {
// If model is still loading, ignore the request // If model is still loading, ignore the request
true // Mark as handled true // Mark as handled
} else if (isModelLoaded) { } else if (isModelLoaded) {
showUnloadDialog = true showUnloadDialog = true
pendingNavigation = { navController.popBackStack() } pendingNavigation = { navigationActions.navigateUp() }
true // Mark as handled true // Mark as handled
} else { } else {
navController.popBackStack() navigationActions.navigateUp()
true // Mark as handled true // Mark as handled
} }
} }
@ -138,14 +136,9 @@ fun AppContent(
val drawerGesturesEnabled by remember(currentRoute, drawerState.currentValue) { val drawerGesturesEnabled by remember(currentRoute, drawerState.currentValue) {
derivedStateOf { derivedStateOf {
// Always allow gesture dismissal when drawer is open // Always allow gesture dismissal when drawer is open
if (drawerState.currentValue == DrawerValue.Open) { if (drawerState.currentValue == DrawerValue.Open) true
true // Only enable drawer opening by gesture on these screens
} else { else currentRoute == AppDestinations.MODEL_SELECTION_ROUTE
// Only enable drawer opening by gesture on these screens
currentRoute == AppDestinations.MODEL_SELECTION_ROUTE ||
currentRoute == AppDestinations.SETTINGS_GENERAL_ROUTE ||
currentRoute == AppDestinations.MODELS_MANAGEMENT_ROUTE
}
} }
} }
@ -191,8 +184,6 @@ fun AppContent(
navigationActions.navigateToModelsManagement() navigationActions.navigateToModelsManagement()
}, },
onMenuClicked = openDrawer, onMenuClicked = openDrawer,
drawerState = drawerState,
navigationActions = navigationActions
) )
} }
@ -201,8 +192,21 @@ fun AppContent(
ModelLoadingScreen( ModelLoadingScreen(
engineState = engineState, engineState = engineState,
onBenchmarkSelected = { onBenchmarkSelected = {
mainVewModel.prepareForBenchmark() // Store a reference to the loading job
navigationActions.navigateToBenchmark() val loadingJob = coroutineScope.launch {
mainVewModel.prepareForBenchmark()
// Check if the job wasn't cancelled before navigating
if (isActive) {
navigationActions.navigateToBenchmark()
}
}
// Update the pendingNavigation handler to cancel any ongoing loading
pendingNavigation = {
loadingJob.cancel()
navigationActions.navigateUp()
}
}, },
onConversationSelected = { systemPrompt -> onConversationSelected = { systemPrompt ->
// Store a reference to the loading job // Store a reference to the loading job
@ -216,15 +220,13 @@ fun AppContent(
// Update the pendingNavigation handler to cancel any ongoing loading // Update the pendingNavigation handler to cancel any ongoing loading
pendingNavigation = { pendingNavigation = {
loadingJob.cancel() loadingJob.cancel()
navController.popBackStack() navigationActions.navigateUp()
} }
}, },
onBackPressed = { onBackPressed = {
// Need to unload model before going back // Need to unload model before going back
handleBackWithModelCheck() handleBackWithModelCheck()
}, },
drawerState = drawerState,
navigationActions = navigationActions
) )
} }
@ -252,8 +254,6 @@ fun AppContent(
onSharePressed = { onSharePressed = {
// Stub for sharing functionality // Stub for sharing functionality
}, },
drawerState = drawerState,
navigationActions = navigationActions,
viewModel = mainVewModel viewModel = mainVewModel
) )
} }
@ -261,9 +261,7 @@ fun AppContent(
// Settings General Screen // Settings General Screen
composable(AppDestinations.SETTINGS_GENERAL_ROUTE) { composable(AppDestinations.SETTINGS_GENERAL_ROUTE) {
SettingsGeneralScreen( SettingsGeneralScreen(
onBackPressed = { navController.popBackStack() }, onBackPressed = { navigationActions.navigateUp() },
drawerState = drawerState,
navigationActions = navigationActions,
onMenuClicked = openDrawer onMenuClicked = openDrawer
) )
} }
@ -271,7 +269,7 @@ fun AppContent(
// Models Management Screen // Models Management Screen
composable(AppDestinations.MODELS_MANAGEMENT_ROUTE) { composable(AppDestinations.MODELS_MANAGEMENT_ROUTE) {
ModelsManagementScreen( ModelsManagementScreen(
onBackPressed = { navController.popBackStack() }, onBackPressed = { navigationActions.navigateUp() },
) )
} }
} }

View File

@ -11,7 +11,6 @@ import androidx.compose.runtime.remember
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.viewmodel.PerformanceViewModel import com.example.llama.revamp.viewmodel.PerformanceViewModel
// DefaultAppScaffold.kt
@Composable @Composable
fun DefaultAppScaffold( fun DefaultAppScaffold(
title: String, title: String,
@ -35,7 +34,6 @@ fun DefaultAppScaffold(
) )
} }
// PerformanceAppScaffold.kt
@Composable @Composable
fun PerformanceAppScaffold( fun PerformanceAppScaffold(
performanceViewModel: PerformanceViewModel = hiltViewModel(), performanceViewModel: PerformanceViewModel = hiltViewModel(),
@ -68,7 +66,6 @@ fun PerformanceAppScaffold(
) )
} }
// StorageAppScaffold.kt
@Composable @Composable
fun StorageAppScaffold( fun StorageAppScaffold(
title: String, title: String,

View File

@ -26,7 +26,6 @@ import com.example.llama.revamp.monitoring.MemoryMetrics
import com.example.llama.revamp.monitoring.TemperatureMetrics import com.example.llama.revamp.monitoring.TemperatureMetrics
import com.example.llama.revamp.monitoring.TemperatureWarningLevel import com.example.llama.revamp.monitoring.TemperatureWarningLevel
// TopAppBars.kt
@OptIn(ExperimentalMaterial3Api::class) @OptIn(ExperimentalMaterial3Api::class)
@Composable @Composable
fun DefaultTopBar( fun DefaultTopBar(

View File

@ -13,10 +13,10 @@ import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.foundation.verticalScroll import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.DrawerState
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text import androidx.compose.material3.Text
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.collectAsState import androidx.compose.runtime.collectAsState
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
@ -24,7 +24,6 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.ui.theme.MonospacedTextStyle import com.example.llama.revamp.ui.theme.MonospacedTextStyle
import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.MainViewModel
@ -34,14 +33,16 @@ fun BenchmarkScreen(
onBackPressed: () -> Unit, onBackPressed: () -> Unit,
onRerunPressed: () -> Unit, onRerunPressed: () -> Unit,
onSharePressed: () -> Unit, onSharePressed: () -> Unit,
drawerState: DrawerState,
navigationActions: NavigationActions,
viewModel: MainViewModel = hiltViewModel() viewModel: MainViewModel = hiltViewModel()
) { ) {
val engineState by viewModel.engineState.collectAsState() val engineState by viewModel.engineState.collectAsState()
val benchmarkResults by viewModel.benchmarkResults.collectAsState() val benchmarkResults by viewModel.benchmarkResults.collectAsState()
val selectedModel by viewModel.selectedModel.collectAsState() val selectedModel by viewModel.selectedModel.collectAsState()
LaunchedEffect(selectedModel) {
viewModel.runBenchmark()
}
PerformanceAppScaffold( PerformanceAppScaffold(
title = "Chat", title = "Chat",
onNavigateBack = onBackPressed, onNavigateBack = onBackPressed,

View File

@ -55,11 +55,11 @@ import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip import androidx.compose.ui.draw.clip
import androidx.compose.ui.graphics.Color import androidx.compose.ui.graphics.Color
import androidx.compose.ui.graphics.StrokeCap import androidx.compose.ui.graphics.StrokeCap
import androidx.compose.ui.platform.LocalLifecycleOwner
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import androidx.lifecycle.Lifecycle import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleEventObserver import androidx.lifecycle.LifecycleEventObserver
import androidx.lifecycle.compose.LocalLifecycleOwner
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.viewmodel.MainViewModel import com.example.llama.revamp.viewmodel.MainViewModel
@ -159,11 +159,12 @@ fun ConversationScreen(
} }
@Composable @Composable
fun AnimatedSystemPrompt(modelName: String?, systemPrompt: String?) { fun AnimatedSystemPrompt(
modelName: String?, // TODO-han.yin: add model name into this card, on top of system prompt!
systemPrompt: String?
) {
var expanded by remember { mutableStateOf(false) } var expanded by remember { mutableStateOf(false) }
// TODO-han.yin: add model name into this card, on top of system prompt!
if (!systemPrompt.isNullOrBlank()) { if (!systemPrompt.isNullOrBlank()) {
Card( Card(
modifier = Modifier modifier = Modifier

View File

@ -23,7 +23,6 @@ import androidx.compose.material.icons.filled.Check
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.CircularProgressIndicator
import androidx.compose.material3.DrawerState
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
@ -50,7 +49,6 @@ import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import com.example.llama.revamp.viewmodel.SystemPromptViewModel import com.example.llama.revamp.viewmodel.SystemPromptViewModel
@ -66,8 +64,6 @@ fun ModelLoadingScreen(
onBenchmarkSelected: () -> Unit, onBenchmarkSelected: () -> Unit,
onConversationSelected: (String?) -> Unit, onConversationSelected: (String?) -> Unit,
onBackPressed: () -> Unit, onBackPressed: () -> Unit,
drawerState: DrawerState,
navigationActions: NavigationActions
) { ) {
val presetPrompts by viewModel.presetPrompts.collectAsState() val presetPrompts by viewModel.presetPrompts.collectAsState()
val recentPrompts by viewModel.recentPrompts.collectAsState() val recentPrompts by viewModel.recentPrompts.collectAsState()

View File

@ -14,7 +14,6 @@ import androidx.compose.material.icons.Icons
import androidx.compose.material.icons.filled.PlayArrow import androidx.compose.material.icons.filled.PlayArrow
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.CardDefaults import androidx.compose.material3.CardDefaults
import androidx.compose.material3.DrawerState
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.Icon import androidx.compose.material3.Icon
import androidx.compose.material3.IconButton import androidx.compose.material3.IconButton
@ -27,7 +26,6 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.PerformanceAppScaffold import com.example.llama.revamp.ui.components.PerformanceAppScaffold
import java.text.SimpleDateFormat import java.text.SimpleDateFormat
import java.util.Date import java.util.Date
@ -39,8 +37,6 @@ fun ModelSelectionScreen(
onModelSelected: (ModelInfo) -> Unit, onModelSelected: (ModelInfo) -> Unit,
onManageModelsClicked: () -> Unit, onManageModelsClicked: () -> Unit,
onMenuClicked: () -> Unit, onMenuClicked: () -> Unit,
drawerState: DrawerState,
navigationActions: NavigationActions
) { ) {
// For demo purposes, we'll use sample models // For demo purposes, we'll use sample models
val models = remember { ModelInfo.getSampleModels() } val models = remember { ModelInfo.getSampleModels() }

View File

@ -10,7 +10,6 @@ import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.rememberScrollState
import androidx.compose.foundation.verticalScroll import androidx.compose.foundation.verticalScroll
import androidx.compose.material3.Card import androidx.compose.material3.Card
import androidx.compose.material3.DrawerState
import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.HorizontalDivider
import androidx.compose.material3.MaterialTheme import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Switch import androidx.compose.material3.Switch
@ -22,7 +21,6 @@ import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp import androidx.compose.ui.unit.dp
import androidx.hilt.navigation.compose.hiltViewModel import androidx.hilt.navigation.compose.hiltViewModel
import com.example.llama.revamp.navigation.NavigationActions
import com.example.llama.revamp.ui.components.DefaultAppScaffold import com.example.llama.revamp.ui.components.DefaultAppScaffold
import com.example.llama.revamp.viewmodel.PerformanceViewModel import com.example.llama.revamp.viewmodel.PerformanceViewModel
@ -33,8 +31,6 @@ import com.example.llama.revamp.viewmodel.PerformanceViewModel
fun SettingsGeneralScreen( fun SettingsGeneralScreen(
performanceViewModel: PerformanceViewModel = hiltViewModel(), performanceViewModel: PerformanceViewModel = hiltViewModel(),
onBackPressed: () -> Unit, onBackPressed: () -> Unit,
drawerState: DrawerState,
navigationActions: NavigationActions,
onMenuClicked: () -> Unit onMenuClicked: () -> Unit
) { ) {
// Collect state from ViewModel // Collect state from ViewModel

View File

@ -57,30 +57,21 @@ class MainViewModel @Inject constructor (
/** /**
* Prepares the engine for benchmark mode. * Prepares the engine for benchmark mode.
*/ */
fun prepareForBenchmark() { suspend fun prepareForBenchmark() {
viewModelScope.launch { _selectedModel.value?.let { model ->
_selectedModel.value?.let { model -> inferenceEngine.loadModel(model.path)
inferenceEngine.loadModel(model.path)
runBenchmark()
}
} }
} }
/** /**
* Runs the benchmark with current parameters. * Runs the benchmark with current parameters.
*/ */
private suspend fun runBenchmark() { suspend fun runBenchmark() = inferenceEngine.bench(512, 128, 1, 3)
inferenceEngine.bench(512, 128, 1, 3)
}
/** /**
* Reruns the benchmark. * Reruns the benchmark.
*/ */
fun rerunBenchmark() { fun rerunBenchmark() = viewModelScope.launch { runBenchmark() }
viewModelScope.launch {
runBenchmark()
}
}
/** /**
* Prepares the engine for conversation mode. * Prepares the engine for conversation mode.