core: throw Exception instead of returning null if model fails to load

This commit is contained in:
Han Yin 2025-04-18 16:44:06 -07:00
parent f313362ced
commit 8a682ff85d
1 changed files with 33 additions and 19 deletions

View File

@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService {
/** /**
* Load a model for benchmark * Load a model for benchmark
*/ */
suspend fun loadModelForBenchmark(): ModelLoadingMetrics? suspend fun loadModelForBenchmark(): ModelLoadingMetrics
/** /**
* Load a model for conversation * Load a model for conversation
*/ */
suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics
} }
interface BenchmarkService : InferenceService { interface BenchmarkService : InferenceService {
@ -119,7 +119,11 @@ internal class InferenceServiceImpl @Inject internal constructor(
private val inferenceEngine: InferenceEngine private val inferenceEngine: InferenceEngine
) : ModelLoadingService, BenchmarkService, ConversationService { ) : ModelLoadingService, BenchmarkService, ConversationService {
/* InferenceService implementation */ /*
*
* InferenceService implementation
*
*/
override val engineState: StateFlow<State> = inferenceEngine.state override val engineState: StateFlow<State> = inferenceEngine.state
@ -135,27 +139,30 @@ internal class InferenceServiceImpl @Inject internal constructor(
*/ */
fun destroy() = inferenceEngine.destroy() fun destroy() = inferenceEngine.destroy()
/*
*
* ModelLoadingService implementation
*
*/
/* ModelLoadingService implementation */ override suspend fun loadModelForBenchmark(): ModelLoadingMetrics =
_currentModel.value?.let { model ->
override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? {
return _currentModel.value?.let { model ->
try { try {
val modelLoadStartTs = System.currentTimeMillis() val modelLoadStartTs = System.currentTimeMillis()
inferenceEngine.loadModel(model.path) inferenceEngine.loadModel(model.path)
val modelLoadEndTs = System.currentTimeMillis() val modelLoadEndTs = System.currentTimeMillis()
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs) ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
} catch (e: Exception) { } catch (e: Exception) {
Log.e("InferenceManager", "Error loading model", e) Log.e(TAG, "Error loading model", e)
null throw e
}
}
} }
} ?: throw IllegalStateException("No model selected!")
override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? { override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics =
_systemPrompt.value = systemPrompt _currentModel.value?.let { model ->
return _currentModel.value?.let { model ->
try { try {
_systemPrompt.value = systemPrompt
val modelLoadStartTs = System.currentTimeMillis() val modelLoadStartTs = System.currentTimeMillis()
inferenceEngine.loadModel(model.path) inferenceEngine.loadModel(model.path)
val modelLoadEndTs = System.currentTimeMillis() val modelLoadEndTs = System.currentTimeMillis()
@ -173,14 +180,17 @@ internal class InferenceServiceImpl @Inject internal constructor(
) )
} }
} catch (e: Exception) { } catch (e: Exception) {
Log.e("InferenceManager", "Error loading model", e) Log.e(TAG, "Error loading model", e)
null throw e
}
}
} }
} ?: throw IllegalStateException("No model selected!")
/* BenchmarkService implementation */ /*
*
* BenchmarkService implementation
*
*/
override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String = override suspend fun benchmark(pp: Int, tg: Int, pl: Int, nr: Int): String =
inferenceEngine.bench(pp, tg, pl, nr).also { inferenceEngine.bench(pp, tg, pl, nr).also {
@ -264,4 +274,8 @@ internal class InferenceServiceImpl @Inject internal constructor(
if (tokens <= 0 || timeMs <= 0) return 0f if (tokens <= 0 || timeMs <= 0) return 0f
return (tokens.toFloat() * 1000f) / timeMs return (tokens.toFloat() * 1000f) / timeMs
} }
companion object {
private val TAG = InferenceServiceImpl::class.java.simpleName
}
} }