lib: optimize engine loader; always perform a fresh detection when cache is null
This commit is contained in:
parent
e6413dd05d
commit
4ff924b273
|
|
@ -13,7 +13,7 @@ interface TierDetection {
|
||||||
* Higher tiers provide better performance on supported hardware.
|
* Higher tiers provide better performance on supported hardware.
|
||||||
*/
|
*/
|
||||||
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
|
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
|
||||||
NONE(404, "", "No valid Arm optimization available!"),
|
NONE(-404, "", "No valid Arm optimization available!"),
|
||||||
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
|
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
|
||||||
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
|
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
|
||||||
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
|
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
|
||||||
|
|
|
||||||
|
|
@ -42,9 +42,7 @@ internal object InferenceEngineLoader {
|
||||||
* Get the detected tier, loading from cache if needed
|
* Get the detected tier, loading from cache if needed
|
||||||
*/
|
*/
|
||||||
fun getDetectedTier(context: Context): LLamaTier? =
|
fun getDetectedTier(context: Context): LLamaTier? =
|
||||||
_detectedTier ?: runBlocking {
|
_detectedTier ?: runBlocking { obtainTier(context) }
|
||||||
loadDetectedTierFromDataStore(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Factory method to get a configured [InferenceEngineImpl] instance.
|
* Factory method to get a configured [InferenceEngineImpl] instance.
|
||||||
|
|
@ -57,11 +55,7 @@ internal object InferenceEngineLoader {
|
||||||
|
|
||||||
return runBlocking {
|
return runBlocking {
|
||||||
// Obtain the optimal tier from cache if available
|
// Obtain the optimal tier from cache if available
|
||||||
val tier = loadDetectedTierFromDataStore(context) ?: run {
|
val tier = obtainTier(context)
|
||||||
Log.i(TAG, "Performing fresh tier detection")
|
|
||||||
detectAndSaveOptimalTier(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tier == null || tier == LLamaTier.NONE) {
|
if (tier == null || tier == LLamaTier.NONE) {
|
||||||
Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier")
|
Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier")
|
||||||
return@runBlocking null
|
return@runBlocking null
|
||||||
|
|
@ -83,14 +77,17 @@ internal object InferenceEngineLoader {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Clear cached detection results (for testing/debugging)
|
* First attempt to load detected tier from storage, if available;
|
||||||
|
* Otherwise, perform a fresh detection, then save to storage and cache.
|
||||||
*/
|
*/
|
||||||
fun clearCache(context: Context) {
|
private suspend fun obtainTier(context: Context) =
|
||||||
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
|
loadDetectedTierFromDataStore(context) ?: run {
|
||||||
_cachedInstance = null
|
Log.i(TAG, "Performing fresh tier detection")
|
||||||
_detectedTier = null
|
performOptimalTierDetection().also { tier ->
|
||||||
Log.i(TAG, "Cleared detection results and cached instance")
|
tier?.saveToDataStore(context)
|
||||||
}
|
_detectedTier = tier
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load cached tier from datastore without performing detection
|
* Load cached tier from datastore without performing detection
|
||||||
|
|
@ -112,18 +109,9 @@ internal object InferenceEngineLoader {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Detect optimal tier and save to cache
|
* Actual implementation of optimal tier detection via native methods
|
||||||
*/
|
*/
|
||||||
private suspend fun detectAndSaveOptimalTier(context: Context): LLamaTier? =
|
private fun performOptimalTierDetection(): LLamaTier? {
|
||||||
detectOptimalTier().also { tier ->
|
|
||||||
tier?.saveToDataStore(context)
|
|
||||||
_detectedTier = tier
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Detect optimal tier and save to cache
|
|
||||||
*/
|
|
||||||
private fun detectOptimalTier(): LLamaTier? {
|
|
||||||
try {
|
try {
|
||||||
// Load CPU detection library
|
// Load CPU detection library
|
||||||
System.loadLibrary("llama_cpu_detector")
|
System.loadLibrary("llama_cpu_detector")
|
||||||
|
|
@ -141,9 +129,10 @@ internal object InferenceEngineLoader {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we don't exceed maximum supported tier
|
// Ensure we don't exceed maximum supported tier
|
||||||
return if (tier.rawValue > LLamaTier.maxSupportedTier.rawValue) {
|
val maxTier = LLamaTier.maxSupportedTier
|
||||||
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.maxSupportedTier.name}")
|
return if (tier.rawValue > maxTier.rawValue) {
|
||||||
LLamaTier.maxSupportedTier
|
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}")
|
||||||
|
maxTier
|
||||||
} else {
|
} else {
|
||||||
tier
|
tier
|
||||||
}
|
}
|
||||||
|
|
@ -158,6 +147,16 @@ internal object InferenceEngineLoader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clear cached detection results (for testing/debugging)
|
||||||
|
*/
|
||||||
|
fun clearCache(context: Context) {
|
||||||
|
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
|
||||||
|
_cachedInstance = null
|
||||||
|
_detectedTier = null
|
||||||
|
Log.i(TAG, "Cleared detection results and cached instance")
|
||||||
|
}
|
||||||
|
|
||||||
private suspend fun LLamaTier.saveToDataStore(context: Context) {
|
private suspend fun LLamaTier.saveToDataStore(context: Context) {
|
||||||
context.llamaTierDataStore.edit { prefs ->
|
context.llamaTierDataStore.edit { prefs ->
|
||||||
prefs[DETECTED_TIER] = this.rawValue
|
prefs[DETECTED_TIER] = this.rawValue
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue