diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt index 87e72ea1fd..7d0f791a57 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/InferenceServices.kt @@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService { /** * Load a model for benchmark */ - suspend fun loadModelForBenchmark(): ModelLoadingMetrics? + suspend fun loadModelForBenchmark(): ModelLoadingMetrics /** * Load a model for conversation */ - suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? + suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics } interface BenchmarkService : InferenceService { @@ -119,7 +119,11 @@ internal class InferenceServiceImpl @Inject internal constructor( private val inferenceEngine: InferenceEngine ) : ModelLoadingService, BenchmarkService, ConversationService { - /* InferenceService implementation */ + /* + * + * InferenceService implementation + * + */ override val engineState: StateFlow = inferenceEngine.state @@ -135,27 +139,30 @@ internal class InferenceServiceImpl @Inject internal constructor( */ fun destroy() = inferenceEngine.destroy() + /* + * + * ModelLoadingService implementation + * + */ - /* ModelLoadingService implementation */ - - override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? { - return _currentModel.value?.let { model -> + override suspend fun loadModelForBenchmark(): ModelLoadingMetrics = + _currentModel.value?.let { model -> try { val modelLoadStartTs = System.currentTimeMillis() inferenceEngine.loadModel(model.path) val modelLoadEndTs = System.currentTimeMillis() ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs) } catch (e: Exception) { - Log.e("InferenceManager", "Error loading model", e) - null + Log.e(TAG, "Error loading model", e) + throw e } - } - } + } ?: throw IllegalStateException("No model selected!") - override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? { - _systemPrompt.value = systemPrompt - return _currentModel.value?.let { model -> + override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics = + _currentModel.value?.let { model -> try { + _systemPrompt.value = systemPrompt + val modelLoadStartTs = System.currentTimeMillis() inferenceEngine.loadModel(model.path) val modelLoadEndTs = System.currentTimeMillis() @@ -173,14 +180,17 @@ internal class InferenceServiceImpl @Inject internal constructor( ) } } catch (e: Exception) { - Log.e("InferenceManager", "Error loading model", e) - null + Log.e(TAG, "Error loading model", e) + throw e } - } - } + } ?: throw IllegalStateException("No model selected!") - /* BenchmarkService implementation */ + /* + * + * BenchmarkService implementation + * + */ override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String = inferenceEngine.bench(pp, tg, pl, nr).also { @@ -264,4 +274,8 @@ internal class InferenceServiceImpl @Inject internal constructor( if (tokens <= 0 || timeMs <= 0) return 0f return (tokens.toFloat() * 1000f) / timeMs } + + companion object { + private val TAG = InferenceServiceImpl::class.java.simpleName + } }