lib: optimize engine loader; always perform a fresh detection when cache is null

This commit is contained in:
Han Yin 2025-07-06 16:57:23 -07:00
parent e6413dd05d
commit 4ff924b273
2 changed files with 29 additions and 30 deletions

View File

@ -13,7 +13,7 @@ interface TierDetection {
* Higher tiers provide better performance on supported hardware.
*/
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
NONE(404, "", "No valid Arm optimization available!"),
NONE(-404, "", "No valid Arm optimization available!"),
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),

View File

@ -42,9 +42,7 @@ internal object InferenceEngineLoader {
* Get the detected tier, loading from cache if needed
*/
fun getDetectedTier(context: Context): LLamaTier? =
_detectedTier ?: runBlocking {
loadDetectedTierFromDataStore(context)
}
_detectedTier ?: runBlocking { obtainTier(context) }
/**
* Factory method to get a configured [InferenceEngineImpl] instance.
@ -57,11 +55,7 @@ internal object InferenceEngineLoader {
return runBlocking {
// Obtain the optimal tier from cache if available
val tier = loadDetectedTierFromDataStore(context) ?: run {
Log.i(TAG, "Performing fresh tier detection")
detectAndSaveOptimalTier(context)
}
val tier = obtainTier(context)
if (tier == null || tier == LLamaTier.NONE) {
Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier")
return@runBlocking null
@ -83,14 +77,17 @@ internal object InferenceEngineLoader {
}
/**
* Clear cached detection results (for testing/debugging)
* First attempt to load detected tier from storage, if available;
* Otherwise, perform a fresh detection, then save to storage and cache.
*/
fun clearCache(context: Context) {
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
_cachedInstance = null
_detectedTier = null
Log.i(TAG, "Cleared detection results and cached instance")
}
private suspend fun obtainTier(context: Context) =
loadDetectedTierFromDataStore(context) ?: run {
Log.i(TAG, "Performing fresh tier detection")
performOptimalTierDetection().also { tier ->
tier?.saveToDataStore(context)
_detectedTier = tier
}
}
/**
* Load cached tier from datastore without performing detection
@ -112,18 +109,9 @@ internal object InferenceEngineLoader {
}
/**
* Detect optimal tier and save to cache
* Actual implementation of optimal tier detection via native methods
*/
private suspend fun detectAndSaveOptimalTier(context: Context): LLamaTier? =
detectOptimalTier().also { tier ->
tier?.saveToDataStore(context)
_detectedTier = tier
}
/**
* Detect optimal tier and save to cache
*/
private fun detectOptimalTier(): LLamaTier? {
private fun performOptimalTierDetection(): LLamaTier? {
try {
// Load CPU detection library
System.loadLibrary("llama_cpu_detector")
@ -141,9 +129,10 @@ internal object InferenceEngineLoader {
}
// Ensure we don't exceed maximum supported tier
return if (tier.rawValue > LLamaTier.maxSupportedTier.rawValue) {
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.maxSupportedTier.name}")
LLamaTier.maxSupportedTier
val maxTier = LLamaTier.maxSupportedTier
return if (tier.rawValue > maxTier.rawValue) {
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}")
maxTier
} else {
tier
}
@ -158,6 +147,16 @@ internal object InferenceEngineLoader {
}
}
/**
* Clear cached detection results (for testing/debugging)
*/
fun clearCache(context: Context) {
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
_cachedInstance = null
_detectedTier = null
Log.i(TAG, "Cleared detection results and cached instance")
}
private suspend fun LLamaTier.saveToDataStore(context: Context) {
context.llamaTierDataStore.edit { prefs ->
prefs[DETECTED_TIER] = this.rawValue