core: break down InferenceManager due to Interface Segregation Principle

This commit is contained in:
Han Yin 2025-04-15 21:41:48 -07:00
parent 286ed05f13
commit 9cfa74f754
7 changed files with 179 additions and 114 deletions

View File

@ -6,7 +6,12 @@ import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.data.repository.ModelRepositoryImpl import com.example.llama.revamp.data.repository.ModelRepositoryImpl
import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl
import com.example.llama.revamp.engine.BenchmarkService
import com.example.llama.revamp.engine.ConversationService
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.engine.InferenceService
import com.example.llama.revamp.engine.InferenceServiceImpl
import com.example.llama.revamp.engine.ModelLoadingService
import com.example.llama.revamp.monitoring.PerformanceMonitor import com.example.llama.revamp.monitoring.PerformanceMonitor
import dagger.Binds import dagger.Binds
import dagger.Module import dagger.Module
@ -18,7 +23,19 @@ import javax.inject.Singleton
@Module @Module
@InstallIn(SingletonComponent::class) @InstallIn(SingletonComponent::class)
abstract class AppModule { internal abstract class AppModule {
@Binds
abstract fun bindInferenceService(impl: InferenceServiceImpl) : InferenceService
@Binds
abstract fun bindModelLoadingService(impl: InferenceServiceImpl) : ModelLoadingService
@Binds
abstract fun bindBenchmarkService(impl: InferenceServiceImpl) : BenchmarkService
@Binds
abstract fun bindConversationService(impl: InferenceServiceImpl) : ConversationService
@Binds @Binds
abstract fun bindsModelsRepository(impl: ModelRepositoryImpl): ModelRepository abstract fun bindsModelsRepository(impl: ModelRepositoryImpl): ModelRepository

View File

@ -10,41 +10,124 @@ import kotlinx.coroutines.flow.flow
import javax.inject.Inject import javax.inject.Inject
import javax.inject.Singleton import javax.inject.Singleton
@Singleton interface InferenceService {
class InferenceManager @Inject constructor( /**
private val inferenceEngine: InferenceEngine * Expose engine state
) { */
// Expose engine state val engineState: StateFlow<InferenceEngine.State>
val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
// Benchmark results /**
val benchmarkResults: StateFlow<String?> = inferenceEngine.benchmarkResults * Currently selected model
*/
// Currently loaded model val currentSelectedModel: StateFlow<ModelInfo?>
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
val currentModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
// System prompt
private val _systemPrompt = MutableStateFlow<String?>(null)
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
// Token metrics tracking
private var generationStartTime: Long = 0L
private var firstTokenTime: Long = 0L
private var tokenCount: Int = 0
private var isFirstToken: Boolean = true
/** /**
* Set current model * Set current model
*/ */
fun setCurrentModel(model: ModelInfo) { fun setCurrentModel(model: ModelInfo)
_currentModel.value = model
}
/**
* Unload current model and free resources
*/
suspend fun unloadModel()
}
interface ModelLoadingService : InferenceService {
/** /**
* Load a model for benchmark * Load a model for benchmark
*/ */
suspend fun loadModelForBenchmark(): Boolean { suspend fun loadModelForBenchmark(): Boolean
/**
* Load a model for conversation
*/
suspend fun loadModelForConversation(systemPrompt: String?): Boolean
}
interface BenchmarkService : InferenceService {
/**
* Run benchmark
*
* @param pp: Prompt Processing size
* @param tg: Token Generation size
* @param pl: Parallel sequences
* @param nr: repetitions (Number of Runs)
*/
suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String
/**
* Benchmark results
*/
val results: StateFlow<String?>
}
interface ConversationService : InferenceService {
/**
* System prompt
*/
val systemPrompt: StateFlow<String?>
/**
* Generate response from prompt
*/
fun generateResponse(prompt: String): Flow<GenerationUpdate>
/**
* Create token metrics based on current state
*/
fun createTokenMetrics(): TokenMetrics
}
/**
* Represents an update during text generation
*/
data class GenerationUpdate(
val text: String,
val isComplete: Boolean
)
/**
* Metrics for token generation performance
*/
data class TokenMetrics(
val tokensCount: Int,
val ttftMs: Long,
val tpsMs: Float,
) {
val text: String
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
}
/**
* Internal implementation of the above [InferenceService]s
*/
@Singleton
internal class InferenceServiceImpl @Inject internal constructor(
private val inferenceEngine: InferenceEngine
) : ModelLoadingService, BenchmarkService, ConversationService {
/* InferenceService implementation */
override val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
override fun setCurrentModel(model: ModelInfo) {
_currentModel.value = model
}
override suspend fun unloadModel() = inferenceEngine.unloadModel()
/**
* Shut down inference engine
*/
fun destroy() = inferenceEngine.destroy()
/* ModelLoadingService implementation */
override suspend fun loadModelForBenchmark(): Boolean {
return _currentModel.value?.let { model -> return _currentModel.value?.let { model ->
try { try {
inferenceEngine.loadModel(model.path) inferenceEngine.loadModel(model.path)
@ -53,13 +136,10 @@ class InferenceManager @Inject constructor(
Log.e("InferenceManager", "Error loading model", e) Log.e("InferenceManager", "Error loading model", e)
false false
} }
} ?: false } == true
} }
/** override suspend fun loadModelForConversation(systemPrompt: String?): Boolean {
* Load a model for conversation
*/
suspend fun loadModelForConversation(systemPrompt: String? = null): Boolean {
_systemPrompt.value = systemPrompt _systemPrompt.value = systemPrompt
return _currentModel.value?.let { model -> return _currentModel.value?.let { model ->
try { try {
@ -69,23 +149,30 @@ class InferenceManager @Inject constructor(
Log.e("InferenceManager", "Error loading model", e) Log.e("InferenceManager", "Error loading model", e)
false false
} }
} ?: false } == true
} }
/**
* Run benchmark
*/
suspend fun benchmark(
pp: Int = 512,
tg: Int = 128,
pl: Int = 1,
nr: Int = 3
): String = inferenceEngine.bench(pp, tg, pl, nr)
/** /* BenchmarkService implementation */
* Generate response from prompt
*/ override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String =
fun generateResponse(prompt: String): Flow<Pair<String, Boolean>> = flow { inferenceEngine.bench(pp, tg, pl, nr)
override val results: StateFlow<String?> = inferenceEngine.benchmarkResults
/* ConversationService implementation */
private val _systemPrompt = MutableStateFlow<String?>(null)
override val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
// Token metrics tracking
private var generationStartTime: Long = 0L
private var firstTokenTime: Long = 0L
private var tokenCount: Int = 0
private var isFirstToken: Boolean = true
override fun generateResponse(prompt: String): Flow<GenerationUpdate> = flow {
try { try {
// Reset metrics tracking // Reset metrics tracking
generationStartTime = System.currentTimeMillis() generationStartTime = System.currentTimeMillis()
@ -111,14 +198,14 @@ class InferenceManager @Inject constructor(
response.append(token) response.append(token)
// Emit ongoing response (not completed) // Emit ongoing response (not completed)
emit(Pair(response.toString(), false)) emit(GenerationUpdate(response.toString(), false))
} }
// Calculate final metrics after completion // Calculate final metrics after completion
val metrics = createTokenMetrics() val metrics = createTokenMetrics()
// Emit final response with completion flag // Emit final response with completion flag
emit(Pair(response.toString(), true)) emit(GenerationUpdate(response.toString(), true))
} catch (e: Exception) { } catch (e: Exception) {
// Emit error // Emit error
val metrics = createTokenMetrics() val metrics = createTokenMetrics()
@ -126,10 +213,7 @@ class InferenceManager @Inject constructor(
} }
} }
/** override fun createTokenMetrics(): TokenMetrics {
* Create token metrics based on current state
*/
fun createTokenMetrics(): TokenMetrics {
val endTime = System.currentTimeMillis() val endTime = System.currentTimeMillis()
val totalTimeMs = endTime - generationStartTime val totalTimeMs = endTime - generationStartTime
@ -147,23 +231,4 @@ class InferenceManager @Inject constructor(
if (tokens <= 0 || timeMs <= 0) return 0f if (tokens <= 0 || timeMs <= 0) return 0f
return (tokens.toFloat() * 1000f) / timeMs return (tokens.toFloat() * 1000f) / timeMs
} }
/**
* Unload current model
*/
suspend fun unloadModel() = inferenceEngine.unloadModel()
/**
* Cleanup resources
*/
fun destroy() = inferenceEngine.destroy()
}
data class TokenMetrics(
val tokensCount: Int,
val ttftMs: Long,
val tpsMs: Float,
) {
val text: String
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
} }

View File

@ -3,8 +3,8 @@ package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.engine.BenchmarkService
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.engine.InferenceManager
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
@ -12,18 +12,18 @@ import javax.inject.Inject
@HiltViewModel @HiltViewModel
class BenchmarkViewModel @Inject constructor( class BenchmarkViewModel @Inject constructor(
private val inferenceManager: InferenceManager private val benchmarkService: BenchmarkService
) : ViewModel() { ) : ViewModel() {
val engineState: StateFlow<InferenceEngine.State> = inferenceManager.engineState val engineState: StateFlow<InferenceEngine.State> = benchmarkService.engineState
val benchmarkResults: StateFlow<String?> = inferenceManager.benchmarkResults val benchmarkResults: StateFlow<String?> = benchmarkService.results
val selectedModel: StateFlow<ModelInfo?> = inferenceManager.currentModel val selectedModel: StateFlow<ModelInfo?> = benchmarkService.currentSelectedModel
/** /**
* Run benchmark with specified parameters * Run benchmark with specified parameters
*/ */
fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) = fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) =
viewModelScope.launch { viewModelScope.launch {
inferenceManager.benchmark(pp, tg, pl, nr) benchmarkService.benchmark(pp, tg, pl, nr)
} }
} }

View File

@ -2,7 +2,7 @@ package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.engine.InferenceManager import com.example.llama.revamp.engine.ConversationService
import com.example.llama.revamp.engine.TokenMetrics import com.example.llama.revamp.engine.TokenMetrics
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
@ -14,17 +14,16 @@ import java.text.SimpleDateFormat
import java.util.Date import java.util.Date
import java.util.Locale import java.util.Locale
import javax.inject.Inject import javax.inject.Inject
import kotlin.getValue
@HiltViewModel @HiltViewModel
class ConversationViewModel @Inject constructor( class ConversationViewModel @Inject constructor(
private val inferenceManager: InferenceManager private val conversationService: ConversationService
) : ViewModel() { ) : ViewModel() {
val engineState = inferenceManager.engineState val engineState = conversationService.engineState
val selectedModel = inferenceManager.currentModel val selectedModel = conversationService.currentSelectedModel
val systemPrompt = inferenceManager.systemPrompt val systemPrompt = conversationService.systemPrompt
// Messages in conversation // Messages in conversation
private val _messages = MutableStateFlow<List<Message>>(emptyList()) private val _messages = MutableStateFlow<List<Message>>(emptyList())
@ -60,7 +59,7 @@ class ConversationViewModel @Inject constructor(
// Collect response // Collect response
tokenCollectionJob = viewModelScope.launch { tokenCollectionJob = viewModelScope.launch {
try { try {
inferenceManager.generateResponse(content) conversationService.generateResponse(content)
.collect { (text, isComplete) -> .collect { (text, isComplete) ->
updateAssistantMessage(text, isComplete) updateAssistantMessage(text, isComplete)
} }
@ -85,7 +84,7 @@ class ConversationViewModel @Inject constructor(
currentMessages[lastIndex] = Message.Assistant.Completed( currentMessages[lastIndex] = Message.Assistant.Completed(
content = text, content = text,
timestamp = currentAssistantMessage.timestamp, timestamp = currentAssistantMessage.timestamp,
metrics = inferenceManager.createTokenMetrics() metrics = conversationService.createTokenMetrics()
) )
} else { } else {
// Ongoing message update // Ongoing message update
@ -110,7 +109,7 @@ class ConversationViewModel @Inject constructor(
currentMessages[lastIndex] = Message.Assistant.Completed( currentMessages[lastIndex] = Message.Assistant.Completed(
content = "${currentAssistantMessage.content}[Error: ${e.message}]", content = "${currentAssistantMessage.content}[Error: ${e.message}]",
timestamp = currentAssistantMessage.timestamp, timestamp = currentAssistantMessage.timestamp,
metrics = inferenceManager.createTokenMetrics() metrics = conversationService.createTokenMetrics()
) )
_messages.value = currentMessages _messages.value = currentMessages
} }

View File

@ -2,7 +2,7 @@ package com.example.llama.revamp.viewmodel
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import com.example.llama.revamp.engine.InferenceEngine import com.example.llama.revamp.engine.InferenceEngine
import com.example.llama.revamp.engine.InferenceManager import com.example.llama.revamp.engine.InferenceService
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import javax.inject.Inject import javax.inject.Inject
@ -11,20 +11,14 @@ import javax.inject.Inject
* Main ViewModel that expose the core states of [InferenceEngine] * Main ViewModel that expose the core states of [InferenceEngine]
*/ */
class MainViewModel @Inject constructor ( class MainViewModel @Inject constructor (
private val inferenceManager: InferenceManager, private val inferenceService: InferenceService,
) : ViewModel() { ) : ViewModel() {
val engineState = inferenceManager.engineState val engineState = inferenceService.engineState
/** /**
* Unload the current model and release the resources * Unload the current model and release the resources
*/ */
suspend fun unloadModel() = inferenceManager.unloadModel() suspend fun unloadModel() = inferenceService.unloadModel()
companion object {
private val TAG = MainViewModel::class.java.simpleName
private const val SUBSCRIPTION_TIMEOUT_MS = 5000L
}
} }

View File

@ -4,7 +4,7 @@ import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.model.SystemPrompt
import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.data.repository.SystemPromptRepository
import com.example.llama.revamp.engine.InferenceManager import com.example.llama.revamp.engine.ModelLoadingService
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
@ -14,14 +14,14 @@ import javax.inject.Inject
@HiltViewModel @HiltViewModel
class ModelLoadingViewModel @Inject constructor( class ModelLoadingViewModel @Inject constructor(
private val inferenceManager: InferenceManager, private val modelLoadingService: ModelLoadingService,
private val repository: SystemPromptRepository private val repository: SystemPromptRepository
) : ViewModel() { ) : ViewModel() {
/** /**
* Currently selected model to be loaded * Currently selected model to be loaded
*/ */
val selectedModel = inferenceManager.currentModel val selectedModel = modelLoadingService.currentSelectedModel
/** /**
* Preset prompts * Preset prompts
@ -83,13 +83,13 @@ class ModelLoadingViewModel @Inject constructor(
* Prepares the engine for benchmark mode. * Prepares the engine for benchmark mode.
*/ */
suspend fun prepareForBenchmark() = suspend fun prepareForBenchmark() =
inferenceManager.loadModelForBenchmark() modelLoadingService.loadModelForBenchmark()
/** /**
* Prepare for conversation * Prepare for conversation
*/ */
suspend fun prepareForConversation(systemPrompt: String? = null) = suspend fun prepareForConversation(systemPrompt: String? = null) =
inferenceManager.loadModelForConversation(systemPrompt) modelLoadingService.loadModelForConversation(systemPrompt)
companion object { companion object {

View File

@ -4,7 +4,7 @@ import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope import androidx.lifecycle.viewModelScope
import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.model.ModelInfo
import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.data.repository.ModelRepository
import com.example.llama.revamp.engine.InferenceManager import com.example.llama.revamp.engine.InferenceService
import dagger.hilt.android.lifecycle.HiltViewModel import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.SharingStarted
import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.StateFlow
@ -15,7 +15,7 @@ import javax.inject.Inject
@HiltViewModel @HiltViewModel
class ModelSelectionViewModel @Inject constructor( class ModelSelectionViewModel @Inject constructor(
private val inferenceManager: InferenceManager, private val inferenceService: InferenceService,
private val modelRepository: ModelRepository private val modelRepository: ModelRepository
) : ViewModel() { ) : ViewModel() {
@ -29,27 +29,17 @@ class ModelSelectionViewModel @Inject constructor(
initialValue = emptyList() initialValue = emptyList()
) )
/**
* Access to currently selected model
*/
val selectedModel = inferenceManager.currentModel
/** /**
* Select a model and update its last used timestamp * Select a model and update its last used timestamp
*/ */
fun selectModel(modelInfo: ModelInfo) { fun selectModel(modelInfo: ModelInfo) {
inferenceManager.setCurrentModel(modelInfo) inferenceService.setCurrentModel(modelInfo)
viewModelScope.launch { viewModelScope.launch {
modelRepository.updateModelLastUsed(modelInfo.id) modelRepository.updateModelLastUsed(modelInfo.id)
} }
} }
/**
* Unload model when navigating away
*/
suspend fun unloadModel() = inferenceManager.unloadModel()
companion object { companion object {
private val TAG = ModelSelectionViewModel::class.java.simpleName private val TAG = ModelSelectionViewModel::class.java.simpleName