diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt index 73e43cf6bb..af62d972ef 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt @@ -13,7 +13,7 @@ interface TierDetection { * Higher tiers provide better performance on supported hardware. */ enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) { - NONE(404, "", "No valid Arm optimization available!"), + NONE(-404, "", "No valid Arm optimization available!"), T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"), T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"), T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"), diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt index c73ef5b9cb..e01edf2e5f 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt @@ -42,9 +42,7 @@ internal object InferenceEngineLoader { * Get the detected tier, loading from cache if needed */ fun getDetectedTier(context: Context): LLamaTier? = - _detectedTier ?: runBlocking { - loadDetectedTierFromDataStore(context) - } + _detectedTier ?: runBlocking { obtainTier(context) } /** * Factory method to get a configured [InferenceEngineImpl] instance. @@ -57,11 +55,7 @@ internal object InferenceEngineLoader { return runBlocking { // Obtain the optimal tier from cache if available - val tier = loadDetectedTierFromDataStore(context) ?: run { - Log.i(TAG, "Performing fresh tier detection") - detectAndSaveOptimalTier(context) - } - + val tier = obtainTier(context) if (tier == null || tier == LLamaTier.NONE) { Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier") return@runBlocking null @@ -83,14 +77,17 @@ internal object InferenceEngineLoader { } /** - * Clear cached detection results (for testing/debugging) + * First attempt to load detected tier from storage, if available; + * Otherwise, perform a fresh detection, then save to storage and cache. */ - fun clearCache(context: Context) { - runBlocking { context.llamaTierDataStore.edit { it.clear() } } - _cachedInstance = null - _detectedTier = null - Log.i(TAG, "Cleared detection results and cached instance") - } + private suspend fun obtainTier(context: Context) = + loadDetectedTierFromDataStore(context) ?: run { + Log.i(TAG, "Performing fresh tier detection") + performOptimalTierDetection().also { tier -> + tier?.saveToDataStore(context) + _detectedTier = tier + } + } /** * Load cached tier from datastore without performing detection @@ -112,18 +109,9 @@ internal object InferenceEngineLoader { } /** - * Detect optimal tier and save to cache + * Actual implementation of optimal tier detection via native methods */ - private suspend fun detectAndSaveOptimalTier(context: Context): LLamaTier? = - detectOptimalTier().also { tier -> - tier?.saveToDataStore(context) - _detectedTier = tier - } - - /** - * Detect optimal tier and save to cache - */ - private fun detectOptimalTier(): LLamaTier? { + private fun performOptimalTierDetection(): LLamaTier? { try { // Load CPU detection library System.loadLibrary("llama_cpu_detector") @@ -141,9 +129,10 @@ internal object InferenceEngineLoader { } // Ensure we don't exceed maximum supported tier - return if (tier.rawValue > LLamaTier.maxSupportedTier.rawValue) { - Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.maxSupportedTier.name}") - LLamaTier.maxSupportedTier + val maxTier = LLamaTier.maxSupportedTier + return if (tier.rawValue > maxTier.rawValue) { + Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}") + maxTier } else { tier } @@ -158,6 +147,16 @@ internal object InferenceEngineLoader { } } + /** + * Clear cached detection results (for testing/debugging) + */ + fun clearCache(context: Context) { + runBlocking { context.llamaTierDataStore.edit { it.clear() } } + _cachedInstance = null + _detectedTier = null + Log.i(TAG, "Cleared detection results and cached instance") + } + private suspend fun LLamaTier.saveToDataStore(context: Context) { context.llamaTierDataStore.edit { prefs -> prefs[DETECTED_TIER] = this.rawValue