core: break down InferenceManager due to Interface Segregation Principle

This commit is contained in:
Han Yin 2025-04-15 21:41:48 -07:00
parent 286ed05f13
commit 9cfa74f754
7 changed files with 179 additions and 114 deletions

View File

@ -6,7 +6,12 @@ import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.data.repository.ModelRepositoryImpl
import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl
import com.example.llama.revamp.engine.BenchmarkService
import com.example.llama.revamp.engine.ConversationService
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.engine.InferenceService
import com.example.llama.revamp.engine.InferenceServiceImpl
import com.example.llama.revamp.engine.ModelLoadingService
import com.example.llama.revamp.monitoring.PerformanceMonitor
import dagger.Binds
import dagger.Module
@ -18,7 +23,19 @@ import javax.inject.Singleton
@Module
@InstallIn(SingletonComponent::class)
abstract class AppModule {
internal abstract class AppModule {
@Binds
abstract fun bindInferenceService(impl: InferenceServiceImpl) : InferenceService
@Binds
abstract fun bindModelLoadingService(impl: InferenceServiceImpl) : ModelLoadingService
@Binds
abstract fun bindBenchmarkService(impl: InferenceServiceImpl) : BenchmarkService
@Binds
abstract fun bindConversationService(impl: InferenceServiceImpl) : ConversationService
@Binds
abstract fun bindsModelsRepository(impl: ModelRepositoryImpl): ModelRepository

View File

@ -10,41 +10,124 @@ import kotlinx.coroutines.flow.flow
import javax.inject.Inject
import javax.inject.Singleton
@Singleton
class InferenceManager @Inject constructor(
private val inferenceEngine: InferenceEngine
) {
// Expose engine state
val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
interface InferenceService {
/**
* Expose engine state
*/
val engineState: StateFlow<InferenceEngine.State>
// Benchmark results
val benchmarkResults: StateFlow<String?> = inferenceEngine.benchmarkResults
// Currently loaded model
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
val currentModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
// System prompt
private val _systemPrompt = MutableStateFlow<String?>(null)
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
// Token metrics tracking
private var generationStartTime: Long = 0L
private var firstTokenTime: Long = 0L
private var tokenCount: Int = 0
private var isFirstToken: Boolean = true
/**
* Currently selected model
*/
val currentSelectedModel: StateFlow<ModelInfo?>
/**
* Set current model
*/
fun setCurrentModel(model: ModelInfo) {
_currentModel.value = model
fun setCurrentModel(model: ModelInfo)
/**
* Unload current model and free resources
*/
suspend fun unloadModel()
}
interface ModelLoadingService : InferenceService {
/**
* Load a model for benchmark
*/
suspend fun loadModelForBenchmark(): Boolean {
suspend fun loadModelForBenchmark(): Boolean
/**
* Load a model for conversation
*/
suspend fun loadModelForConversation(systemPrompt: String?): Boolean
}
interface BenchmarkService : InferenceService {
/**
* Run benchmark
*
* @param pp: Prompt Processing size
* @param tg: Token Generation size
* @param pl: Parallel sequences
* @param nr: repetitions (Number of Runs)
*/
suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String
/**
* Benchmark results
*/
val results: StateFlow<String?>
}
interface ConversationService : InferenceService {
/**
* System prompt
*/
val systemPrompt: StateFlow<String?>
/**
* Generate response from prompt
*/
fun generateResponse(prompt: String): Flow<GenerationUpdate>
/**
* Create token metrics based on current state
*/
fun createTokenMetrics(): TokenMetrics
}
/**
* Represents an update during text generation
*/
data class GenerationUpdate(
val text: String,
val isComplete: Boolean
)
/**
* Metrics for token generation performance
*/
data class TokenMetrics(
val tokensCount: Int,
val ttftMs: Long,
val tpsMs: Float,
) {
val text: String
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
}
/**
* Internal implementation of the above [InferenceService]s
*/
@Singleton
internal class InferenceServiceImpl @Inject internal constructor(
private val inferenceEngine: InferenceEngine
) : ModelLoadingService, BenchmarkService, ConversationService {
/* InferenceService implementation */
override val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
override fun setCurrentModel(model: ModelInfo) {
_currentModel.value = model
}
override suspend fun unloadModel() = inferenceEngine.unloadModel()
/**
* Shut down inference engine
*/
fun destroy() = inferenceEngine.destroy()
/* ModelLoadingService implementation */
override suspend fun loadModelForBenchmark(): Boolean {
return _currentModel.value?.let { model ->
try {
inferenceEngine.loadModel(model.path)
@ -53,13 +136,10 @@ class InferenceManager @Inject constructor(
Log.e("InferenceManager", "Error loading model", e)
false
}
} ?: false
} == true
}
/**
* Load a model for conversation
*/
suspend fun loadModelForConversation(systemPrompt: String? = null): Boolean {
override suspend fun loadModelForConversation(systemPrompt: String?): Boolean {
_systemPrompt.value = systemPrompt
return _currentModel.value?.let { model ->
try {
@ -69,23 +149,30 @@ class InferenceManager @Inject constructor(
Log.e("InferenceManager", "Error loading model", e)
false
}
} ?: false
} == true
}
/**
* Run benchmark
*/
suspend fun benchmark(
pp: Int = 512,
tg: Int = 128,
pl: Int = 1,
nr: Int = 3
): String = inferenceEngine.bench(pp, tg, pl, nr)
/**
* Generate response from prompt
*/
fun generateResponse(prompt: String): Flow<Pair<String, Boolean>> = flow {
/* BenchmarkService implementation */
override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String =
inferenceEngine.bench(pp, tg, pl, nr)
override val results: StateFlow<String?> = inferenceEngine.benchmarkResults
/* ConversationService implementation */
private val _systemPrompt = MutableStateFlow<String?>(null)
override val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
// Token metrics tracking
private var generationStartTime: Long = 0L
private var firstTokenTime: Long = 0L
private var tokenCount: Int = 0
private var isFirstToken: Boolean = true
override fun generateResponse(prompt: String): Flow<GenerationUpdate> = flow {
try {
// Reset metrics tracking
generationStartTime = System.currentTimeMillis()
@ -111,14 +198,14 @@ class InferenceManager @Inject constructor(
response.append(token)
// Emit ongoing response (not completed)
emit(Pair(response.toString(), false))
emit(GenerationUpdate(response.toString(), false))
}
// Calculate final metrics after completion
val metrics = createTokenMetrics()
// Emit final response with completion flag
emit(Pair(response.toString(), true))
emit(GenerationUpdate(response.toString(), true))
} catch (e: Exception) {
// Emit error
val metrics = createTokenMetrics()
@ -126,10 +213,7 @@ class InferenceManager @Inject constructor(
}
}
/**
* Create token metrics based on current state
*/
fun createTokenMetrics(): TokenMetrics {
override fun createTokenMetrics(): TokenMetrics {
val endTime = System.currentTimeMillis()
val totalTimeMs = endTime - generationStartTime
@ -147,23 +231,4 @@ class InferenceManager @Inject constructor(
if (tokens <= 0 || timeMs <= 0) return 0f
return (tokens.toFloat() * 1000f) / timeMs
}
/**
* Unload current model
*/
suspend fun unloadModel() = inferenceEngine.unloadModel()
/**
* Cleanup resources
*/
fun destroy() = inferenceEngine.destroy()
}
data class TokenMetrics(
val tokensCount: Int,
val ttftMs: Long,
val tpsMs: Float,
) {
val text: String
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
}

View File

@ -3,8 +3,8 @@ package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.engine.BenchmarkService
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.engine.InferenceManager
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.launch
@ -12,18 +12,18 @@ import javax.inject.Inject
@HiltViewModel
class BenchmarkViewModel @Inject constructor(
private val inferenceManager: InferenceManager
private val benchmarkService: BenchmarkService
) : ViewModel() {
val engineState: StateFlow<InferenceEngine.State> = inferenceManager.engineState
val benchmarkResults: StateFlow<String?> = inferenceManager.benchmarkResults
val selectedModel: StateFlow<ModelInfo?> = inferenceManager.currentModel
val engineState: StateFlow<InferenceEngine.State> = benchmarkService.engineState
val benchmarkResults: StateFlow<String?> = benchmarkService.results
val selectedModel: StateFlow<ModelInfo?> = benchmarkService.currentSelectedModel
/**
* Run benchmark with specified parameters
*/
fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) =
viewModelScope.launch {
inferenceManager.benchmark(pp, tg, pl, nr)
benchmarkService.benchmark(pp, tg, pl, nr)
}
}

View File

@ -2,7 +2,7 @@ package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.engine.InferenceManager
import com.example.llama.revamp.engine.ConversationService
import com.example.llama.revamp.engine.TokenMetrics
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.Job
@ -14,17 +14,16 @@ import java.text.SimpleDateFormat
import java.util.Date
import java.util.Locale
import javax.inject.Inject
import kotlin.getValue
@HiltViewModel
class ConversationViewModel @Inject constructor(
private val inferenceManager: InferenceManager
private val conversationService: ConversationService
) : ViewModel() {
val engineState = inferenceManager.engineState
val selectedModel = inferenceManager.currentModel
val systemPrompt = inferenceManager.systemPrompt
val engineState = conversationService.engineState
val selectedModel = conversationService.currentSelectedModel
val systemPrompt = conversationService.systemPrompt
// Messages in conversation
private val _messages = MutableStateFlow<List<Message>>(emptyList())
@ -60,7 +59,7 @@ class ConversationViewModel @Inject constructor(
// Collect response
tokenCollectionJob = viewModelScope.launch {
try {
inferenceManager.generateResponse(content)
conversationService.generateResponse(content)
.collect { (text, isComplete) ->
updateAssistantMessage(text, isComplete)
}
@ -85,7 +84,7 @@ class ConversationViewModel @Inject constructor(
currentMessages[lastIndex] = Message.Assistant.Completed(
content = text,
timestamp = currentAssistantMessage.timestamp,
metrics = inferenceManager.createTokenMetrics()
metrics = conversationService.createTokenMetrics()
)
} else {
// Ongoing message update
@ -110,7 +109,7 @@ class ConversationViewModel @Inject constructor(
currentMessages[lastIndex] = Message.Assistant.Completed(
content = "${currentAssistantMessage.content}[Error: ${e.message}]",
timestamp = currentAssistantMessage.timestamp,
metrics = inferenceManager.createTokenMetrics()
metrics = conversationService.createTokenMetrics()
)
_messages.value = currentMessages
}

View File

@ -2,7 +2,7 @@ package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel
import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.engine.InferenceManager
import com.example.llama.revamp.engine.InferenceService
import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject
@ -11,20 +11,14 @@ import javax.inject.Inject
* Main ViewModel that expose the core states of [InferenceEngine]
*/
class MainViewModel @Inject constructor (
private val inferenceManager: InferenceManager,
private val inferenceService: InferenceService,
) : ViewModel() {
val engineState = inferenceManager.engineState
val engineState = inferenceService.engineState
/**
* Unload the current model and release the resources
*/
suspend fun unloadModel() = inferenceManager.unloadModel()
companion object {
private val TAG = MainViewModel::class.java.simpleName
private const val SUBSCRIPTION_TIMEOUT_MS = 5000L
}
suspend fun unloadModel() = inferenceService.unloadModel()
}

View File

@ -4,7 +4,7 @@ import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.engine.InferenceManager
import com.example.llama.revamp.engine.ModelLoadingService
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
@ -14,14 +14,14 @@ import javax.inject.Inject
@HiltViewModel
class ModelLoadingViewModel @Inject constructor(
private val inferenceManager: InferenceManager,
private val modelLoadingService: ModelLoadingService,
private val repository: SystemPromptRepository
) : ViewModel() {
/**
* Currently selected model to be loaded
*/
val selectedModel = inferenceManager.currentModel
val selectedModel = modelLoadingService.currentSelectedModel
/**
* Preset prompts
@ -83,13 +83,13 @@ class ModelLoadingViewModel @Inject constructor(
* Prepares the engine for benchmark mode.
*/
suspend fun prepareForBenchmark() =
inferenceManager.loadModelForBenchmark()
modelLoadingService.loadModelForBenchmark()
/**
* Prepare for conversation
*/
suspend fun prepareForConversation(systemPrompt: String? = null) =
inferenceManager.loadModelForConversation(systemPrompt)
modelLoadingService.loadModelForConversation(systemPrompt)
companion object {

View File

@ -4,7 +4,7 @@ import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.engine.InferenceManager
import com.example.llama.revamp.engine.InferenceService
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow
@ -15,7 +15,7 @@ import javax.inject.Inject
@HiltViewModel
class ModelSelectionViewModel @Inject constructor(
private val inferenceManager: InferenceManager,
private val inferenceService: InferenceService,
private val modelRepository: ModelRepository
) : ViewModel() {
@ -29,27 +29,17 @@ class ModelSelectionViewModel @Inject constructor(
initialValue = emptyList()
)
/**
* Access to currently selected model
*/
val selectedModel = inferenceManager.currentModel
/**
* Select a model and update its last used timestamp
*/
fun selectModel(modelInfo: ModelInfo) {
inferenceManager.setCurrentModel(modelInfo)
inferenceService.setCurrentModel(modelInfo)
viewModelScope.launch {
modelRepository.updateModelLastUsed(modelInfo.id)
}
}
/**
* Unload model when navigating away
*/
suspend fun unloadModel() = inferenceManager.unloadModel()
companion object {
private val TAG = ModelSelectionViewModel::class.java.simpleName