core: break down InferenceManager due to Interface Segregation Principle
This commit is contained in:
parent
286ed05f13
commit
9cfa74f754
|
|
@ -6,7 +6,12 @@ import com.example.llama.revamp.data.repository.ModelRepository
|
||||||
import com.example.llama.revamp.data.repository.ModelRepositoryImpl
|
import com.example.llama.revamp.data.repository.ModelRepositoryImpl
|
||||||
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
||||||
import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl
|
import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl
|
||||||
|
import com.example.llama.revamp.engine.BenchmarkService
|
||||||
|
import com.example.llama.revamp.engine.ConversationService
|
||||||
import com.example.llama.revamp.engine.InferenceEngine
|
import com.example.llama.revamp.engine.InferenceEngine
|
||||||
|
import com.example.llama.revamp.engine.InferenceService
|
||||||
|
import com.example.llama.revamp.engine.InferenceServiceImpl
|
||||||
|
import com.example.llama.revamp.engine.ModelLoadingService
|
||||||
import com.example.llama.revamp.monitoring.PerformanceMonitor
|
import com.example.llama.revamp.monitoring.PerformanceMonitor
|
||||||
import dagger.Binds
|
import dagger.Binds
|
||||||
import dagger.Module
|
import dagger.Module
|
||||||
|
|
@ -18,7 +23,19 @@ import javax.inject.Singleton
|
||||||
|
|
||||||
@Module
|
@Module
|
||||||
@InstallIn(SingletonComponent::class)
|
@InstallIn(SingletonComponent::class)
|
||||||
abstract class AppModule {
|
internal abstract class AppModule {
|
||||||
|
|
||||||
|
@Binds
|
||||||
|
abstract fun bindInferenceService(impl: InferenceServiceImpl) : InferenceService
|
||||||
|
|
||||||
|
@Binds
|
||||||
|
abstract fun bindModelLoadingService(impl: InferenceServiceImpl) : ModelLoadingService
|
||||||
|
|
||||||
|
@Binds
|
||||||
|
abstract fun bindBenchmarkService(impl: InferenceServiceImpl) : BenchmarkService
|
||||||
|
|
||||||
|
@Binds
|
||||||
|
abstract fun bindConversationService(impl: InferenceServiceImpl) : ConversationService
|
||||||
|
|
||||||
@Binds
|
@Binds
|
||||||
abstract fun bindsModelsRepository(impl: ModelRepositoryImpl): ModelRepository
|
abstract fun bindsModelsRepository(impl: ModelRepositoryImpl): ModelRepository
|
||||||
|
|
|
||||||
|
|
@ -10,41 +10,124 @@ import kotlinx.coroutines.flow.flow
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
import javax.inject.Singleton
|
import javax.inject.Singleton
|
||||||
|
|
||||||
@Singleton
|
interface InferenceService {
|
||||||
class InferenceManager @Inject constructor(
|
/**
|
||||||
private val inferenceEngine: InferenceEngine
|
* Expose engine state
|
||||||
) {
|
*/
|
||||||
// Expose engine state
|
val engineState: StateFlow<InferenceEngine.State>
|
||||||
val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
|
|
||||||
|
|
||||||
// Benchmark results
|
/**
|
||||||
val benchmarkResults: StateFlow<String?> = inferenceEngine.benchmarkResults
|
* Currently selected model
|
||||||
|
*/
|
||||||
// Currently loaded model
|
val currentSelectedModel: StateFlow<ModelInfo?>
|
||||||
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
|
|
||||||
val currentModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
|
|
||||||
|
|
||||||
// System prompt
|
|
||||||
private val _systemPrompt = MutableStateFlow<String?>(null)
|
|
||||||
val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
|
|
||||||
|
|
||||||
// Token metrics tracking
|
|
||||||
private var generationStartTime: Long = 0L
|
|
||||||
private var firstTokenTime: Long = 0L
|
|
||||||
private var tokenCount: Int = 0
|
|
||||||
private var isFirstToken: Boolean = true
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set current model
|
* Set current model
|
||||||
*/
|
*/
|
||||||
fun setCurrentModel(model: ModelInfo) {
|
fun setCurrentModel(model: ModelInfo)
|
||||||
_currentModel.value = model
|
|
||||||
|
/**
|
||||||
|
* Unload current model and free resources
|
||||||
|
*/
|
||||||
|
suspend fun unloadModel()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface ModelLoadingService : InferenceService {
|
||||||
/**
|
/**
|
||||||
* Load a model for benchmark
|
* Load a model for benchmark
|
||||||
*/
|
*/
|
||||||
suspend fun loadModelForBenchmark(): Boolean {
|
suspend fun loadModelForBenchmark(): Boolean
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load a model for conversation
|
||||||
|
*/
|
||||||
|
suspend fun loadModelForConversation(systemPrompt: String?): Boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
interface BenchmarkService : InferenceService {
|
||||||
|
/**
|
||||||
|
* Run benchmark
|
||||||
|
*
|
||||||
|
* @param pp: Prompt Processing size
|
||||||
|
* @param tg: Token Generation size
|
||||||
|
* @param pl: Parallel sequences
|
||||||
|
* @param nr: repetitions (Number of Runs)
|
||||||
|
*/
|
||||||
|
suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Benchmark results
|
||||||
|
*/
|
||||||
|
val results: StateFlow<String?>
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ConversationService : InferenceService {
|
||||||
|
/**
|
||||||
|
* System prompt
|
||||||
|
*/
|
||||||
|
val systemPrompt: StateFlow<String?>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate response from prompt
|
||||||
|
*/
|
||||||
|
fun generateResponse(prompt: String): Flow<GenerationUpdate>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create token metrics based on current state
|
||||||
|
*/
|
||||||
|
fun createTokenMetrics(): TokenMetrics
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Represents an update during text generation
|
||||||
|
*/
|
||||||
|
data class GenerationUpdate(
|
||||||
|
val text: String,
|
||||||
|
val isComplete: Boolean
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Metrics for token generation performance
|
||||||
|
*/
|
||||||
|
data class TokenMetrics(
|
||||||
|
val tokensCount: Int,
|
||||||
|
val ttftMs: Long,
|
||||||
|
val tpsMs: Float,
|
||||||
|
) {
|
||||||
|
val text: String
|
||||||
|
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Internal implementation of the above [InferenceService]s
|
||||||
|
*/
|
||||||
|
@Singleton
|
||||||
|
internal class InferenceServiceImpl @Inject internal constructor(
|
||||||
|
private val inferenceEngine: InferenceEngine
|
||||||
|
) : ModelLoadingService, BenchmarkService, ConversationService {
|
||||||
|
|
||||||
|
/* InferenceService implementation */
|
||||||
|
|
||||||
|
override val engineState: StateFlow<InferenceEngine.State> = inferenceEngine.state
|
||||||
|
|
||||||
|
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
|
||||||
|
override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
|
||||||
|
|
||||||
|
override fun setCurrentModel(model: ModelInfo) {
|
||||||
|
_currentModel.value = model
|
||||||
|
}
|
||||||
|
|
||||||
|
override suspend fun unloadModel() = inferenceEngine.unloadModel()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Shut down inference engine
|
||||||
|
*/
|
||||||
|
fun destroy() = inferenceEngine.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
/* ModelLoadingService implementation */
|
||||||
|
|
||||||
|
override suspend fun loadModelForBenchmark(): Boolean {
|
||||||
return _currentModel.value?.let { model ->
|
return _currentModel.value?.let { model ->
|
||||||
try {
|
try {
|
||||||
inferenceEngine.loadModel(model.path)
|
inferenceEngine.loadModel(model.path)
|
||||||
|
|
@ -53,13 +136,10 @@ class InferenceManager @Inject constructor(
|
||||||
Log.e("InferenceManager", "Error loading model", e)
|
Log.e("InferenceManager", "Error loading model", e)
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
} ?: false
|
} == true
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
override suspend fun loadModelForConversation(systemPrompt: String?): Boolean {
|
||||||
* Load a model for conversation
|
|
||||||
*/
|
|
||||||
suspend fun loadModelForConversation(systemPrompt: String? = null): Boolean {
|
|
||||||
_systemPrompt.value = systemPrompt
|
_systemPrompt.value = systemPrompt
|
||||||
return _currentModel.value?.let { model ->
|
return _currentModel.value?.let { model ->
|
||||||
try {
|
try {
|
||||||
|
|
@ -69,23 +149,30 @@ class InferenceManager @Inject constructor(
|
||||||
Log.e("InferenceManager", "Error loading model", e)
|
Log.e("InferenceManager", "Error loading model", e)
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
} ?: false
|
} == true
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Run benchmark
|
|
||||||
*/
|
|
||||||
suspend fun benchmark(
|
|
||||||
pp: Int = 512,
|
|
||||||
tg: Int = 128,
|
|
||||||
pl: Int = 1,
|
|
||||||
nr: Int = 3
|
|
||||||
): String = inferenceEngine.bench(pp, tg, pl, nr)
|
|
||||||
|
|
||||||
/**
|
/* BenchmarkService implementation */
|
||||||
* Generate response from prompt
|
|
||||||
*/
|
override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
||||||
fun generateResponse(prompt: String): Flow<Pair<String, Boolean>> = flow {
|
inferenceEngine.bench(pp, tg, pl, nr)
|
||||||
|
|
||||||
|
override val results: StateFlow<String?> = inferenceEngine.benchmarkResults
|
||||||
|
|
||||||
|
|
||||||
|
/* ConversationService implementation */
|
||||||
|
|
||||||
|
private val _systemPrompt = MutableStateFlow<String?>(null)
|
||||||
|
override val systemPrompt: StateFlow<String?> = _systemPrompt.asStateFlow()
|
||||||
|
|
||||||
|
// Token metrics tracking
|
||||||
|
private var generationStartTime: Long = 0L
|
||||||
|
private var firstTokenTime: Long = 0L
|
||||||
|
private var tokenCount: Int = 0
|
||||||
|
private var isFirstToken: Boolean = true
|
||||||
|
|
||||||
|
override fun generateResponse(prompt: String): Flow<GenerationUpdate> = flow {
|
||||||
try {
|
try {
|
||||||
// Reset metrics tracking
|
// Reset metrics tracking
|
||||||
generationStartTime = System.currentTimeMillis()
|
generationStartTime = System.currentTimeMillis()
|
||||||
|
|
@ -111,14 +198,14 @@ class InferenceManager @Inject constructor(
|
||||||
response.append(token)
|
response.append(token)
|
||||||
|
|
||||||
// Emit ongoing response (not completed)
|
// Emit ongoing response (not completed)
|
||||||
emit(Pair(response.toString(), false))
|
emit(GenerationUpdate(response.toString(), false))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate final metrics after completion
|
// Calculate final metrics after completion
|
||||||
val metrics = createTokenMetrics()
|
val metrics = createTokenMetrics()
|
||||||
|
|
||||||
// Emit final response with completion flag
|
// Emit final response with completion flag
|
||||||
emit(Pair(response.toString(), true))
|
emit(GenerationUpdate(response.toString(), true))
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
// Emit error
|
// Emit error
|
||||||
val metrics = createTokenMetrics()
|
val metrics = createTokenMetrics()
|
||||||
|
|
@ -126,10 +213,7 @@ class InferenceManager @Inject constructor(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
override fun createTokenMetrics(): TokenMetrics {
|
||||||
* Create token metrics based on current state
|
|
||||||
*/
|
|
||||||
fun createTokenMetrics(): TokenMetrics {
|
|
||||||
val endTime = System.currentTimeMillis()
|
val endTime = System.currentTimeMillis()
|
||||||
val totalTimeMs = endTime - generationStartTime
|
val totalTimeMs = endTime - generationStartTime
|
||||||
|
|
||||||
|
|
@ -147,23 +231,4 @@ class InferenceManager @Inject constructor(
|
||||||
if (tokens <= 0 || timeMs <= 0) return 0f
|
if (tokens <= 0 || timeMs <= 0) return 0f
|
||||||
return (tokens.toFloat() * 1000f) / timeMs
|
return (tokens.toFloat() * 1000f) / timeMs
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Unload current model
|
|
||||||
*/
|
|
||||||
suspend fun unloadModel() = inferenceEngine.unloadModel()
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Cleanup resources
|
|
||||||
*/
|
|
||||||
fun destroy() = inferenceEngine.destroy()
|
|
||||||
}
|
|
||||||
|
|
||||||
data class TokenMetrics(
|
|
||||||
val tokensCount: Int,
|
|
||||||
val ttftMs: Long,
|
|
||||||
val tpsMs: Float,
|
|
||||||
) {
|
|
||||||
val text: String
|
|
||||||
get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}"
|
|
||||||
}
|
}
|
||||||
|
|
@ -3,8 +3,8 @@ package com.example.llama.revamp.viewmodel
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.example.llama.revamp.data.model.ModelInfo
|
import com.example.llama.revamp.data.model.ModelInfo
|
||||||
|
import com.example.llama.revamp.engine.BenchmarkService
|
||||||
import com.example.llama.revamp.engine.InferenceEngine
|
import com.example.llama.revamp.engine.InferenceEngine
|
||||||
import com.example.llama.revamp.engine.InferenceManager
|
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
|
@ -12,18 +12,18 @@ import javax.inject.Inject
|
||||||
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class BenchmarkViewModel @Inject constructor(
|
class BenchmarkViewModel @Inject constructor(
|
||||||
private val inferenceManager: InferenceManager
|
private val benchmarkService: BenchmarkService
|
||||||
) : ViewModel() {
|
) : ViewModel() {
|
||||||
|
|
||||||
val engineState: StateFlow<InferenceEngine.State> = inferenceManager.engineState
|
val engineState: StateFlow<InferenceEngine.State> = benchmarkService.engineState
|
||||||
val benchmarkResults: StateFlow<String?> = inferenceManager.benchmarkResults
|
val benchmarkResults: StateFlow<String?> = benchmarkService.results
|
||||||
val selectedModel: StateFlow<ModelInfo?> = inferenceManager.currentModel
|
val selectedModel: StateFlow<ModelInfo?> = benchmarkService.currentSelectedModel
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Run benchmark with specified parameters
|
* Run benchmark with specified parameters
|
||||||
*/
|
*/
|
||||||
fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) =
|
fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) =
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
inferenceManager.benchmark(pp, tg, pl, nr)
|
benchmarkService.benchmark(pp, tg, pl, nr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ package com.example.llama.revamp.viewmodel
|
||||||
|
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.example.llama.revamp.engine.InferenceManager
|
import com.example.llama.revamp.engine.ConversationService
|
||||||
import com.example.llama.revamp.engine.TokenMetrics
|
import com.example.llama.revamp.engine.TokenMetrics
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import kotlinx.coroutines.Job
|
import kotlinx.coroutines.Job
|
||||||
|
|
@ -14,17 +14,16 @@ import java.text.SimpleDateFormat
|
||||||
import java.util.Date
|
import java.util.Date
|
||||||
import java.util.Locale
|
import java.util.Locale
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
import kotlin.getValue
|
|
||||||
|
|
||||||
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class ConversationViewModel @Inject constructor(
|
class ConversationViewModel @Inject constructor(
|
||||||
private val inferenceManager: InferenceManager
|
private val conversationService: ConversationService
|
||||||
) : ViewModel() {
|
) : ViewModel() {
|
||||||
|
|
||||||
val engineState = inferenceManager.engineState
|
val engineState = conversationService.engineState
|
||||||
val selectedModel = inferenceManager.currentModel
|
val selectedModel = conversationService.currentSelectedModel
|
||||||
val systemPrompt = inferenceManager.systemPrompt
|
val systemPrompt = conversationService.systemPrompt
|
||||||
|
|
||||||
// Messages in conversation
|
// Messages in conversation
|
||||||
private val _messages = MutableStateFlow<List<Message>>(emptyList())
|
private val _messages = MutableStateFlow<List<Message>>(emptyList())
|
||||||
|
|
@ -60,7 +59,7 @@ class ConversationViewModel @Inject constructor(
|
||||||
// Collect response
|
// Collect response
|
||||||
tokenCollectionJob = viewModelScope.launch {
|
tokenCollectionJob = viewModelScope.launch {
|
||||||
try {
|
try {
|
||||||
inferenceManager.generateResponse(content)
|
conversationService.generateResponse(content)
|
||||||
.collect { (text, isComplete) ->
|
.collect { (text, isComplete) ->
|
||||||
updateAssistantMessage(text, isComplete)
|
updateAssistantMessage(text, isComplete)
|
||||||
}
|
}
|
||||||
|
|
@ -85,7 +84,7 @@ class ConversationViewModel @Inject constructor(
|
||||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||||
content = text,
|
content = text,
|
||||||
timestamp = currentAssistantMessage.timestamp,
|
timestamp = currentAssistantMessage.timestamp,
|
||||||
metrics = inferenceManager.createTokenMetrics()
|
metrics = conversationService.createTokenMetrics()
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
// Ongoing message update
|
// Ongoing message update
|
||||||
|
|
@ -110,7 +109,7 @@ class ConversationViewModel @Inject constructor(
|
||||||
currentMessages[lastIndex] = Message.Assistant.Completed(
|
currentMessages[lastIndex] = Message.Assistant.Completed(
|
||||||
content = "${currentAssistantMessage.content}[Error: ${e.message}]",
|
content = "${currentAssistantMessage.content}[Error: ${e.message}]",
|
||||||
timestamp = currentAssistantMessage.timestamp,
|
timestamp = currentAssistantMessage.timestamp,
|
||||||
metrics = inferenceManager.createTokenMetrics()
|
metrics = conversationService.createTokenMetrics()
|
||||||
)
|
)
|
||||||
_messages.value = currentMessages
|
_messages.value = currentMessages
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ package com.example.llama.revamp.viewmodel
|
||||||
|
|
||||||
import androidx.lifecycle.ViewModel
|
import androidx.lifecycle.ViewModel
|
||||||
import com.example.llama.revamp.engine.InferenceEngine
|
import com.example.llama.revamp.engine.InferenceEngine
|
||||||
import com.example.llama.revamp.engine.InferenceManager
|
import com.example.llama.revamp.engine.InferenceService
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import javax.inject.Inject
|
import javax.inject.Inject
|
||||||
|
|
||||||
|
|
@ -11,20 +11,14 @@ import javax.inject.Inject
|
||||||
* Main ViewModel that expose the core states of [InferenceEngine]
|
* Main ViewModel that expose the core states of [InferenceEngine]
|
||||||
*/
|
*/
|
||||||
class MainViewModel @Inject constructor (
|
class MainViewModel @Inject constructor (
|
||||||
private val inferenceManager: InferenceManager,
|
private val inferenceService: InferenceService,
|
||||||
) : ViewModel() {
|
) : ViewModel() {
|
||||||
|
|
||||||
val engineState = inferenceManager.engineState
|
val engineState = inferenceService.engineState
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Unload the current model and release the resources
|
* Unload the current model and release the resources
|
||||||
*/
|
*/
|
||||||
suspend fun unloadModel() = inferenceManager.unloadModel()
|
suspend fun unloadModel() = inferenceService.unloadModel()
|
||||||
|
|
||||||
companion object {
|
|
||||||
private val TAG = MainViewModel::class.java.simpleName
|
|
||||||
|
|
||||||
private const val SUBSCRIPTION_TIMEOUT_MS = 5000L
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.example.llama.revamp.data.model.SystemPrompt
|
import com.example.llama.revamp.data.model.SystemPrompt
|
||||||
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
import com.example.llama.revamp.data.repository.SystemPromptRepository
|
||||||
import com.example.llama.revamp.engine.InferenceManager
|
import com.example.llama.revamp.engine.ModelLoadingService
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import kotlinx.coroutines.flow.SharingStarted
|
import kotlinx.coroutines.flow.SharingStarted
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
|
@ -14,14 +14,14 @@ import javax.inject.Inject
|
||||||
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class ModelLoadingViewModel @Inject constructor(
|
class ModelLoadingViewModel @Inject constructor(
|
||||||
private val inferenceManager: InferenceManager,
|
private val modelLoadingService: ModelLoadingService,
|
||||||
private val repository: SystemPromptRepository
|
private val repository: SystemPromptRepository
|
||||||
) : ViewModel() {
|
) : ViewModel() {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Currently selected model to be loaded
|
* Currently selected model to be loaded
|
||||||
*/
|
*/
|
||||||
val selectedModel = inferenceManager.currentModel
|
val selectedModel = modelLoadingService.currentSelectedModel
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Preset prompts
|
* Preset prompts
|
||||||
|
|
@ -83,13 +83,13 @@ class ModelLoadingViewModel @Inject constructor(
|
||||||
* Prepares the engine for benchmark mode.
|
* Prepares the engine for benchmark mode.
|
||||||
*/
|
*/
|
||||||
suspend fun prepareForBenchmark() =
|
suspend fun prepareForBenchmark() =
|
||||||
inferenceManager.loadModelForBenchmark()
|
modelLoadingService.loadModelForBenchmark()
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Prepare for conversation
|
* Prepare for conversation
|
||||||
*/
|
*/
|
||||||
suspend fun prepareForConversation(systemPrompt: String? = null) =
|
suspend fun prepareForConversation(systemPrompt: String? = null) =
|
||||||
inferenceManager.loadModelForConversation(systemPrompt)
|
modelLoadingService.loadModelForConversation(systemPrompt)
|
||||||
|
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import androidx.lifecycle.ViewModel
|
||||||
import androidx.lifecycle.viewModelScope
|
import androidx.lifecycle.viewModelScope
|
||||||
import com.example.llama.revamp.data.model.ModelInfo
|
import com.example.llama.revamp.data.model.ModelInfo
|
||||||
import com.example.llama.revamp.data.repository.ModelRepository
|
import com.example.llama.revamp.data.repository.ModelRepository
|
||||||
import com.example.llama.revamp.engine.InferenceManager
|
import com.example.llama.revamp.engine.InferenceService
|
||||||
import dagger.hilt.android.lifecycle.HiltViewModel
|
import dagger.hilt.android.lifecycle.HiltViewModel
|
||||||
import kotlinx.coroutines.flow.SharingStarted
|
import kotlinx.coroutines.flow.SharingStarted
|
||||||
import kotlinx.coroutines.flow.StateFlow
|
import kotlinx.coroutines.flow.StateFlow
|
||||||
|
|
@ -15,7 +15,7 @@ import javax.inject.Inject
|
||||||
|
|
||||||
@HiltViewModel
|
@HiltViewModel
|
||||||
class ModelSelectionViewModel @Inject constructor(
|
class ModelSelectionViewModel @Inject constructor(
|
||||||
private val inferenceManager: InferenceManager,
|
private val inferenceService: InferenceService,
|
||||||
private val modelRepository: ModelRepository
|
private val modelRepository: ModelRepository
|
||||||
) : ViewModel() {
|
) : ViewModel() {
|
||||||
|
|
||||||
|
|
@ -29,27 +29,17 @@ class ModelSelectionViewModel @Inject constructor(
|
||||||
initialValue = emptyList()
|
initialValue = emptyList()
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
|
||||||
* Access to currently selected model
|
|
||||||
*/
|
|
||||||
val selectedModel = inferenceManager.currentModel
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Select a model and update its last used timestamp
|
* Select a model and update its last used timestamp
|
||||||
*/
|
*/
|
||||||
fun selectModel(modelInfo: ModelInfo) {
|
fun selectModel(modelInfo: ModelInfo) {
|
||||||
inferenceManager.setCurrentModel(modelInfo)
|
inferenceService.setCurrentModel(modelInfo)
|
||||||
|
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
modelRepository.updateModelLastUsed(modelInfo.id)
|
modelRepository.updateModelLastUsed(modelInfo.id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Unload model when navigating away
|
|
||||||
*/
|
|
||||||
suspend fun unloadModel() = inferenceManager.unloadModel()
|
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private val TAG = ModelSelectionViewModel::class.java.simpleName
|
private val TAG = ModelSelectionViewModel::class.java.simpleName
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue