From c08d02d2339881fd34c30590ade7b2bf70c3548b Mon Sep 17 00:00:00 2001 From: Han Yin Date: Fri, 18 Apr 2025 12:10:09 -0700 Subject: [PATCH] LLama: add ModelUnloadingState to engine State; add missing state checks in stub engine; fix instrumentation engine's error messages --- .../revamp/engine/StubInferenceEngine.kt | 250 +++++++++++------- .../java/android/llama/cpp/InferenceEngine.kt | 1 + .../java/android/llama/cpp/LLamaAndroid.kt | 18 +- 3 files changed, 162 insertions(+), 107 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt index 6d2ee7bd01..6517b3f022 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt @@ -3,13 +3,19 @@ package com.example.llama.revamp.engine import android.llama.cpp.InferenceEngine import android.llama.cpp.InferenceEngine.State import android.util.Log +import com.example.llama.revamp.APP_NAME import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import org.jetbrains.annotations.TestOnly import org.jetbrains.annotations.VisibleForTesting import javax.inject.Singleton @@ -24,161 +30,201 @@ class StubInferenceEngine : InferenceEngine { companion object { private val TAG = StubInferenceEngine::class.java.simpleName - private const val STUB_MODEL_LOADING_TIME = 2000L - private const val STUB_BENCHMARKING_TIME = 4000L - private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 3000L - private const val STUB_USER_PROMPT_PROCESSING_TIME = 1500L + private const val STUB_LIBRARY_LOADING_TIME = 2_000L + private const val STUB_MODEL_LOADING_TIME = 3_000L + private const val STUB_MODEL_UNLOADING_TIME = 2_000L + private const val STUB_BENCHMARKING_TIME = 8_000L + private const val STUB_SYSTEM_PROMPT_PROCESSING_TIME = 4_000L + private const val STUB_USER_PROMPT_PROCESSING_TIME = 2_000L private const val STUB_TOKEN_GENERATION_TIME = 200L } private val _state = MutableStateFlow(State.Uninitialized) override val state: StateFlow = _state - init { - Log.i(TAG, "Initiated!") + private var _readyForSystemPrompt = false - // Simulate library loading - _state.value = State.LibraryLoaded + private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1) + private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob()) + + init { + llamaScope.launch { + Log.i(TAG, "Initiated!") + + // Simulate library loading + delay(STUB_LIBRARY_LOADING_TIME) + + _state.value = State.LibraryLoaded + } } /** * Loads a model from the given path. */ - override suspend fun loadModel(pathToModel: String) { - Log.i(TAG, "loadModel! state: ${_state.value}") + override suspend fun loadModel(pathToModel: String) = + withContext(llamaDispatcher) { + Log.i(TAG, "loadModel! state: ${_state.value.javaClass.simpleName}") + check(_state.value is State.LibraryLoaded) { + "Cannot load model at ${_state.value.javaClass.simpleName}" + } - try { - _state.value = State.LoadingModel + try { + _readyForSystemPrompt = false + _state.value = State.LoadingModel - // Simulate model loading - delay(STUB_MODEL_LOADING_TIME) + // Simulate model loading + delay(STUB_MODEL_LOADING_TIME) - _state.value = State.ModelReady + _readyForSystemPrompt = true + _state.value = State.ModelReady - - } catch (e: CancellationException) { - // If coroutine is cancelled, propagate cancellation - throw e - } catch (e: Exception) { - _state.value = State.Error(e.message ?: "Unknown error during model loading") + } catch (e: CancellationException) { + // If coroutine is cancelled, propagate cancellation + throw e + } catch (e: Exception) { + _state.value = State.Error(e.message ?: "Unknown error during model loading") + } } - } /** * Process the plain text system prompt */ - override suspend fun setSystemPrompt(prompt: String) { - try { - _state.value = State.ProcessingSystemPrompt + override suspend fun setSystemPrompt(prompt: String) = + withContext(llamaDispatcher) { + check(_state.value is State.ModelReady) { + "Cannot load model at ${_state.value.javaClass.simpleName}" + } + check(_readyForSystemPrompt) { + "System prompt must be set ** RIGHT AFTER ** model loaded!" + } - // Simulate processing system prompt - delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME) + try { + _state.value = State.ProcessingSystemPrompt - _state.value = State.ModelReady - } catch (e: CancellationException) { - // If coroutine is cancelled, propagate cancellation - throw e - } catch (e: Exception) { - _state.value = State.Error(e.message ?: "Unknown error during model loading") + // Simulate processing system prompt + delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME) + + _state.value = State.ModelReady + } catch (e: CancellationException) { + // If coroutine is cancelled, propagate cancellation + throw e + } catch (e: Exception) { + _state.value = State.Error(e.message ?: "Unknown error during model loading") + } } - } /** * Sends a user prompt to the loaded model and returns a Flow of generated tokens. */ - override fun sendUserPrompt(message: String, predictLength: Int): Flow { - Log.i(TAG, "sendUserPrompt! state: ${_state.value}") + override fun sendUserPrompt(message: String, predictLength: Int): Flow = flow { + require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } + check(_state.value is State.ModelReady) { + "Cannot load model at ${_state.value.javaClass.simpleName}" + } - _state.value = State.ProcessingUserPrompt + try { + Log.i(TAG, "sendUserPrompt! \n$message") + _state.value = State.ProcessingUserPrompt - // This would be replaced with actual token generation logic - return flow { - try { - // Simulate longer processing time (1.5 seconds) - delay(STUB_USER_PROMPT_PROCESSING_TIME) + // Simulate longer processing time + delay(STUB_USER_PROMPT_PROCESSING_TIME) - _state.value = State.Generating + _state.value = State.Generating - // Simulate token generation - val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message" - response.split(" ").forEach { - emit("$it ") - delay(STUB_TOKEN_GENERATION_TIME) - } - - _state.value = State.ModelReady - } catch (e: CancellationException) { - // Handle cancellation gracefully - _state.value = State.ModelReady - throw e - } catch (e: Exception) { - _state.value = State.Error(e.message ?: "Unknown error during generation") - throw e - } - }.catch { e -> - // If it's not a cancellation, update state to error - if (e !is CancellationException) { - _state.value = State.Error(e.message ?: "Unknown error during generation") + // Simulate token generation + val response = "This is a simulated response from the LLM model. The actual implementation would generate tokens one by one based on the input: $message" + response.split(" ").forEach { + emit("$it ") + delay(STUB_TOKEN_GENERATION_TIME) } + + _state.value = State.ModelReady + } catch (e: CancellationException) { + // Handle cancellation gracefully + _state.value = State.ModelReady + throw e + } catch (e: Exception) { + _state.value = State.Error(e.message ?: "Unknown error during generation") throw e } + }.catch { e -> + // If it's not a cancellation, update state to error + if (e !is CancellationException) { + _state.value = State.Error(e.message ?: "Unknown error during generation") + } + throw e } /** * Runs a benchmark with the specified parameters. */ - override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String { - Log.i(TAG, "bench! state: ${_state.value}") + override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String = + withContext(llamaDispatcher) { + check(_state.value is State.ModelReady) { + "Cannot load model at ${_state.value.javaClass.simpleName}" + } - _state.value = State.Benchmarking + try { + Log.i(TAG, "bench! state: ${_state.value}") + _state.value = State.Benchmarking - try { - // Simulate benchmark running - delay(STUB_BENCHMARKING_TIME) + // Simulate benchmark running + delay(STUB_BENCHMARKING_TIME) - // Generate fake benchmark results - val modelDesc = "Kleidi Llama" - val model_size = "7" - val model_n_params = "7" - val backend = "CPU" + // Generate fake benchmark results + val modelDesc = APP_NAME + val model_size = "7" + val model_n_params = "7" + val backend = "CPU" - // Random values for benchmarks - val pp_avg = (51.4 + Math.random() * 5.14).toFloat() - val pp_std = (5.14 + Math.random() * 0.514).toFloat() - val tg_avg = (11.4 + Math.random() * 1.14).toFloat() - val tg_std = (1.14 + Math.random() * 0.114).toFloat() + // Random values for benchmarks + val pp_avg = (51.4 + Math.random() * 5.14).toFloat() + val pp_std = (5.14 + Math.random() * 0.514).toFloat() + val tg_avg = (11.4 + Math.random() * 1.14).toFloat() + val tg_std = (1.14 + Math.random() * 0.114).toFloat() - val result = StringBuilder() - result.append("| model | size | params | backend | test | t/s |\n") - result.append("| --- | --- | --- | --- | --- | --- |\n") - result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ") - result.append("$backend | pp $pp | $pp_avg ± $pp_std |\n") - result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ") - result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n") + val result = StringBuilder() + result.append("| model | size | params | backend | test | t/s |\n") + result.append("| --- | --- | --- | --- | --- | --- |\n") + result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ") + result.append("$backend | pp $pp | $pp_avg ± $pp_std |\n") + result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ") + result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n") - _state.value = State.ModelReady + _state.value = State.ModelReady - return result.toString() - } catch (e: CancellationException) { - // If coroutine is cancelled, propagate cancellation - _state.value = State.ModelReady - throw e - } catch (e: Exception) { - _state.value = State.Error(e.message ?: "Unknown error during benchmarking") - return "Error: ${e.message}" + result.toString() + } catch (e: CancellationException) { + // If coroutine is cancelled, propagate cancellation + Log.w(TAG, "Unexpected user cancellation while benchmarking!") + _state.value = State.ModelReady + throw e + } catch (e: Exception) { + _state.value = State.Error(e.message ?: "Unknown error during benchmarking") + "Error: ${e.message}" + } } - } /** * Unloads the currently loaded model. */ - override suspend fun unloadModel() { - Log.i(TAG, "unloadModel! state: ${_state.value}") + override suspend fun unloadModel() = + withContext(llamaDispatcher) { + when(val state = _state.value) { + is State.ModelReady, is State.Error -> { + Log.i(TAG, "unloadModel! state: ${_state.value.javaClass.simpleName}") + _state.value = State.UnloadingModel - // Simulate model unloading time - delay(2000) - _state.value = State.LibraryLoaded - } + // Simulate model unloading time + delay(STUB_MODEL_UNLOADING_TIME) + + _state.value = State.LibraryLoaded + } + else -> throw IllegalStateException( + "Cannot load model at ${_state.value.javaClass.simpleName}" + ) + } + } /** * Cleans up resources when the engine is no longer needed. diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt index f38cece58f..b1f1733ea4 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngine.kt @@ -50,6 +50,7 @@ interface InferenceEngine { object LibraryLoaded : State() object LoadingModel : State() + object UnloadingModel : State() object ModelReady : State() object ProcessingSystemPrompt : State() diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index 55c5ff187d..cbf8c4823a 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -87,7 +87,9 @@ class LLamaAndroid private constructor() : InferenceEngine { */ override suspend fun loadModel(pathToModel: String) = withContext(llamaDispatcher) { - check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" } + check(_state.value is State.LibraryLoaded) { + "Cannot load model in ${_state.value.javaClass.simpleName}!" + } File(pathToModel).let { require(it.exists()) { "Model file not found: $pathToModel" } require(it.isFile) { "Model file is not a file: $pathToModel" } @@ -114,7 +116,9 @@ class LLamaAndroid private constructor() : InferenceEngine { withContext(llamaDispatcher) { require(prompt.isNotBlank()) { "Cannot process empty system prompt!" } check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" } - check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" } + check(_state.value is State.ModelReady) { + "Cannot process system prompt in ${_state.value.javaClass.simpleName}!" + } Log.i(TAG, "Sending system prompt...") _readyForSystemPrompt = false @@ -139,13 +143,14 @@ class LLamaAndroid private constructor() : InferenceEngine { ): Flow = flow { require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } check(_state.value is State.ModelReady) { - "User prompt discarded due to: ${_state.value}" + "User prompt discarded due to: ${_state.value.javaClass.simpleName}" } try { Log.i(TAG, "Sending user prompt...") _readyForSystemPrompt = false _state.value = State.ProcessingUserPrompt + processUserPrompt(message, predictLength).let { result -> if (result != 0) { Log.e(TAG, "Failed to process user prompt: $result") @@ -194,16 +199,19 @@ class LLamaAndroid private constructor() : InferenceEngine { */ override suspend fun unloadModel() = withContext(llamaDispatcher) { - when(_state.value) { + when(val state = _state.value) { is State.ModelReady, is State.Error -> { Log.i(TAG, "Unloading model and free resources...") _readyForSystemPrompt = false + _state.value = State.UnloadingModel + unload() + _state.value = State.LibraryLoaded Log.i(TAG, "Model unloaded!") Unit } - else -> throw IllegalStateException("Cannot unload model in ${_state.value}") + else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}") } }