From 32d778bb8e3a3320cfadfe06f42d2a0ba964152e Mon Sep 17 00:00:00 2001 From: Han Yin Date: Tue, 15 Apr 2025 13:46:52 -0700 Subject: [PATCH] core: extract conversation and benchmark logics into InferenceManager; add logs and missing state updates in stub InferenceEngine --- .../llama/revamp/engine/InferenceEngine.kt | 19 +- .../llama/revamp/engine/InferenceManager.kt | 169 ++++++++++++++++++ 2 files changed, 186 insertions(+), 2 deletions(-) create mode 100644 examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceManager.kt diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt index 9c175d56d9..61f089dc66 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceEngine.kt @@ -1,5 +1,6 @@ package com.example.llama.revamp.engine +import android.util.Log import kotlinx.coroutines.CancellationException import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow @@ -13,7 +14,9 @@ import kotlinx.coroutines.flow.flow */ class InferenceEngine { companion object { - const val DEFAULT_PREDICT_LENGTH = 1024 + private val TAG = InferenceEngine::class.java.simpleName + + private const val DEFAULT_PREDICT_LENGTH = 1024 } sealed class State { @@ -45,6 +48,8 @@ class InferenceEngine { val benchmarkResults: StateFlow = _benchmarkResultsFlow init { + Log.i(TAG, "Initiated!") + // Simulate library loading _state.value = State.LibraryLoaded } @@ -53,6 +58,8 @@ class InferenceEngine { * Loads a model from the given path with an optional system prompt. */ suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) { + Log.i(TAG, "loadModel! state: ${_state.value}") + try { _state.value = State.LoadingModel @@ -81,6 +88,8 @@ class InferenceEngine { * Sends a user prompt to the loaded model and returns a Flow of generated tokens. */ fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow { + Log.i(TAG, "sendUserPrompt! state: ${_state.value}") + _state.value = State.ProcessingUserPrompt // This would be replaced with actual token generation logic @@ -123,6 +132,8 @@ class InferenceEngine { * Runs a benchmark with the specified parameters. */ suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { + Log.i(TAG, "bench! state: ${_state.value}") + _state.value = State.Benchmarking try { @@ -169,6 +180,8 @@ class InferenceEngine { * Unloads the currently loaded model. */ suspend fun unloadModel() { + Log.i(TAG, "unloadModel! state: ${_state.value}") + // Simulate model unloading time delay(2000) _state.value = State.LibraryLoaded @@ -180,6 +193,8 @@ class InferenceEngine { * Cleans up resources when the engine is no longer needed. */ fun destroy() { - // In a real implementation, this would release native resources + Log.i(TAG, "destroy! state: ${_state.value}") + + _state.value = State.Uninitialized } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceManager.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceManager.kt new file mode 100644 index 0000000000..e0b36eec5e --- /dev/null +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceManager.kt @@ -0,0 +1,169 @@ +package com.example.llama.revamp.engine + +import android.util.Log +import com.example.llama.revamp.data.model.ModelInfo +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.flow +import javax.inject.Inject +import javax.inject.Singleton + +@Singleton +class InferenceManager @Inject constructor( + private val inferenceEngine: InferenceEngine +) { + // Expose engine state + val engineState: StateFlow = inferenceEngine.state + + // Benchmark results + val benchmarkResults: StateFlow = inferenceEngine.benchmarkResults + + // Currently loaded model + private val _currentModel = MutableStateFlow(null) + val currentModel: StateFlow = _currentModel.asStateFlow() + + // System prompt + private val _systemPrompt = MutableStateFlow(null) + val systemPrompt: StateFlow = _systemPrompt.asStateFlow() + + // Token metrics tracking + private var generationStartTime: Long = 0L + private var firstTokenTime: Long = 0L + private var tokenCount: Int = 0 + private var isFirstToken: Boolean = true + + /** + * Set current model + */ + fun setCurrentModel(model: ModelInfo) { + _currentModel.value = model + } + + /** + * Load a model for benchmark + */ + suspend fun loadModelForBenchmark(): Boolean { + return _currentModel.value?.let { model -> + try { + inferenceEngine.loadModel(model.path) + true + } catch (e: Exception) { + Log.e("InferenceManager", "Error loading model", e) + false + } + } ?: false + } + + /** + * Load a model for conversation + */ + suspend fun loadModelForConversation(systemPrompt: String? = null): Boolean { + _systemPrompt.value = systemPrompt + return _currentModel.value?.let { model -> + try { + inferenceEngine.loadModel(model.path, systemPrompt) + true + } catch (e: Exception) { + Log.e("InferenceManager", "Error loading model", e) + false + } + } ?: false + } + + /** + * Run benchmark + */ + suspend fun benchmark( + pp: Int = 512, + tg: Int = 128, + pl: Int = 1, + nr: Int = 3 + ): String = inferenceEngine.bench(pp, tg, pl, nr) + + /** + * Generate response from prompt + */ + fun generateResponse(prompt: String): Flow> = flow { + try { + // Reset metrics tracking + generationStartTime = System.currentTimeMillis() + firstTokenTime = 0L + tokenCount = 0 + isFirstToken = true + + val response = StringBuilder() + + inferenceEngine.sendUserPrompt(prompt) + .collect { token -> + // Track first token time + if (isFirstToken && token.isNotBlank()) { + firstTokenTime = System.currentTimeMillis() + isFirstToken = false + } + + // Count tokens + if (token.isNotBlank()) { + tokenCount++ + } + + response.append(token) + + // Emit ongoing response (not completed) + emit(Pair(response.toString(), false)) + } + + // Calculate final metrics after completion + val metrics = createTokenMetrics() + + // Emit final response with completion flag + emit(Pair(response.toString(), true)) + } catch (e: Exception) { + // Emit error + val metrics = createTokenMetrics() + throw e + } + } + + /** + * Create token metrics based on current state + */ + fun createTokenMetrics(): TokenMetrics { + val endTime = System.currentTimeMillis() + val totalTimeMs = endTime - generationStartTime + + return TokenMetrics( + tokensCount = tokenCount, + ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L, + tpsMs = calculateTPS(tokenCount, totalTimeMs) + ) + } + + /** + * Calculate tokens per second + */ + private fun calculateTPS(tokens: Int, timeMs: Long): Float { + if (tokens <= 0 || timeMs <= 0) return 0f + return (tokens.toFloat() * 1000f) / timeMs + } + + /** + * Unload current model + */ + suspend fun unloadModel() = inferenceEngine.unloadModel() + + /** + * Cleanup resources + */ + fun destroy() = inferenceEngine.destroy() +} + +data class TokenMetrics( + val tokensCount: Int, + val ttftMs: Long, + val tpsMs: Float, +) { + val text: String + get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}" +}