lib: optimize engine loader; always perform a fresh detection when cache is null

This commit is contained in:
Han Yin 2025-07-06 16:57:23 -07:00
parent e6413dd05d
commit 4ff924b273
2 changed files with 29 additions and 30 deletions

View File

@ -13,7 +13,7 @@ interface TierDetection {
* Higher tiers provide better performance on supported hardware. * Higher tiers provide better performance on supported hardware.
*/ */
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) { enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
NONE(404, "", "No valid Arm optimization available!"), NONE(-404, "", "No valid Arm optimization available!"),
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"), T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"), T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"), T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),

View File

@ -42,9 +42,7 @@ internal object InferenceEngineLoader {
* Get the detected tier, loading from cache if needed * Get the detected tier, loading from cache if needed
*/ */
fun getDetectedTier(context: Context): LLamaTier? = fun getDetectedTier(context: Context): LLamaTier? =
_detectedTier ?: runBlocking { _detectedTier ?: runBlocking { obtainTier(context) }
loadDetectedTierFromDataStore(context)
}
/** /**
* Factory method to get a configured [InferenceEngineImpl] instance. * Factory method to get a configured [InferenceEngineImpl] instance.
@ -57,11 +55,7 @@ internal object InferenceEngineLoader {
return runBlocking { return runBlocking {
// Obtain the optimal tier from cache if available // Obtain the optimal tier from cache if available
val tier = loadDetectedTierFromDataStore(context) ?: run { val tier = obtainTier(context)
Log.i(TAG, "Performing fresh tier detection")
detectAndSaveOptimalTier(context)
}
if (tier == null || tier == LLamaTier.NONE) { if (tier == null || tier == LLamaTier.NONE) {
Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier") Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier")
return@runBlocking null return@runBlocking null
@ -83,14 +77,17 @@ internal object InferenceEngineLoader {
} }
/** /**
* Clear cached detection results (for testing/debugging) * First attempt to load detected tier from storage, if available;
* Otherwise, perform a fresh detection, then save to storage and cache.
*/ */
fun clearCache(context: Context) { private suspend fun obtainTier(context: Context) =
runBlocking { context.llamaTierDataStore.edit { it.clear() } } loadDetectedTierFromDataStore(context) ?: run {
_cachedInstance = null Log.i(TAG, "Performing fresh tier detection")
_detectedTier = null performOptimalTierDetection().also { tier ->
Log.i(TAG, "Cleared detection results and cached instance") tier?.saveToDataStore(context)
} _detectedTier = tier
}
}
/** /**
* Load cached tier from datastore without performing detection * Load cached tier from datastore without performing detection
@ -112,18 +109,9 @@ internal object InferenceEngineLoader {
} }
/** /**
* Detect optimal tier and save to cache * Actual implementation of optimal tier detection via native methods
*/ */
private suspend fun detectAndSaveOptimalTier(context: Context): LLamaTier? = private fun performOptimalTierDetection(): LLamaTier? {
detectOptimalTier().also { tier ->
tier?.saveToDataStore(context)
_detectedTier = tier
}
/**
* Detect optimal tier and save to cache
*/
private fun detectOptimalTier(): LLamaTier? {
try { try {
// Load CPU detection library // Load CPU detection library
System.loadLibrary("llama_cpu_detector") System.loadLibrary("llama_cpu_detector")
@ -141,9 +129,10 @@ internal object InferenceEngineLoader {
} }
// Ensure we don't exceed maximum supported tier // Ensure we don't exceed maximum supported tier
return if (tier.rawValue > LLamaTier.maxSupportedTier.rawValue) { val maxTier = LLamaTier.maxSupportedTier
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.maxSupportedTier.name}") return if (tier.rawValue > maxTier.rawValue) {
LLamaTier.maxSupportedTier Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}")
maxTier
} else { } else {
tier tier
} }
@ -158,6 +147,16 @@ internal object InferenceEngineLoader {
} }
} }
/**
* Clear cached detection results (for testing/debugging)
*/
fun clearCache(context: Context) {
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
_cachedInstance = null
_detectedTier = null
Log.i(TAG, "Cleared detection results and cached instance")
}
private suspend fun LLamaTier.saveToDataStore(context: Context) { private suspend fun LLamaTier.saveToDataStore(context: Context) {
context.llamaTierDataStore.edit { prefs -> context.llamaTierDataStore.edit { prefs ->
prefs[DETECTED_TIER] = this.rawValue prefs[DETECTED_TIER] = this.rawValue