core: break down InferenceManager due to Interface Segregation Principle
This commit is contained in:
parent
286ed05f13
commit
9cfa74f754
|
|
@ -6,7 +6,12 @@ import com.example.llama.revamp.data.repository.ModelRepository
|
|||
import com.example.llama.revamp.data.repository.ModelRepositoryImpl
|
||||
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
||||
import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl
|
||||
import com.example.llama.revamp.engine.BenchmarkService
|
||||
import com.example.llama.revamp.engine.ConversationService
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.engine.InferenceService
|
||||
import com.example.llama.revamp.engine.InferenceServiceImpl
|
||||
import com.example.llama.revamp.engine.ModelLoadingService
|
||||
import com.example.llama.revamp.monitoring.PerformanceMonitor
|
||||
import dagger.Binds
|
||||
import dagger.Module
|
||||
|
|
@ -18,7 +23,19 @@ import javax.inject.Singleton
|
|||
|
||||
@Module
|
||||
@InstallIn(SingletonComponent::class)
|
||||
abstract class AppModule {
|
||||
internal abstract class AppModule {
|
||||
|
||||
@Binds
|
||||
abstract fun bindInferenceService(impl: InferenceServiceImpl) : InferenceService
|
||||
|
||||
@Binds
|
||||
abstract fun bindModelLoadingService(impl: InferenceServiceImpl) : ModelLoadingService
|
||||
|
||||
@Binds
|
||||
abstract fun bindBenchmarkService(impl: InferenceServiceImpl) : BenchmarkService
|
||||
|
||||
@Binds
|
||||
abstract fun bindConversationService(impl: InferenceServiceImpl) : ConversationService
|
||||
|
||||
@Binds
|
||||
abstract fun bindsModelsRepository(impl: ModelRepositoryImpl): ModelRepository
|
||||
|
|
|
|||
|
|
@ -10,41 +10,124 @@ import kotlinx.coroutines.flow.flow
|
|||
import javax.inject.Inject
|
||||
import javax.inject.Singleton
|
||||
|
||||
@Singleton
|
||||
class InferenceManager @Inject constructor(
|
||||
private val inferenceEngine: InferenceEngine
|
||||
) {
|
||||
// Expose engine state
|
||||
val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
|
||||
interface InferenceService {
|
||||
/**
|
||||
* Expose engine state
|
||||
*/
|
||||
val engineState: StateFlow<InferenceEngine.State>
|
||||
|
||||
// Benchmark results
|
||||
val benchmarkResults: StateFlow<String?> = inferenceEngine.benchmarkResults
|
||||
|
||||
// Currently loaded model
|
||||
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
|
||||
val currentModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
|
||||
|
||||
// System prompt
|
||||
private val _systemPrompt = MutableStateFlow<String?>(null)
|
||||
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
|
||||
|
||||
// Token metrics tracking
|
||||
private var generationStartTime: Long = 0L
|
||||
private var firstTokenTime: Long = 0L
|
||||
private var tokenCount: Int = 0
|
||||
private var isFirstToken: Boolean = true
|
||||
/**
|
||||
* Currently selected model
|
||||
*/
|
||||
val currentSelectedModel: StateFlow<ModelInfo?>
|
||||
|
||||
/**
|
||||
* Set current model
|
||||
*/
|
||||
fun setCurrentModel(model: ModelInfo) {
|
||||
_currentModel.value = model
|
||||
fun setCurrentModel(model: ModelInfo)
|
||||
|
||||
/**
|
||||
* Unload current model and free resources
|
||||
*/
|
||||
suspend fun unloadModel()
|
||||
}
|
||||
|
||||
interface ModelLoadingService : InferenceService {
|
||||
/**
|
||||
* Load a model for benchmark
|
||||
*/
|
||||
suspend fun loadModelForBenchmark(): Boolean {
|
||||
suspend fun loadModelForBenchmark(): Boolean
|
||||
|
||||
/**
|
||||
* Load a model for conversation
|
||||
*/
|
||||
suspend fun loadModelForConversation(systemPrompt: String?): Boolean
|
||||
}
|
||||
|
||||
interface BenchmarkService : InferenceService {
|
||||
/**
|
||||
* Run benchmark
|
||||
*
|
||||
* @param pp: Prompt Processing size
|
||||
* @param tg: Token Generation size
|
||||
* @param pl: Parallel sequences
|
||||
* @param nr: repetitions (Number of Runs)
|
||||
*/
|
||||
suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String
|
||||
|
||||
/**
|
||||
* Benchmark results
|
||||
*/
|
||||
val results: StateFlow<String?>
|
||||
}
|
||||
|
||||
interface ConversationService : InferenceService {
|
||||
/**
|
||||
* System prompt
|
||||
*/
|
||||
val systemPrompt: StateFlow<String?>
|
||||
|
||||
/**
|
||||
* Generate response from prompt
|
||||
*/
|
||||
fun generateResponse(prompt: String): Flow<GenerationUpdate>
|
||||
|
||||
/**
|
||||
* Create token metrics based on current state
|
||||
*/
|
||||
fun createTokenMetrics(): TokenMetrics
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents an update during text generation
|
||||
*/
|
||||
data class GenerationUpdate(
|
||||
val text: String,
|
||||
val isComplete: Boolean
|
||||
)
|
||||
|
||||
/**
|
||||
* Metrics for token generation performance
|
||||
*/
|
||||
data class TokenMetrics(
|
||||
val tokensCount: Int,
|
||||
val ttftMs: Long,
|
||||
val tpsMs: Float,
|
||||
) {
|
||||
val text: String
|
||||
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal implementation of the above [InferenceService]s
|
||||
*/
|
||||
@Singleton
|
||||
internal class InferenceServiceImpl @Inject internal constructor(
|
||||
private val inferenceEngine: InferenceEngine
|
||||
) : ModelLoadingService, BenchmarkService, ConversationService {
|
||||
|
||||
/* InferenceService implementation */
|
||||
|
||||
override val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
|
||||
|
||||
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
|
||||
override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
|
||||
|
||||
override fun setCurrentModel(model: ModelInfo) {
|
||||
_currentModel.value = model
|
||||
}
|
||||
|
||||
override suspend fun unloadModel() = inferenceEngine.unloadModel()
|
||||
|
||||
/**
|
||||
* Shut down inference engine
|
||||
*/
|
||||
fun destroy() = inferenceEngine.destroy()
|
||||
|
||||
|
||||
/* ModelLoadingService implementation */
|
||||
|
||||
override suspend fun loadModelForBenchmark(): Boolean {
|
||||
return _currentModel.value?.let { model ->
|
||||
try {
|
||||
inferenceEngine.loadModel(model.path)
|
||||
|
|
@ -53,13 +136,10 @@ class InferenceManager @Inject constructor(
|
|||
Log.e("InferenceManager", "Error loading model", e)
|
||||
false
|
||||
}
|
||||
} ?: false
|
||||
} == true
|
||||
}
|
||||
|
||||
/**
|
||||
* Load a model for conversation
|
||||
*/
|
||||
suspend fun loadModelForConversation(systemPrompt: String? = null): Boolean {
|
||||
override suspend fun loadModelForConversation(systemPrompt: String?): Boolean {
|
||||
_systemPrompt.value = systemPrompt
|
||||
return _currentModel.value?.let { model ->
|
||||
try {
|
||||
|
|
@ -69,23 +149,30 @@ class InferenceManager @Inject constructor(
|
|||
Log.e("InferenceManager", "Error loading model", e)
|
||||
false
|
||||
}
|
||||
} ?: false
|
||||
} == true
|
||||
}
|
||||
|
||||
/**
|
||||
* Run benchmark
|
||||
*/
|
||||
suspend fun benchmark(
|
||||
pp: Int = 512,
|
||||
tg: Int = 128,
|
||||
pl: Int = 1,
|
||||
nr: Int = 3
|
||||
): String = inferenceEngine.bench(pp, tg, pl, nr)
|
||||
|
||||
/**
|
||||
* Generate response from prompt
|
||||
*/
|
||||
fun generateResponse(prompt: String): Flow<Pair<String, Boolean>> = flow {
|
||||
/* BenchmarkService implementation */
|
||||
|
||||
override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
||||
inferenceEngine.bench(pp, tg, pl, nr)
|
||||
|
||||
override val results: StateFlow<String?> = inferenceEngine.benchmarkResults
|
||||
|
||||
|
||||
/* ConversationService implementation */
|
||||
|
||||
private val _systemPrompt = MutableStateFlow<String?>(null)
|
||||
override val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
|
||||
|
||||
// Token metrics tracking
|
||||
private var generationStartTime: Long = 0L
|
||||
private var firstTokenTime: Long = 0L
|
||||
private var tokenCount: Int = 0
|
||||
private var isFirstToken: Boolean = true
|
||||
|
||||
override fun generateResponse(prompt: String): Flow<GenerationUpdate> = flow {
|
||||
try {
|
||||
// Reset metrics tracking
|
||||
generationStartTime = System.currentTimeMillis()
|
||||
|
|
@ -111,14 +198,14 @@ class InferenceManager @Inject constructor(
|
|||
response.append(token)
|
||||
|
||||
// Emit ongoing response (not completed)
|
||||
emit(Pair(response.toString(), false))
|
||||
emit(GenerationUpdate(response.toString(), false))
|
||||
}
|
||||
|
||||
// Calculate final metrics after completion
|
||||
val metrics = createTokenMetrics()
|
||||
|
||||
// Emit final response with completion flag
|
||||
emit(Pair(response.toString(), true))
|
||||
emit(GenerationUpdate(response.toString(), true))
|
||||
} catch (e: Exception) {
|
||||
// Emit error
|
||||
val metrics = createTokenMetrics()
|
||||
|
|
@ -126,10 +213,7 @@ class InferenceManager @Inject constructor(
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create token metrics based on current state
|
||||
*/
|
||||
fun createTokenMetrics(): TokenMetrics {
|
||||
override fun createTokenMetrics(): TokenMetrics {
|
||||
val endTime = System.currentTimeMillis()
|
||||
val totalTimeMs = endTime - generationStartTime
|
||||
|
||||
|
|
@ -147,23 +231,4 @@ class InferenceManager @Inject constructor(
|
|||
if (tokens <= 0 || timeMs <= 0) return 0f
|
||||
return (tokens.toFloat() * 1000f) / timeMs
|
||||
}
|
||||
|
||||
/**
|
||||
* Unload current model
|
||||
*/
|
||||
suspend fun unloadModel() = inferenceEngine.unloadModel()
|
||||
|
||||
/**
|
||||
* Cleanup resources
|
||||
*/
|
||||
fun destroy() = inferenceEngine.destroy()
|
||||
}
|
||||
|
||||
data class TokenMetrics(
|
||||
val tokensCount: Int,
|
||||
val ttftMs: Long,
|
||||
val tpsMs: Float,
|
||||
) {
|
||||
val text: String
|
||||
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
|
||||
}
|
||||
|
|
@ -3,8 +3,8 @@ package com.example.llama.revamp.viewmodel
|
|||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.engine.BenchmarkService
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.engine.InferenceManager
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.launch
|
||||
|
|
@ -12,18 +12,18 @@ import javax.inject.Inject
|
|||
|
||||
@HiltViewModel
|
||||
class BenchmarkViewModel @Inject constructor(
|
||||
private val inferenceManager: InferenceManager
|
||||
private val benchmarkService: BenchmarkService
|
||||
) : ViewModel() {
|
||||
|
||||
val engineState: StateFlow<InferenceEngine.State> = inferenceManager.engineState
|
||||
val benchmarkResults: StateFlow<String?> = inferenceManager.benchmarkResults
|
||||
val selectedModel: StateFlow<ModelInfo?> = inferenceManager.currentModel
|
||||
val engineState: StateFlow<InferenceEngine.State> = benchmarkService.engineState
|
||||
val benchmarkResults: StateFlow<String?> = benchmarkService.results
|
||||
val selectedModel: StateFlow<ModelInfo?> = benchmarkService.currentSelectedModel
|
||||
|
||||
/**
|
||||
* Run benchmark with specified parameters
|
||||
*/
|
||||
fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) =
|
||||
viewModelScope.launch {
|
||||
inferenceManager.benchmark(pp, tg, pl, nr)
|
||||
benchmarkService.benchmark(pp, tg, pl, nr)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package com.example.llama.revamp.viewmodel
|
|||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.engine.InferenceManager
|
||||
import com.example.llama.revamp.engine.ConversationService
|
||||
import com.example.llama.revamp.engine.TokenMetrics
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.Job
|
||||
|
|
@ -14,17 +14,16 @@ import java.text.SimpleDateFormat
|
|||
import java.util.Date
|
||||
import java.util.Locale
|
||||
import javax.inject.Inject
|
||||
import kotlin.getValue
|
||||
|
||||
|
||||
@HiltViewModel
|
||||
class ConversationViewModel @Inject constructor(
|
||||
private val inferenceManager: InferenceManager
|
||||
private val conversationService: ConversationService
|
||||
) : ViewModel() {
|
||||
|
||||
val engineState = inferenceManager.engineState
|
||||
val selectedModel = inferenceManager.currentModel
|
||||
val systemPrompt = inferenceManager.systemPrompt
|
||||
val engineState = conversationService.engineState
|
||||
val selectedModel = conversationService.currentSelectedModel
|
||||
val systemPrompt = conversationService.systemPrompt
|
||||
|
||||
// Messages in conversation
|
||||
private val _messages = MutableStateFlow<List<Message>>(emptyList())
|
||||
|
|
@ -60,7 +59,7 @@ class ConversationViewModel @Inject constructor(
|
|||
// Collect response
|
||||
tokenCollectionJob = viewModelScope.launch {
|
||||
try {
|
||||
inferenceManager.generateResponse(content)
|
||||
conversationService.generateResponse(content)
|
||||
.collect { (text, isComplete) ->
|
||||
updateAssistantMessage(text, isComplete)
|
||||
}
|
||||
|
|
@ -85,7 +84,7 @@ class ConversationViewModel @Inject constructor(
|
|||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||
content = text,
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = inferenceManager.createTokenMetrics()
|
||||
metrics = conversationService.createTokenMetrics()
|
||||
)
|
||||
} else {
|
||||
// Ongoing message update
|
||||
|
|
@ -110,7 +109,7 @@ class ConversationViewModel @Inject constructor(
|
|||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||
content = "${currentAssistantMessage.content}[Error: ${e.message}]",
|
||||
timestamp = currentAssistantMessage.timestamp,
|
||||
metrics = inferenceManager.createTokenMetrics()
|
||||
metrics = conversationService.createTokenMetrics()
|
||||
)
|
||||
_messages.value = currentMessages
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package com.example.llama.revamp.viewmodel
|
|||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import com.example.llama.revamp.engine.InferenceEngine
|
||||
import com.example.llama.revamp.engine.InferenceManager
|
||||
import com.example.llama.revamp.engine.InferenceService
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import javax.inject.Inject
|
||||
|
||||
|
|
@ -11,20 +11,14 @@ import javax.inject.Inject
|
|||
* Main ViewModel that expose the core states of [InferenceEngine]
|
||||
*/
|
||||
class MainViewModel @Inject constructor (
|
||||
private val inferenceManager: InferenceManager,
|
||||
private val inferenceService: InferenceService,
|
||||
) : ViewModel() {
|
||||
|
||||
val engineState = inferenceManager.engineState
|
||||
val engineState = inferenceService.engineState
|
||||
|
||||
/**
|
||||
* Unload the current model and release the resources
|
||||
*/
|
||||
suspend fun unloadModel() = inferenceManager.unloadModel()
|
||||
|
||||
companion object {
|
||||
private val TAG = MainViewModel::class.java.simpleName
|
||||
|
||||
private const val SUBSCRIPTION_TIMEOUT_MS = 5000L
|
||||
}
|
||||
suspend fun unloadModel() = inferenceService.unloadModel()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import androidx.lifecycle.ViewModel
|
|||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.data.model.SystemPrompt
|
||||
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
||||
import com.example.llama.revamp.engine.InferenceManager
|
||||
import com.example.llama.revamp.engine.ModelLoadingService
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.flow.SharingStarted
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
|
|
@ -14,14 +14,14 @@ import javax.inject.Inject
|
|||
|
||||
@HiltViewModel
|
||||
class ModelLoadingViewModel @Inject constructor(
|
||||
private val inferenceManager: InferenceManager,
|
||||
private val modelLoadingService: ModelLoadingService,
|
||||
private val repository: SystemPromptRepository
|
||||
) : ViewModel() {
|
||||
|
||||
/**
|
||||
* Currently selected model to be loaded
|
||||
*/
|
||||
val selectedModel = inferenceManager.currentModel
|
||||
val selectedModel = modelLoadingService.currentSelectedModel
|
||||
|
||||
/**
|
||||
* Preset prompts
|
||||
|
|
@ -83,13 +83,13 @@ class ModelLoadingViewModel @Inject constructor(
|
|||
* Prepares the engine for benchmark mode.
|
||||
*/
|
||||
suspend fun prepareForBenchmark() =
|
||||
inferenceManager.loadModelForBenchmark()
|
||||
modelLoadingService.loadModelForBenchmark()
|
||||
|
||||
/**
|
||||
* Prepare for conversation
|
||||
*/
|
||||
suspend fun prepareForConversation(systemPrompt: String? = null) =
|
||||
inferenceManager.loadModelForConversation(systemPrompt)
|
||||
modelLoadingService.loadModelForConversation(systemPrompt)
|
||||
|
||||
|
||||
companion object {
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import androidx.lifecycle.ViewModel
|
|||
import androidx.lifecycle.viewModelScope
|
||||
import com.example.llama.revamp.data.model.ModelInfo
|
||||
import com.example.llama.revamp.data.repository.ModelRepository
|
||||
import com.example.llama.revamp.engine.InferenceManager
|
||||
import com.example.llama.revamp.engine.InferenceService
|
||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||
import kotlinx.coroutines.flow.SharingStarted
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
|
|
@ -15,7 +15,7 @@ import javax.inject.Inject
|
|||
|
||||
@HiltViewModel
|
||||
class ModelSelectionViewModel @Inject constructor(
|
||||
private val inferenceManager: InferenceManager,
|
||||
private val inferenceService: InferenceService,
|
||||
private val modelRepository: ModelRepository
|
||||
) : ViewModel() {
|
||||
|
||||
|
|
@ -29,27 +29,17 @@ class ModelSelectionViewModel @Inject constructor(
|
|||
initialValue = emptyList()
|
||||
)
|
||||
|
||||
/**
|
||||
* Access to currently selected model
|
||||
*/
|
||||
val selectedModel = inferenceManager.currentModel
|
||||
|
||||
/**
|
||||
* Select a model and update its last used timestamp
|
||||
*/
|
||||
fun selectModel(modelInfo: ModelInfo) {
|
||||
inferenceManager.setCurrentModel(modelInfo)
|
||||
inferenceService.setCurrentModel(modelInfo)
|
||||
|
||||
viewModelScope.launch {
|
||||
modelRepository.updateModelLastUsed(modelInfo.id)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unload model when navigating away
|
||||
*/
|
||||
suspend fun unloadModel() = inferenceManager.unloadModel()
|
||||
|
||||
companion object {
|
||||
private val TAG = ModelSelectionViewModel::class.java.simpleName
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue