core: extract conversation and benchmark logics into InferenceManager; add logs and missing state updates in stub InferenceEngine
This commit is contained in:
parent
51b120f464
commit
32d778bb8e
|
|
@ -1,5 +1,6 @@
|
||||||
package com.example.llama.revamp.engine
|
package com.example.llama.revamp.engine
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
import kotlinx.coroutines.CancellationException
|
import kotlinx.coroutines.CancellationException
|
||||||
import kotlinx.coroutines.delay
|
import kotlinx.coroutines.delay
|
||||||
import kotlinx.coroutines.flow.Flow
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
|
@ -13,7 +14,9 @@ import kotlinx.coroutines.flow.flow
|
||||||
*/
|
*/
|
||||||
class InferenceEngine {
|
class InferenceEngine {
|
||||||
companion object {
|
companion object {
|
||||||
const val DEFAULT_PREDICT_LENGTH = 1024
|
private val TAG = InferenceEngine::class.java.simpleName
|
||||||
|
|
||||||
|
private const val DEFAULT_PREDICT_LENGTH = 1024
|
||||||
}
|
}
|
||||||
|
|
||||||
sealed class State {
|
sealed class State {
|
||||||
|
|
@ -45,6 +48,8 @@ class InferenceEngine {
|
||||||
val benchmarkResults: StateFlow<String?> = _benchmarkResultsFlow
|
val benchmarkResults: StateFlow<String?> = _benchmarkResultsFlow
|
||||||
|
|
||||||
init {
|
init {
|
||||||
|
Log.i(TAG, "Initiated!")
|
||||||
|
|
||||||
// Simulate library loading
|
// Simulate library loading
|
||||||
_state.value = State.LibraryLoaded
|
_state.value = State.LibraryLoaded
|
||||||
}
|
}
|
||||||
|
|
@ -53,6 +58,8 @@ class InferenceEngine {
|
||||||
* Loads a model from the given path with an optional system prompt.
|
* Loads a model from the given path with an optional system prompt.
|
||||||
*/
|
*/
|
||||||
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) {
|
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) {
|
||||||
|
Log.i(TAG, "loadModel! state: ${_state.value}")
|
||||||
|
|
||||||
try {
|
try {
|
||||||
_state.value = State.LoadingModel
|
_state.value = State.LoadingModel
|
||||||
|
|
||||||
|
|
@ -81,6 +88,8 @@ class InferenceEngine {
|
||||||
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
||||||
*/
|
*/
|
||||||
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String> {
|
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String> {
|
||||||
|
Log.i(TAG, "sendUserPrompt! state: ${_state.value}")
|
||||||
|
|
||||||
_state.value = State.ProcessingUserPrompt
|
_state.value = State.ProcessingUserPrompt
|
||||||
|
|
||||||
// This would be replaced with actual token generation logic
|
// This would be replaced with actual token generation logic
|
||||||
|
|
@ -123,6 +132,8 @@ class InferenceEngine {
|
||||||
* Runs a benchmark with the specified parameters.
|
* Runs a benchmark with the specified parameters.
|
||||||
*/
|
*/
|
||||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
||||||
|
Log.i(TAG, "bench! state: ${_state.value}")
|
||||||
|
|
||||||
_state.value = State.Benchmarking
|
_state.value = State.Benchmarking
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
@ -169,6 +180,8 @@ class InferenceEngine {
|
||||||
* Unloads the currently loaded model.
|
* Unloads the currently loaded model.
|
||||||
*/
|
*/
|
||||||
suspend fun unloadModel() {
|
suspend fun unloadModel() {
|
||||||
|
Log.i(TAG, "unloadModel! state: ${_state.value}")
|
||||||
|
|
||||||
// Simulate model unloading time
|
// Simulate model unloading time
|
||||||
delay(2000)
|
delay(2000)
|
||||||
_state.value = State.LibraryLoaded
|
_state.value = State.LibraryLoaded
|
||||||
|
|
@ -180,6 +193,8 @@ class InferenceEngine {
|
||||||
* Cleans up resources when the engine is no longer needed.
|
* Cleans up resources when the engine is no longer needed.
|
||||||
*/
|
*/
|
||||||
fun destroy() {
|
fun destroy() {
|
||||||
// In a real implementation, this would release native resources
|
Log.i(TAG, "destroy! state: ${_state.value}")
|
||||||
|
|
||||||
|
_state.value = State.Uninitialized
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,169 @@
|
||||||
|
package com.example.llama.revamp.engine
|
||||||
|
|
||||||
|
import android.util.Log
|
||||||
|
import com.example.llama.revamp.data.model.ModelInfo
|
||||||
|
import kotlinx.coroutines.flow.Flow
|
||||||
|
import kotlinx.coroutines.flow.MutableStateFlow
|
||||||
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
import kotlinx.coroutines.flow.asStateFlow
|
||||||
|
import kotlinx.coroutines.flow.flow
|
||||||
|
import javax.inject.Inject
|
||||||
|
import javax.inject.Singleton
|
||||||
|
|
||||||
|
@Singleton
|
||||||
|
class InferenceManager @Inject constructor(
|
||||||
|
private val inferenceEngine: InferenceEngine
|
||||||
|
) {
|
||||||
|
// Expose engine state
|
||||||
|
val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
|
||||||
|
|
||||||
|
// Benchmark results
|
||||||
|
val benchmarkResults: StateFlow<String?> = inferenceEngine.benchmarkResults
|
||||||
|
|
||||||
|
// Currently loaded model
|
||||||
|
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
|
||||||
|
val currentModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
|
||||||
|
|
||||||
|
// System prompt
|
||||||
|
private val _systemPrompt = MutableStateFlow<String?>(null)
|
||||||
|
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
|
||||||
|
|
||||||
|
// Token metrics tracking
|
||||||
|
private var generationStartTime: Long = 0L
|
||||||
|
private var firstTokenTime: Long = 0L
|
||||||
|
private var tokenCount: Int = 0
|
||||||
|
private var isFirstToken: Boolean = true
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set current model
|
||||||
|
*/
|
||||||
|
fun setCurrentModel(model: ModelInfo) {
|
||||||
|
_currentModel.value = model
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a model for benchmark
|
||||||
|
*/
|
||||||
|
suspend fun loadModelForBenchmark(): Boolean {
|
||||||
|
return _currentModel.value?.let { model ->
|
||||||
|
try {
|
||||||
|
inferenceEngine.loadModel(model.path)
|
||||||
|
true
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e("InferenceManager", "Error loading model", e)
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} ?: false
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a model for conversation
|
||||||
|
*/
|
||||||
|
suspend fun loadModelForConversation(systemPrompt: String? = null): Boolean {
|
||||||
|
_systemPrompt.value = systemPrompt
|
||||||
|
return _currentModel.value?.let { model ->
|
||||||
|
try {
|
||||||
|
inferenceEngine.loadModel(model.path, systemPrompt)
|
||||||
|
true
|
||||||
|
} catch (e: Exception) {
|
||||||
|
Log.e("InferenceManager", "Error loading model", e)
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} ?: false
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run benchmark
|
||||||
|
*/
|
||||||
|
suspend fun benchmark(
|
||||||
|
pp: Int = 512,
|
||||||
|
tg: Int = 128,
|
||||||
|
pl: Int = 1,
|
||||||
|
nr: Int = 3
|
||||||
|
): String = inferenceEngine.bench(pp, tg, pl, nr)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate response from prompt
|
||||||
|
*/
|
||||||
|
fun generateResponse(prompt: String): Flow<Pair<String, Boolean>> = flow {
|
||||||
|
try {
|
||||||
|
// Reset metrics tracking
|
||||||
|
generationStartTime = System.currentTimeMillis()
|
||||||
|
firstTokenTime = 0L
|
||||||
|
tokenCount = 0
|
||||||
|
isFirstToken = true
|
||||||
|
|
||||||
|
val response = StringBuilder()
|
||||||
|
|
||||||
|
inferenceEngine.sendUserPrompt(prompt)
|
||||||
|
.collect { token ->
|
||||||
|
// Track first token time
|
||||||
|
if (isFirstToken && token.isNotBlank()) {
|
||||||
|
firstTokenTime = System.currentTimeMillis()
|
||||||
|
isFirstToken = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count tokens
|
||||||
|
if (token.isNotBlank()) {
|
||||||
|
tokenCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
response.append(token)
|
||||||
|
|
||||||
|
// Emit ongoing response (not completed)
|
||||||
|
emit(Pair(response.toString(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate final metrics after completion
|
||||||
|
val metrics = createTokenMetrics()
|
||||||
|
|
||||||
|
// Emit final response with completion flag
|
||||||
|
emit(Pair(response.toString(), true))
|
||||||
|
} catch (e: Exception) {
|
||||||
|
// Emit error
|
||||||
|
val metrics = createTokenMetrics()
|
||||||
|
throw e
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create token metrics based on current state
|
||||||
|
*/
|
||||||
|
fun createTokenMetrics(): TokenMetrics {
|
||||||
|
val endTime = System.currentTimeMillis()
|
||||||
|
val totalTimeMs = endTime - generationStartTime
|
||||||
|
|
||||||
|
return TokenMetrics(
|
||||||
|
tokensCount = tokenCount,
|
||||||
|
ttftMs = if (firstTokenTime > 0) firstTokenTime - generationStartTime else 0L,
|
||||||
|
tpsMs = calculateTPS(tokenCount, totalTimeMs)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate tokens per second
|
||||||
|
*/
|
||||||
|
private fun calculateTPS(tokens: Int, timeMs: Long): Float {
|
||||||
|
if (tokens <= 0 || timeMs <= 0) return 0f
|
||||||
|
return (tokens.toFloat() * 1000f) / timeMs
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unload current model
|
||||||
|
*/
|
||||||
|
suspend fun unloadModel() = inferenceEngine.unloadModel()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cleanup resources
|
||||||
|
*/
|
||||||
|
fun destroy() = inferenceEngine.destroy()
|
||||||
|
}
|
||||||
|
|
||||||
|
data class TokenMetrics(
|
||||||
|
val tokensCount: Int,
|
||||||
|
val ttftMs: Long,
|
||||||
|
val tpsMs: Float,
|
||||||
|
) {
|
||||||
|
val text: String
|
||||||
|
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue