From 65d4a57a8b4cc40ec892f79f2412bbfe80e4e79a Mon Sep 17 00:00:00 2001 From: Han Yin Date: Wed, 16 Apr 2025 22:29:40 -0700 Subject: [PATCH] LLama: refactor loadModel by splitting the system prompt setting into a separate method --- .../llama/revamp/data/model/ModelInfo.kt | 43 ------------- .../llama/revamp/engine/InferenceServices.kt | 53 +++++++++++---- .../revamp/engine/StubInferenceEngine.kt | 38 +++++++---- .../revamp/ui/screens/ModelLoadingScreen.kt | 4 +- .../java/android/llama/cpp/InferenceEngine.kt | 13 ++-- .../java/android/llama/cpp/LLamaAndroid.kt | 64 +++++++++++-------- 6 files changed, 111 insertions(+), 104 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt index 62e6295ca6..ad0bb31530 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/data/model/ModelInfo.kt @@ -18,47 +18,4 @@ data class ModelInfo( ) { val formattedSize: String get() = formatSize(sizeInBytes) - - companion object { - /** - * Creates a list of sample models for development and testing. - */ - fun getSampleModels(): List { - return listOf( - ModelInfo( - id = "mistral-7b", - name = "Mistral 7B", - path = "/storage/models/mistral-7b-q4_0.gguf", - sizeInBytes = 4_000_000_000, - parameters = "7B", - quantization = "Q4_K_M", - type = "Mistral", - contextLength = 8192, - lastUsed = System.currentTimeMillis() - 86400000 // 1 day ago - ), - ModelInfo( - id = "llama2-13b", - name = "Llama 2 13B", - path = "/storage/models/llama2-13b-q5_k_m.gguf", - sizeInBytes = 8_500_000_000, - parameters = "13B", - quantization = "Q5_K_M", - type = "Llama", - contextLength = 4096, - lastUsed = System.currentTimeMillis() - 259200000 // 3 days ago - ), - ModelInfo( - id = "phi-2", - name = "Phi-2", - path = "/storage/models/phi-2.gguf", - sizeInBytes = 2_800_000_000, - parameters = "2.7B", - quantization = "Q4_0", - type = "Phi", - contextLength = 2048, - lastUsed = null - ) - ) - } - } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt index 1d9d356e33..87e72ea1fd 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt @@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService { /** * Load a model for benchmark */ - suspend fun loadModelForBenchmark(): Boolean + suspend fun loadModelForBenchmark(): ModelLoadingMetrics? /** * Load a model for conversation */ - suspend fun loadModelForConversation(systemPrompt: String?): Boolean + suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? } interface BenchmarkService : InferenceService { @@ -80,6 +80,17 @@ interface ConversationService : InferenceService { fun createTokenMetrics(): TokenMetrics } +/** + * Metrics for model loading and system prompt processing + */ +data class ModelLoadingMetrics( + val modelLoadingTimeMs: Long, + val systemPromptProcessingTimeMs: Long? = null +) { + val totalTimeMs: Long + get() = modelLoadingTimeMs + (systemPromptProcessingTimeMs ?: 0) +} + /** * Represents an update during text generation */ @@ -115,9 +126,7 @@ internal class InferenceServiceImpl @Inject internal constructor( private val _currentModel = MutableStateFlow(null) override val currentSelectedModel: StateFlow = _currentModel.asStateFlow() - override fun setCurrentModel(model: ModelInfo) { - _currentModel.value = model - } + override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model } override suspend fun unloadModel() = inferenceEngine.unloadModel() @@ -129,29 +138,45 @@ internal class InferenceServiceImpl @Inject internal constructor( /* ModelLoadingService implementation */ - override suspend fun loadModelForBenchmark(): Boolean { + override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? { return _currentModel.value?.let { model -> try { + val modelLoadStartTs = System.currentTimeMillis() inferenceEngine.loadModel(model.path) - true + val modelLoadEndTs = System.currentTimeMillis() + ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs) } catch (e: Exception) { Log.e("InferenceManager", "Error loading model", e) - false + null } - } == true + } } - override suspend fun loadModelForConversation(systemPrompt: String?): Boolean { + override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? { _systemPrompt.value = systemPrompt return _currentModel.value?.let { model -> try { - inferenceEngine.loadModel(model.path, systemPrompt) - true + val modelLoadStartTs = System.currentTimeMillis() + inferenceEngine.loadModel(model.path) + val modelLoadEndTs = System.currentTimeMillis() + + if (systemPrompt.isNullOrBlank()) { + ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs) + } else { + val prompt: String = systemPrompt + val systemPromptStartTs = System.currentTimeMillis() + inferenceEngine.setSystemPrompt(prompt) + val systemPromptEndTs = System.currentTimeMillis() + ModelLoadingMetrics( + modelLoadingTimeMs = modelLoadEndTs - modelLoadStartTs, + systemPromptProcessingTimeMs = systemPromptEndTs - systemPromptStartTs + ) + } } catch (e: Exception) { Log.e("InferenceManager", "Error loading model", e) - false + null } - } == true + } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt index 97943f386d..6d2ee7bd01 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt @@ -42,9 +42,9 @@ class StubInferenceEngine : InferenceEngine { } /** - * Loads a model from the given path with an optional system prompt. + * Loads a model from the given path. */ - override suspend fun loadModel(pathToModel: String, systemPrompt: String?) { + override suspend fun loadModel(pathToModel: String) { Log.i(TAG, "loadModel! state: ${_state.value}") try { @@ -53,16 +53,28 @@ class StubInferenceEngine : InferenceEngine { // Simulate model loading delay(STUB_MODEL_LOADING_TIME) - _state.value = State.ModelLoaded + _state.value = State.ModelReady - if (systemPrompt != null) { - _state.value = State.ProcessingSystemPrompt - // Simulate processing system prompt - delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME) - } + } catch (e: CancellationException) { + // If coroutine is cancelled, propagate cancellation + throw e + } catch (e: Exception) { + _state.value = State.Error(e.message ?: "Unknown error during model loading") + } + } - _state.value = State.AwaitingUserPrompt + /** + * Process the plain text system prompt + */ + override suspend fun setSystemPrompt(prompt: String) { + try { + _state.value = State.ProcessingSystemPrompt + + // Simulate processing system prompt + delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME) + + _state.value = State.ModelReady } catch (e: CancellationException) { // If coroutine is cancelled, propagate cancellation throw e @@ -94,10 +106,10 @@ class StubInferenceEngine : InferenceEngine { delay(STUB_TOKEN_GENERATION_TIME) } - _state.value = State.AwaitingUserPrompt + _state.value = State.ModelReady } catch (e: CancellationException) { // Handle cancellation gracefully - _state.value = State.AwaitingUserPrompt + _state.value = State.ModelReady throw e } catch (e: Exception) { _state.value = State.Error(e.message ?: "Unknown error during generation") @@ -144,12 +156,12 @@ class StubInferenceEngine : InferenceEngine { result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ") result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n") - _state.value = State.AwaitingUserPrompt + _state.value = State.ModelReady return result.toString() } catch (e: CancellationException) { // If coroutine is cancelled, propagate cancellation - _state.value = State.AwaitingUserPrompt + _state.value = State.ModelReady throw e } catch (e: Exception) { _state.value = State.Error(e.message ?: "Unknown error during benchmarking") diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt index ecaca9703d..df076dcbb7 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/ui/screens/ModelLoadingScreen.kt @@ -101,7 +101,7 @@ fun ModelLoadingScreen( // Check if we're in a loading state val isLoading = engineState !is State.Uninitialized && engineState !is State.LibraryLoaded && - engineState !is State.AwaitingUserPrompt + engineState !is State.ModelReady // Mode selection callbacks val handleBenchmarkSelected = { @@ -431,7 +431,7 @@ fun ModelLoadingScreen( text = when (engineState) { is State.LoadingModel -> "Loading model..." is State.ProcessingSystemPrompt -> "Processing system prompt..." - is State.ModelLoaded -> "Preparing conversation..." + is State.ModelReady -> "Preparing conversation..." else -> "Processing..." }, style = MaterialTheme.typography.titleMedium diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt index 043f0a0037..f38cece58f 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt @@ -13,9 +13,14 @@ interface InferenceEngine { val state: StateFlow /** - * Load a model from the given path with an optional system prompt. + * Load a model from the given path. */ - suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) + suspend fun loadModel(pathToModel: String) + + /** + * Sends a system prompt to the loaded model + */ + suspend fun setSystemPrompt(systemPrompt: String) /** * Sends a user prompt to the loaded model and returns a Flow of generated tokens. @@ -45,11 +50,9 @@ interface InferenceEngine { object LibraryLoaded : State() object LoadingModel : State() - object ModelLoaded : State() + object ModelReady : State() object ProcessingSystemPrompt : State() - object AwaitingUserPrompt : State() - object ProcessingUserPrompt : State() object Generating : State() diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index 69ffc38409..55c5ff187d 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -18,10 +18,6 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.withContext import java.io.File -@Target(AnnotationTarget.FUNCTION) -@Retention(AnnotationRetention.SOURCE) -annotation class RequiresCleanup(val message: String = "Remember to call this method for proper cleanup!") - /** * JNI wrapper for the llama.cpp library providing Android-friendly access to large language models. * @@ -63,6 +59,8 @@ class LLamaAndroid private constructor() : InferenceEngine { private val _state = MutableStateFlow(State.Uninitialized) override val state: StateFlow = _state + private var _readyForSystemPrompt = false + /** * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations */ @@ -85,9 +83,9 @@ class LLamaAndroid private constructor() : InferenceEngine { } /** - * Load the LLM, then process the plain text system prompt if provided + * Load the LLM */ - override suspend fun loadModel(pathToModel: String, systemPrompt: String?) = + override suspend fun loadModel(pathToModel: String) = withContext(llamaDispatcher) { check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" } File(pathToModel).let { @@ -96,6 +94,7 @@ class LLamaAndroid private constructor() : InferenceEngine { } Log.i(TAG, "Loading model... \n$pathToModel") + _readyForSystemPrompt = false _state.value = State.LoadingModel load(pathToModel).let { result -> if (result != 0) throw IllegalStateException("Failed to Load model: $result") @@ -104,23 +103,31 @@ class LLamaAndroid private constructor() : InferenceEngine { if (result != 0) throw IllegalStateException("Failed to prepare resources: $result") } Log.i(TAG, "Model loaded!") - _state.value = State.ModelLoaded + _readyForSystemPrompt = true + _state.value = State.ModelReady + } - systemPrompt?.let { prompt -> - Log.i(TAG, "Sending system prompt...") - _state.value = State.ProcessingSystemPrompt - processSystemPrompt(prompt).let { result -> - if (result != 0) { - val errorMessage = "Failed to process system prompt: $result" - _state.value = State.Error(errorMessage) - throw IllegalStateException(errorMessage) - } + /** + * Process the plain text system prompt + */ + override suspend fun setSystemPrompt(prompt: String) = + withContext(llamaDispatcher) { + require(prompt.isNotBlank()) { "Cannot process empty system prompt!" } + check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" } + check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" } + + Log.i(TAG, "Sending system prompt...") + _readyForSystemPrompt = false + _state.value = State.ProcessingSystemPrompt + processSystemPrompt(prompt).let { result -> + if (result != 0) { + val errorMessage = "Failed to process system prompt: $result" + _state.value = State.Error(errorMessage) + throw IllegalStateException(errorMessage) } - Log.i(TAG, "System prompt processed! Awaiting user prompt...") - } ?: run { - Log.w(TAG, "No system prompt to process.") } - _state.value = State.AwaitingUserPrompt + Log.i(TAG, "System prompt processed! Awaiting user prompt...") + _state.value = State.ModelReady } /** @@ -131,12 +138,13 @@ class LLamaAndroid private constructor() : InferenceEngine { predictLength: Int, ): Flow = flow { require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } - check(_state.value is State.AwaitingUserPrompt) { + check(_state.value is State.ModelReady) { "User prompt discarded due to: ${_state.value}" } try { Log.i(TAG, "Sending user prompt...") + _readyForSystemPrompt = false _state.value = State.ProcessingUserPrompt processUserPrompt(message, predictLength).let { result -> if (result != 0) { @@ -153,10 +161,10 @@ class LLamaAndroid private constructor() : InferenceEngine { } ?: break } Log.i(TAG, "Assistant generation complete. Awaiting user prompt...") - _state.value = State.AwaitingUserPrompt + _state.value = State.ModelReady } catch (e: CancellationException) { Log.i(TAG, "Generation cancelled by user.") - _state.value = State.AwaitingUserPrompt + _state.value = State.ModelReady throw e } catch (e: Exception) { Log.e(TAG, "Error during generation!", e) @@ -170,13 +178,14 @@ class LLamaAndroid private constructor() : InferenceEngine { */ override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String = withContext(llamaDispatcher) { - check(_state.value is State.AwaitingUserPrompt) { + check(_state.value is State.ModelReady) { "Benchmark request discarded due to: $state" } Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)") + _readyForSystemPrompt = false // Just to be safe _state.value = State.Benchmarking benchModel(pp, tg, pl, nr).also { - _state.value = State.AwaitingUserPrompt + _state.value = State.ModelReady } } @@ -186,8 +195,9 @@ class LLamaAndroid private constructor() : InferenceEngine { override suspend fun unloadModel() = withContext(llamaDispatcher) { when(_state.value) { - is State.AwaitingUserPrompt, is State.Error -> { + is State.ModelReady, is State.Error -> { Log.i(TAG, "Unloading model and free resources...") + _readyForSystemPrompt = false unload() _state.value = State.LibraryLoaded Log.i(TAG, "Model unloaded!") @@ -200,8 +210,8 @@ class LLamaAndroid private constructor() : InferenceEngine { /** * Cancel all ongoing coroutines and free GGML backends */ - @RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!") override fun destroy() { + _readyForSystemPrompt = false llamaScope.cancel() when(_state.value) { is State.Uninitialized -> {}