core: extract conversation and benchmark logics into InferenceManager; add logs and missing state updates in stub InferenceEngine

This commit is contained in:
Han Yin 2025-04-15 13:46:52 -07:00
parent 51b120f464
commit 32d778bb8e
2 changed files with 186 additions and 2 deletions

View File

@ -1,5 +1,6 @@
package com.example.llama.revamp.engine
import android.util.Log
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
@ -13,7 +14,9 @@ import kotlinx.coroutines.flow.flow
*/
class InferenceEngine {
companion object {
const val DEFAULT_PREDICT_LENGTH = 1024
private val TAG = InferenceEngine::class.java.simpleName
private const val DEFAULT_PREDICT_LENGTH = 1024
}
sealed class State {
@ -45,6 +48,8 @@ class InferenceEngine {
val benchmarkResults: StateFlow<String?> = _benchmarkResultsFlow
init {
Log.i(TAG, "Initiated!")
// Simulate library loading
_state.value = State.LibraryLoaded
}
@ -53,6 +58,8 @@ class InferenceEngine {
* Loads a model from the given path with an optional system prompt.
*/
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) {
Log.i(TAG, "loadModel! state: ${_state.value}")
try {
_state.value = State.LoadingModel
@ -81,6 +88,8 @@ class InferenceEngine {
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String> {
Log.i(TAG, "sendUserPrompt! state: ${_state.value}")
_state.value = State.ProcessingUserPrompt
// This would be replaced with actual token generation logic
@ -123,6 +132,8 @@ class InferenceEngine {
* Runs a benchmark with the specified parameters.
*/
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
Log.i(TAG, "bench! state: ${_state.value}")
_state.value = State.Benchmarking
try {
@ -169,6 +180,8 @@ class InferenceEngine {
* Unloads the currently loaded model.
*/
suspend fun unloadModel() {
Log.i(TAG, "unloadModel! state: ${_state.value}")
// Simulate model unloading time
delay(2000)
_state.value = State.LibraryLoaded
@ -180,6 +193,8 @@ class InferenceEngine {
* Cleans up resources when the engine is no longer needed.
*/
fun destroy() {
// In a real implementation, this would release native resources
Log.i(TAG, "destroy! state: ${_state.value}")
_state.value = State.Uninitialized
}
}

View File

@ -0,0 +1,169 @@
package com.example.llama.revamp.engine
import android.util.Log
import com.example.llama.revamp.data.model.ModelInfo
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.flow
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class InferenceManager @Inject constructor(
private val inferenceEngine: InferenceEngine
) {
// Expose engine state
val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
// Benchmark results
val benchmarkResults: StateFlow<String?> = inferenceEngine.benchmarkResults
// Currently loaded model
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
val currentModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
// System prompt
private val _systemPrompt = MutableStateFlow<String?>(null)
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
// Token metrics tracking
private var generationStartTime: Long = 0L
private var firstTokenTime: Long = 0L
private var tokenCount: Int = 0
private var isFirstToken: Boolean = true
/**
* Set current model
*/
fun setCurrentModel(model: ModelInfo) {
_currentModel.value = model
}
/**
* Load a model for benchmark
*/
suspend fun loadModelForBenchmark(): Boolean {
return _currentModel.value?.let { model ->
try {
inferenceEngine.loadModel(model.path)
true
} catch (e: Exception) {
Log.e("InferenceManager", "Error loading model", e)
false
}
} ?: false
}
/**
* Load a model for conversation
*/
suspend fun loadModelForConversation(systemPrompt: String? = null): Boolean {
_systemPrompt.value = systemPrompt
return _currentModel.value?.let { model ->
try {
inferenceEngine.loadModel(model.path, systemPrompt)
true
} catch (e: Exception) {
Log.e("InferenceManager", "Error loading model", e)
false
}
} ?: false
}
/**
* Run benchmark
*/
suspend fun benchmark(
pp: Int = 512,
tg: Int = 128,
pl: Int = 1,
nr: Int = 3
): String = inferenceEngine.bench(pp, tg, pl, nr)
/**
* Generate response from prompt
*/
fun generateResponse(prompt: String): Flow<Pair<String, Boolean>> = flow {
try {
// Reset metrics tracking
generationStartTime = System.currentTimeMillis()
firstTokenTime = 0L
tokenCount = 0
isFirstToken = true
val response = StringBuilder()
inferenceEngine.sendUserPrompt(prompt)
.collect { token ->
// Track first token time
if (isFirstToken && token.isNotBlank()) {
firstTokenTime = System.currentTimeMillis()
isFirstToken = false
}
// Count tokens
if (token.isNotBlank()) {
tokenCount++
}
response.append(token)
// Emit ongoing response (not completed)
emit(Pair(response.toString(), false))
}
// Calculate final metrics after completion
val metrics = createTokenMetrics()
// Emit final response with completion flag
emit(Pair(response.toString(), true))
} catch (e: Exception) {
// Emit error
val metrics = createTokenMetrics()
throw e
}
}
/**
* Create token metrics based on current state
*/
fun createTokenMetrics(): TokenMetrics {
val endTime = System.currentTimeMillis()
val totalTimeMs = endTime - generationStartTime
return TokenMetrics(
tokensCount = tokenCount,
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
tpsMs = calculateTPS(tokenCount, totalTimeMs)
)
}
/**
* Calculate tokens per second
*/
private fun calculateTPS(tokens: Int, timeMs: Long): Float {
if (tokens <= 0 || timeMs <= 0) return 0f
return (tokens.toFloat() * 1000f) / timeMs
}
/**
* Unload current model
*/
suspend fun unloadModel() = inferenceEngine.unloadModel()
/**
* Cleanup resources
*/
fun destroy() = inferenceEngine.destroy()
}
data class TokenMetrics(
val tokensCount: Int,
val ttftMs: Long,
val tpsMs: Float,
) {
val text: String
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
}