From 75c986afc5b4864b1cb15b0f057ee61f2bcd8bea Mon Sep 17 00:00:00 2001 From: Han Yin Date: Sat, 12 Apr 2025 12:36:11 -0700 Subject: [PATCH] bugfix: properly handle user's quitting conversation screen while tokens in generation --- .../llama/revamp/engine/InferenceEngine.kt | 52 +++++++--- .../llama/revamp/viewmodel/MainViewModel.kt | 98 +++++++++++++++---- 2 files changed, 115 insertions(+), 35 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt index a7e5a9bcf9..619cb43c00 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt @@ -1,9 +1,11 @@ package com.example.llama.revamp.engine +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.flow /** @@ -39,7 +41,8 @@ class InferenceEngine { // Keep track of current benchmark results private var _benchmarkResults: String? = null - val benchmarkResults: StateFlow = MutableStateFlow(_benchmarkResults) + private val _benchmarkResultsFlow = MutableStateFlow(null) + val benchmarkResults: StateFlow = _benchmarkResultsFlow init { // Simulate library loading @@ -66,6 +69,9 @@ class InferenceEngine { } _state.value = State.AwaitingUserPrompt + } catch (e: CancellationException) { + // If coroutine is cancelled, propagate cancellation + throw e } catch (e: Exception) { _state.value = State.Error(e.message ?: "Unknown error during model loading") } @@ -79,21 +85,35 @@ class InferenceEngine { // This would be replaced with actual token generation logic return flow { - delay(500) // Simulate processing time + try { + delay(500) // Simulate processing time - _state.value = State.Generating + _state.value = State.Generating - // Simulate token generation - val response = - "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message" - val words = response.split(" ") + // Simulate token generation + val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message" + val words = response.split(" ") - for (word in words) { - emit(word + " ") - delay(50) // Simulate token generation delay + for (word in words) { + emit(word + " ") + delay(50) // Simulate token generation delay + } + + _state.value = State.AwaitingUserPrompt + } catch (e: CancellationException) { + // Handle cancellation gracefully + _state.value = State.AwaitingUserPrompt + throw e + } catch (e: Exception) { + _state.value = State.Error(e.message ?: "Unknown error during generation") + throw e } - - _state.value = State.AwaitingUserPrompt + }.catch { e -> + // If it's not a cancellation, update state to error + if (e !is CancellationException) { + _state.value = State.Error(e.message ?: "Unknown error during generation") + } + throw e } } @@ -128,11 +148,15 @@ class InferenceEngine { result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n") _benchmarkResults = result.toString() - (benchmarkResults as MutableStateFlow).value = _benchmarkResults + _benchmarkResultsFlow.value = _benchmarkResults _state.value = State.AwaitingUserPrompt return _benchmarkResults ?: "" + } catch (e: CancellationException) { + // If coroutine is cancelled, propagate cancellation + _state.value = State.AwaitingUserPrompt + throw e } catch (e: Exception) { _state.value = State.Error(e.message ?: "Unknown error during benchmarking") return "Error: ${e.message}" @@ -147,7 +171,7 @@ class InferenceEngine { delay(300) _state.value = State.LibraryLoaded _benchmarkResults = null - (benchmarkResults as MutableStateFlow).value = null + _benchmarkResultsFlow.value = null } /** diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt index 45170819fc..19e5313eb8 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt @@ -5,9 +5,12 @@ import androidx.lifecycle.ViewModelProvider import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.engine.InferenceEngine +import kotlinx.coroutines.Job import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.launch import java.text.SimpleDateFormat import java.util.Date @@ -43,6 +46,9 @@ class MainViewModel( private val _systemPrompt = MutableStateFlow(null) val systemPrompt: StateFlow = _systemPrompt.asStateFlow() + // Flag to track if token collection is active + private var tokenCollectionJob: Job? = null + /** * Selects a model for use. */ @@ -96,6 +102,9 @@ class MainViewModel( fun sendMessage(content: String) { if (content.isBlank()) return + // Cancel any ongoing token collection + tokenCollectionJob?.cancel() + // Add user message val userMessage = Message.User( content = content, @@ -112,31 +121,72 @@ class MainViewModel( _messages.value = _messages.value + assistantMessage // Get response from engine - val messageIndex = _messages.value.size - 1 - - viewModelScope.launch { + tokenCollectionJob = viewModelScope.launch { val response = StringBuilder() - inferenceEngine.sendUserPrompt(content).collect { token -> - response.append(token) + try { + inferenceEngine.sendUserPrompt(content) + .catch { e -> + // Handle errors during token collection + val currentMessages = _messages.value.toMutableList() + if (currentMessages.size >= 2) { + val messageIndex = currentMessages.size - 1 + val currentAssistantMessage = currentMessages[messageIndex] as? Message.Assistant + if (currentAssistantMessage != null) { + currentMessages[messageIndex] = currentAssistantMessage.copy( + content = "${response}[Error: ${e.message}]", + isComplete = true + ) + _messages.value = currentMessages + } + } + } + .onCompletion { cause -> + // Handle completion (normal or cancelled) + val currentMessages = _messages.value.toMutableList() + if (currentMessages.isNotEmpty()) { + val messageIndex = currentMessages.size - 1 + val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant + if (currentAssistantMessage != null) { + currentMessages[messageIndex] = currentAssistantMessage.copy( + isComplete = true + ) + _messages.value = currentMessages + } + } + } + .collect { token -> + response.append(token) - // Update the assistant message with the generated text + // Safely update the assistant message with the generated text + val currentMessages = _messages.value.toMutableList() + if (currentMessages.isNotEmpty()) { + val messageIndex = currentMessages.size - 1 + val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant + if (currentAssistantMessage != null) { + currentMessages[messageIndex] = currentAssistantMessage.copy( + content = response.toString(), + isComplete = false + ) + _messages.value = currentMessages + } + } + } + } catch (e: Exception) { + // Handle any unexpected exceptions val currentMessages = _messages.value.toMutableList() - val currentAssistantMessage = currentMessages[messageIndex] as Message.Assistant - currentMessages[messageIndex] = currentAssistantMessage.copy( - content = response.toString(), - isComplete = false - ) - _messages.value = currentMessages + if (currentMessages.isNotEmpty()) { + val messageIndex = currentMessages.size - 1 + val currentAssistantMessage = currentMessages.getOrNull(messageIndex) as? Message.Assistant + if (currentAssistantMessage != null) { + currentMessages[messageIndex] = currentAssistantMessage.copy( + content = "${response}[Error: ${e.message}]", + isComplete = true + ) + _messages.value = currentMessages + } + } } - - // Mark message as complete when generation finishes - val finalMessages = _messages.value.toMutableList() - val finalAssistantMessage = finalMessages[messageIndex] as Message.Assistant - finalMessages[messageIndex] = finalAssistantMessage.copy( - isComplete = true - ) - _messages.value = finalMessages } } @@ -144,8 +194,14 @@ class MainViewModel( * Unloads the currently loaded model. */ suspend fun unloadModel() { - inferenceEngine.unloadModel() + // Cancel any ongoing token collection + tokenCollectionJob?.cancel() + + // Clear messages _messages.value = emptyList() + + // Unload model + inferenceEngine.unloadModel() } /**