lib: optimize engine loader; always perform a fresh detection when cache is null
This commit is contained in:
parent
e6413dd05d
commit
4ff924b273
|
|
@ -13,7 +13,7 @@ interface TierDetection {
|
|||
* Higher tiers provide better performance on supported hardware.
|
||||
*/
|
||||
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
|
||||
NONE(404, "", "No valid Arm optimization available!"),
|
||||
NONE(-404, "", "No valid Arm optimization available!"),
|
||||
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
|
||||
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
|
||||
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
|
||||
|
|
|
|||
|
|
@ -42,9 +42,7 @@ internal object InferenceEngineLoader {
|
|||
* Get the detected tier, loading from cache if needed
|
||||
*/
|
||||
fun getDetectedTier(context: Context): LLamaTier? =
|
||||
_detectedTier ?: runBlocking {
|
||||
loadDetectedTierFromDataStore(context)
|
||||
}
|
||||
_detectedTier ?: runBlocking { obtainTier(context) }
|
||||
|
||||
/**
|
||||
* Factory method to get a configured [InferenceEngineImpl] instance.
|
||||
|
|
@ -57,11 +55,7 @@ internal object InferenceEngineLoader {
|
|||
|
||||
return runBlocking {
|
||||
// Obtain the optimal tier from cache if available
|
||||
val tier = loadDetectedTierFromDataStore(context) ?: run {
|
||||
Log.i(TAG, "Performing fresh tier detection")
|
||||
detectAndSaveOptimalTier(context)
|
||||
}
|
||||
|
||||
val tier = obtainTier(context)
|
||||
if (tier == null || tier == LLamaTier.NONE) {
|
||||
Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier")
|
||||
return@runBlocking null
|
||||
|
|
@ -83,14 +77,17 @@ internal object InferenceEngineLoader {
|
|||
}
|
||||
|
||||
/**
|
||||
* Clear cached detection results (for testing/debugging)
|
||||
* First attempt to load detected tier from storage, if available;
|
||||
* Otherwise, perform a fresh detection, then save to storage and cache.
|
||||
*/
|
||||
fun clearCache(context: Context) {
|
||||
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
|
||||
_cachedInstance = null
|
||||
_detectedTier = null
|
||||
Log.i(TAG, "Cleared detection results and cached instance")
|
||||
}
|
||||
private suspend fun obtainTier(context: Context) =
|
||||
loadDetectedTierFromDataStore(context) ?: run {
|
||||
Log.i(TAG, "Performing fresh tier detection")
|
||||
performOptimalTierDetection().also { tier ->
|
||||
tier?.saveToDataStore(context)
|
||||
_detectedTier = tier
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load cached tier from datastore without performing detection
|
||||
|
|
@ -112,18 +109,9 @@ internal object InferenceEngineLoader {
|
|||
}
|
||||
|
||||
/**
|
||||
* Detect optimal tier and save to cache
|
||||
* Actual implementation of optimal tier detection via native methods
|
||||
*/
|
||||
private suspend fun detectAndSaveOptimalTier(context: Context): LLamaTier? =
|
||||
detectOptimalTier().also { tier ->
|
||||
tier?.saveToDataStore(context)
|
||||
_detectedTier = tier
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect optimal tier and save to cache
|
||||
*/
|
||||
private fun detectOptimalTier(): LLamaTier? {
|
||||
private fun performOptimalTierDetection(): LLamaTier? {
|
||||
try {
|
||||
// Load CPU detection library
|
||||
System.loadLibrary("llama_cpu_detector")
|
||||
|
|
@ -141,9 +129,10 @@ internal object InferenceEngineLoader {
|
|||
}
|
||||
|
||||
// Ensure we don't exceed maximum supported tier
|
||||
return if (tier.rawValue > LLamaTier.maxSupportedTier.rawValue) {
|
||||
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.maxSupportedTier.name}")
|
||||
LLamaTier.maxSupportedTier
|
||||
val maxTier = LLamaTier.maxSupportedTier
|
||||
return if (tier.rawValue > maxTier.rawValue) {
|
||||
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}")
|
||||
maxTier
|
||||
} else {
|
||||
tier
|
||||
}
|
||||
|
|
@ -158,6 +147,16 @@ internal object InferenceEngineLoader {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear cached detection results (for testing/debugging)
|
||||
*/
|
||||
fun clearCache(context: Context) {
|
||||
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
|
||||
_cachedInstance = null
|
||||
_detectedTier = null
|
||||
Log.i(TAG, "Cleared detection results and cached instance")
|
||||
}
|
||||
|
||||
private suspend fun LLamaTier.saveToDataStore(context: Context) {
|
||||
context.llamaTierDataStore.edit { prefs ->
|
||||
prefs[DETECTED_TIER] = this.rawValue
|
||||
|
|
|
|||
Loading…
Reference in New Issue