diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt index 70b20e3cdf..fef6c26ea2 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt @@ -6,7 +6,12 @@ import com.example.llama.revamp.data.repository.ModelRepository import com.example.llama.revamp.data.repository.ModelRepositoryImpl import com.example.llama.revamp.data.repository.SystemPromptRepository import com.example.llama.revamp.data.repository.SystemPromptRepositoryImpl +import com.example.llama.revamp.engine.BenchmarkService +import com.example.llama.revamp.engine.ConversationService import com.example.llama.revamp.engine.InferenceEngine +import com.example.llama.revamp.engine.InferenceService +import com.example.llama.revamp.engine.InferenceServiceImpl +import com.example.llama.revamp.engine.ModelLoadingService import com.example.llama.revamp.monitoring.PerformanceMonitor import dagger.Binds import dagger.Module @@ -18,7 +23,19 @@ import javax.inject.Singleton @Module @InstallIn(SingletonComponent::class) -abstract class AppModule { +internal abstract class AppModule { + + @Binds + abstract fun bindInferenceService(impl: InferenceServiceImpl) : InferenceService + + @Binds + abstract fun bindModelLoadingService(impl: InferenceServiceImpl) : ModelLoadingService + + @Binds + abstract fun bindBenchmarkService(impl: InferenceServiceImpl) : BenchmarkService + + @Binds + abstract fun bindConversationService(impl: InferenceServiceImpl) : ConversationService @Binds abstract fun bindsModelsRepository(impl: ModelRepositoryImpl): ModelRepository diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceManager.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt similarity index 57% rename from examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceManager.kt rename to examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt index e0b36eec5e..3da3065224 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceManager.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt @@ -10,41 +10,124 @@ import kotlinx.coroutines.flow.flow import javax.inject.Inject import javax.inject.Singleton -@Singleton -class InferenceManager @Inject constructor( - private val inferenceEngine: InferenceEngine -) { - // Expose engine state - val engineState: StateFlow = inferenceEngine.state +interface InferenceService { + /** + * Expose engine state + */ + val engineState: StateFlow - // Benchmark results - val benchmarkResults: StateFlow = inferenceEngine.benchmarkResults - - // Currently loaded model - private val _currentModel = MutableStateFlow(null) - val currentModel: StateFlow = _currentModel.asStateFlow() - - // System prompt - private val _systemPrompt = MutableStateFlow(null) - val systemPrompt: StateFlow = _systemPrompt.asStateFlow() - - // Token metrics tracking - private var generationStartTime: Long = 0L - private var firstTokenTime: Long = 0L - private var tokenCount: Int = 0 - private var isFirstToken: Boolean = true + /** + * Currently selected model + */ + val currentSelectedModel: StateFlow /** * Set current model */ - fun setCurrentModel(model: ModelInfo) { - _currentModel.value = model - } + fun setCurrentModel(model: ModelInfo) + /** + * Unload current model and free resources + */ + suspend fun unloadModel() +} + +interface ModelLoadingService : InferenceService { /** * Load a model for benchmark */ - suspend fun loadModelForBenchmark(): Boolean { + suspend fun loadModelForBenchmark(): Boolean + + /** + * Load a model for conversation + */ + suspend fun loadModelForConversation(systemPrompt: String?): Boolean +} + +interface BenchmarkService : InferenceService { + /** + * Run benchmark + * + * @param pp: Prompt Processing size + * @param tg: Token Generation size + * @param pl: Parallel sequences + * @param nr: repetitions (Number of Runs) + */ + suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String + + /** + * Benchmark results + */ + val results: StateFlow +} + +interface ConversationService : InferenceService { + /** + * System prompt + */ + val systemPrompt: StateFlow + + /** + * Generate response from prompt + */ + fun generateResponse(prompt: String): Flow + + /** + * Create token metrics based on current state + */ + fun createTokenMetrics(): TokenMetrics +} + +/** + * Represents an update during text generation + */ +data class GenerationUpdate( + val text: String, + val isComplete: Boolean +) + +/** + * Metrics for token generation performance + */ +data class TokenMetrics( + val tokensCount: Int, + val ttftMs: Long, + val tpsMs: Float, +) { + val text: String + get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}" +} + +/** + * Internal implementation of the above [InferenceService]s + */ +@Singleton +internal class InferenceServiceImpl @Inject internal constructor( + private val inferenceEngine: InferenceEngine +) : ModelLoadingService, BenchmarkService, ConversationService { + + /* InferenceService implementation */ + + override val engineState: StateFlow = inferenceEngine.state + + private val _currentModel = MutableStateFlow(null) + override val currentSelectedModel: StateFlow = _currentModel.asStateFlow() + + override fun setCurrentModel(model: ModelInfo) { + _currentModel.value = model + } + + override suspend fun unloadModel() = inferenceEngine.unloadModel() + + /** + * Shut down inference engine + */ + fun destroy() = inferenceEngine.destroy() + + + /* ModelLoadingService implementation */ + + override suspend fun loadModelForBenchmark(): Boolean { return _currentModel.value?.let { model -> try { inferenceEngine.loadModel(model.path) @@ -53,13 +136,10 @@ class InferenceManager @Inject constructor( Log.e("InferenceManager", "Error loading model", e) false } - } ?: false + } == true } - /** - * Load a model for conversation - */ - suspend fun loadModelForConversation(systemPrompt: String? = null): Boolean { + override suspend fun loadModelForConversation(systemPrompt: String?): Boolean { _systemPrompt.value = systemPrompt return _currentModel.value?.let { model -> try { @@ -69,23 +149,30 @@ class InferenceManager @Inject constructor( Log.e("InferenceManager", "Error loading model", e) false } - } ?: false + } == true } - /** - * Run benchmark - */ - suspend fun benchmark( - pp: Int = 512, - tg: Int = 128, - pl: Int = 1, - nr: Int = 3 - ): String = inferenceEngine.bench(pp, tg, pl, nr) - /** - * Generate response from prompt - */ - fun generateResponse(prompt: String): Flow> = flow { + /* BenchmarkService implementation */ + + override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String = + inferenceEngine.bench(pp, tg, pl, nr) + + override val results: StateFlow = inferenceEngine.benchmarkResults + + + /* ConversationService implementation */ + + private val _systemPrompt = MutableStateFlow(null) + override val systemPrompt: StateFlow = _systemPrompt.asStateFlow() + + // Token metrics tracking + private var generationStartTime: Long = 0L + private var firstTokenTime: Long = 0L + private var tokenCount: Int = 0 + private var isFirstToken: Boolean = true + + override fun generateResponse(prompt: String): Flow = flow { try { // Reset metrics tracking generationStartTime = System.currentTimeMillis() @@ -111,14 +198,14 @@ class InferenceManager @Inject constructor( response.append(token) // Emit ongoing response (not completed) - emit(Pair(response.toString(), false)) + emit(GenerationUpdate(response.toString(), false)) } // Calculate final metrics after completion val metrics = createTokenMetrics() // Emit final response with completion flag - emit(Pair(response.toString(), true)) + emit(GenerationUpdate(response.toString(), true)) } catch (e: Exception) { // Emit error val metrics = createTokenMetrics() @@ -126,10 +213,7 @@ class InferenceManager @Inject constructor( } } - /** - * Create token metrics based on current state - */ - fun createTokenMetrics(): TokenMetrics { + override fun createTokenMetrics(): TokenMetrics { val endTime = System.currentTimeMillis() val totalTimeMs = endTime - generationStartTime @@ -147,23 +231,4 @@ class InferenceManager @Inject constructor( if (tokens <= 0 || timeMs <= 0) return 0f return (tokens.toFloat() * 1000f) / timeMs } - - /** - * Unload current model - */ - suspend fun unloadModel() = inferenceEngine.unloadModel() - - /** - * Cleanup resources - */ - fun destroy() = inferenceEngine.destroy() -} - -data class TokenMetrics( - val tokensCount: Int, - val ttftMs: Long, - val tpsMs: Float, -) { - val text: String - get() = "Tokens: $tokensCount, TTFT: ${ttftMs}ms, TPS: ${"%.1f".format(tpsMs)}" } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt index 179d5eb159..d4427ad3e8 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/BenchmarkViewModel.kt @@ -3,8 +3,8 @@ package com.example.llama.revamp.viewmodel import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.ModelInfo +import com.example.llama.revamp.engine.BenchmarkService import com.example.llama.revamp.engine.InferenceEngine -import com.example.llama.revamp.engine.InferenceManager import dagger.hilt.android.lifecycle.HiltViewModel import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.launch @@ -12,18 +12,18 @@ import javax.inject.Inject @HiltViewModel class BenchmarkViewModel @Inject constructor( - private val inferenceManager: InferenceManager + private val benchmarkService: BenchmarkService ) : ViewModel() { - val engineState: StateFlow = inferenceManager.engineState - val benchmarkResults: StateFlow = inferenceManager.benchmarkResults - val selectedModel: StateFlow = inferenceManager.currentModel + val engineState: StateFlow = benchmarkService.engineState + val benchmarkResults: StateFlow = benchmarkService.results + val selectedModel: StateFlow = benchmarkService.currentSelectedModel /** * Run benchmark with specified parameters */ fun runBenchmark(pp: Int = 512, tg: Int = 128, pl: Int = 1, nr: Int = 3) = viewModelScope.launch { - inferenceManager.benchmark(pp, tg, pl, nr) + benchmarkService.benchmark(pp, tg, pl, nr) } } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ConversationViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ConversationViewModel.kt index fd8e243f1d..c1598ada7d 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ConversationViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ConversationViewModel.kt @@ -2,7 +2,7 @@ package com.example.llama.revamp.viewmodel import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope -import com.example.llama.revamp.engine.InferenceManager +import com.example.llama.revamp.engine.ConversationService import com.example.llama.revamp.engine.TokenMetrics import dagger.hilt.android.lifecycle.HiltViewModel import kotlinx.coroutines.Job @@ -14,17 +14,16 @@ import java.text.SimpleDateFormat import java.util.Date import java.util.Locale import javax.inject.Inject -import kotlin.getValue @HiltViewModel class ConversationViewModel @Inject constructor( - private val inferenceManager: InferenceManager + private val conversationService: ConversationService ) : ViewModel() { - val engineState = inferenceManager.engineState - val selectedModel = inferenceManager.currentModel - val systemPrompt = inferenceManager.systemPrompt + val engineState = conversationService.engineState + val selectedModel = conversationService.currentSelectedModel + val systemPrompt = conversationService.systemPrompt // Messages in conversation private val _messages = MutableStateFlow>(emptyList()) @@ -60,7 +59,7 @@ class ConversationViewModel @Inject constructor( // Collect response tokenCollectionJob = viewModelScope.launch { try { - inferenceManager.generateResponse(content) + conversationService.generateResponse(content) .collect { (text, isComplete) -> updateAssistantMessage(text, isComplete) } @@ -85,7 +84,7 @@ class ConversationViewModel @Inject constructor( currentMessages[lastIndex] = Message.Assistant.Completed( content = text, timestamp = currentAssistantMessage.timestamp, - metrics = inferenceManager.createTokenMetrics() + metrics = conversationService.createTokenMetrics() ) } else { // Ongoing message update @@ -110,7 +109,7 @@ class ConversationViewModel @Inject constructor( currentMessages[lastIndex] = Message.Assistant.Completed( content = "${currentAssistantMessage.content}[Error: ${e.message}]", timestamp = currentAssistantMessage.timestamp, - metrics = inferenceManager.createTokenMetrics() + metrics = conversationService.createTokenMetrics() ) _messages.value = currentMessages } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt index ae709896d7..5ef1608a5a 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/MainViewModel.kt @@ -2,7 +2,7 @@ package com.example.llama.revamp.viewmodel import androidx.lifecycle.ViewModel import com.example.llama.revamp.engine.InferenceEngine -import com.example.llama.revamp.engine.InferenceManager +import com.example.llama.revamp.engine.InferenceService import dagger.hilt.android.lifecycle.HiltViewModel import javax.inject.Inject @@ -11,20 +11,14 @@ import javax.inject.Inject * Main ViewModel that expose the core states of [InferenceEngine] */ class MainViewModel @Inject constructor ( - private val inferenceManager: InferenceManager, + private val inferenceService: InferenceService, ) : ViewModel() { - val engineState = inferenceManager.engineState + val engineState = inferenceService.engineState /** * Unload the current model and release the resources */ - suspend fun unloadModel() = inferenceManager.unloadModel() - - companion object { - private val TAG = MainViewModel::class.java.simpleName - - private const val SUBSCRIPTION_TIMEOUT_MS = 5000L - } + suspend fun unloadModel() = inferenceService.unloadModel() } diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt index 8d4fe42ba7..46efcc6942 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelLoadingViewModel.kt @@ -4,7 +4,7 @@ import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.SystemPrompt import com.example.llama.revamp.data.repository.SystemPromptRepository -import com.example.llama.revamp.engine.InferenceManager +import com.example.llama.revamp.engine.ModelLoadingService import dagger.hilt.android.lifecycle.HiltViewModel import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow @@ -14,14 +14,14 @@ import javax.inject.Inject @HiltViewModel class ModelLoadingViewModel @Inject constructor( - private val inferenceManager: InferenceManager, + private val modelLoadingService: ModelLoadingService, private val repository: SystemPromptRepository ) : ViewModel() { /** * Currently selected model to be loaded */ - val selectedModel = inferenceManager.currentModel + val selectedModel = modelLoadingService.currentSelectedModel /** * Preset prompts @@ -83,13 +83,13 @@ class ModelLoadingViewModel @Inject constructor( * Prepares the engine for benchmark mode. */ suspend fun prepareForBenchmark() = - inferenceManager.loadModelForBenchmark() + modelLoadingService.loadModelForBenchmark() /** * Prepare for conversation */ suspend fun prepareForConversation(systemPrompt: String? = null) = - inferenceManager.loadModelForConversation(systemPrompt) + modelLoadingService.loadModelForConversation(systemPrompt) companion object { diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt index f39d024b8c..dd68bfc63e 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/viewmodel/ModelSelectionViewModel.kt @@ -4,7 +4,7 @@ import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.example.llama.revamp.data.model.ModelInfo import com.example.llama.revamp.data.repository.ModelRepository -import com.example.llama.revamp.engine.InferenceManager +import com.example.llama.revamp.engine.InferenceService import dagger.hilt.android.lifecycle.HiltViewModel import kotlinx.coroutines.flow.SharingStarted import kotlinx.coroutines.flow.StateFlow @@ -15,7 +15,7 @@ import javax.inject.Inject @HiltViewModel class ModelSelectionViewModel @Inject constructor( - private val inferenceManager: InferenceManager, + private val inferenceService: InferenceService, private val modelRepository: ModelRepository ) : ViewModel() { @@ -29,27 +29,17 @@ class ModelSelectionViewModel @Inject constructor( initialValue = emptyList() ) - /** - * Access to currently selected model - */ - val selectedModel = inferenceManager.currentModel - /** * Select a model and update its last used timestamp */ fun selectModel(modelInfo: ModelInfo) { - inferenceManager.setCurrentModel(modelInfo) + inferenceService.setCurrentModel(modelInfo) viewModelScope.launch { modelRepository.updateModelLastUsed(modelInfo.id) } } - /** - * Unload model when navigating away - */ - suspend fun unloadModel() = inferenceManager.unloadModel() - companion object { private val TAG = ModelSelectionViewModel::class.java.simpleName