From 1f41ae2315248d00bcafdc20a0b15673693a79d9 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Thu, 3 Jul 2025 13:33:05 -0700 Subject: [PATCH] lib: refactored InferenceEngineLoader; added a `NONE` Llama Tier --- .../java/android/llama/cpp/ArmFeatures.kt | 6 +- .../java/android/llama/cpp/KleidiLlama.kt | 2 +- .../java/android/llama/cpp/TierDetection.kt | 1 + .../cpp/internal/InferenceEngineFactory.kt | 2 +- .../llama/cpp/internal/InferenceEngineImpl.kt | 20 +-- .../cpp/internal/InferenceEngineLoader.kt | 119 +++++++++--------- .../llama/cpp/internal/TierDetectionImpl.kt | 2 +- 7 files changed, 78 insertions(+), 74 deletions(-) diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt index 4fdaff5c51..d049e803e4 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt @@ -65,14 +65,14 @@ object ArmFeaturesMapper { * Maps a LLamaTier to its supported ARM features. * Returns a list of booleans where each index corresponds to allFeatures. */ - fun getSupportedFeatures(tier: LLamaTier?): List = + fun getSupportedFeatures(tier: LLamaTier?): List? = when (tier) { + LLamaTier.NONE, null -> null // No tier detected LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE // TODO-han.yin: implement T4 once obtaining an Android device with SME! - null -> listOf(false, false, false, false, false) // No tier detected } /** @@ -83,7 +83,7 @@ object ArmFeaturesMapper { allFeatures.mapIndexed { index, feature -> DisplayItem( feature = feature, - isSupported = flags.getOrElse(index) { false } + isSupported = flags?.getOrElse(index) { false } == true ) } } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt index 8a0b264ee4..2adabe7b91 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt @@ -12,7 +12,7 @@ object KleidiLlama { * Create an inference engine instance with automatic tier detection. */ fun createInferenceEngine(context: Context) = - InferenceEngineFactory.createInstance(context) + InferenceEngineFactory.getInstance(context) /** * Get tier detection information for debugging/settings. diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt index 1d5c6566ea..73e43cf6bb 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt @@ -13,6 +13,7 @@ interface TierDetection { * Higher tiers provide better performance on supported hardware. */ enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) { + NONE(404, "", "No valid Arm optimization available!"), T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"), T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"), T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"), diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt index 64d9461a98..66943b99b0 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt @@ -7,7 +7,7 @@ import android.llama.cpp.TierDetection * Internal factory to create [InferenceEngine] and [TierDetection] */ internal object InferenceEngineFactory { - fun createInstance(context: Context) = InferenceEngineLoader.createInstance(context) + fun getInstance(context: Context) = InferenceEngineLoader.getInstance(context) fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context) } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt index 38325365c6..8a242f11a3 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt @@ -48,22 +48,22 @@ internal class InferenceEngineImpl private constructor( /** * Create [InferenceEngineImpl] instance with specific tier + * + * @throws IllegalArgumentException if tier's library name is invalid + * @throws UnsatisfiedLinkError if library failed to load */ - internal fun createWithTier(tier: LLamaTier): InferenceEngineImpl? { - if (initialized) { - Log.w(TAG, "LLamaAndroid already initialized") - return null - } + internal fun createWithTier(tier: LLamaTier): InferenceEngineImpl { + assert(!initialized) { "Inference Engine has already been initialized!" } - try { + require(tier.libraryName.isNotBlank()) { "Unexpected library: ${tier.libraryName}" } + + return try { Log.i(TAG, "Instantiating InferenceEngineImpl w/ ${tier.libraryName}") - val instance = InferenceEngineImpl(tier) - initialized = true - return instance + InferenceEngineImpl(tier).also { initialized = true } } catch (e: UnsatisfiedLinkError) { Log.e(TAG, "Failed to load ${tier.libraryName}", e) - return null + throw e } } } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt index c23285147f..c73ef5b9cb 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt @@ -37,39 +37,46 @@ internal object InferenceEngineLoader { private var _cachedInstance: InferenceEngineImpl? = null private var _detectedTier: LLamaTier? = null - val detectedTier: LLamaTier? get() = _detectedTier + + /** + * Get the detected tier, loading from cache if needed + */ + fun getDetectedTier(context: Context): LLamaTier? = + _detectedTier ?: runBlocking { + loadDetectedTierFromDataStore(context) + } /** * Factory method to get a configured [InferenceEngineImpl] instance. * Handles tier detection, caching, and library loading automatically. */ @Synchronized - fun createInstance(context: Context): InferenceEngine? { + fun getInstance(context: Context): InferenceEngine? { // Return cached instance if available _cachedInstance?.let { return it } return runBlocking { + // Obtain the optimal tier from cache if available + val tier = loadDetectedTierFromDataStore(context) ?: run { + Log.i(TAG, "Performing fresh tier detection") + detectAndSaveOptimalTier(context) + } + + if (tier == null || tier == LLamaTier.NONE) { + Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier") + return@runBlocking null + } + try { - // Obtain the optimal tier from cache if available - val tier = getOrDetectOptimalTier(context) ?: run { - Log.e(TAG, "Failed to determine optimal tier") - return@runBlocking null - } - _detectedTier = tier - Log.i(TAG, "Using tier: ${tier.name} (${tier.description})") - // Create and cache the inference engine instance - val instance = InferenceEngineImpl.createWithTier(tier) ?: run { - Log.e(TAG, "Failed to instantiate InferenceEngineImpl") - return@runBlocking null + Log.i(TAG, "Using tier: ${tier.name} (${tier.description})") + InferenceEngineImpl.createWithTier(tier).also { + _cachedInstance = it + Log.i(TAG, "Successfully instantiated Inference Engine w/ ${tier.name}") } - _cachedInstance = instance - Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}") - - instance } catch (e: Exception) { - Log.e(TAG, "Error creating InferenceEngineImpl instance", e) + Log.e(TAG, "Error instantiating Inference Engine", e) null } } @@ -86,31 +93,37 @@ internal object InferenceEngineLoader { } /** - * Get optimal tier from cache or detect it fresh + * Load cached tier from datastore without performing detection */ - private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? { + private suspend fun loadDetectedTierFromDataStore(context: Context): LLamaTier? { val preferences = context.llamaTierDataStore.data.first() - - // Check if we have a cached result with the current detection version val cachedVersion = preferences[DETECTION_VERSION] ?: -1 val cachedTierValue = preferences[DETECTED_TIER] ?: -1 - if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) { - val cachedTier = LLamaTier.Companion.fromRawValue(cachedTierValue) - if (cachedTier != null) { - Log.i(TAG, "Using cached tier detection: ${cachedTier.name}") - return cachedTier - } - } - // No valid cache, detect fresh - Log.i(TAG, "Performing fresh tier detection") - return detectAndCacheOptimalTier(context) + return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) { + LLamaTier.fromRawValue(cachedTierValue)?.also { + Log.i(TAG, "Loaded cached tier: ${it.name}") + _detectedTier = it + } + } else { + Log.i(TAG, "No valid cached tier found") + null + } } /** * Detect optimal tier and save to cache */ - private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? { + private suspend fun detectAndSaveOptimalTier(context: Context): LLamaTier? = + detectOptimalTier().also { tier -> + tier?.saveToDataStore(context) + _detectedTier = tier + } + + /** + * Detect optimal tier and save to cache + */ + private fun detectOptimalTier(): LLamaTier? { try { // Load CPU detection library System.loadLibrary("llama_cpu_detector") @@ -122,44 +135,34 @@ internal object InferenceEngineLoader { Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features") // Convert to enum and validate - val tier = LLamaTier.Companion.fromRawValue(tierValue) ?: run { - Log.w(TAG, "Invalid tier value $tierValue") - return null + val tier = LLamaTier.fromRawValue(tierValue) ?: run { + Log.e(TAG, "Invalid tier value $tierValue") + return LLamaTier.NONE } // Ensure we don't exceed maximum supported tier - val finalTier = if (tier.rawValue > LLamaTier.Companion.maxSupportedTier.rawValue) { - Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.Companion.maxSupportedTier.name}") - LLamaTier.Companion.maxSupportedTier + return if (tier.rawValue > LLamaTier.maxSupportedTier.rawValue) { + Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.maxSupportedTier.name}") + LLamaTier.maxSupportedTier } else { tier } - // Cache the result - context.llamaTierDataStore.edit { - it[DETECTED_TIER] = finalTier.rawValue - it[DETECTION_VERSION] = DATASTORE_VERSION - } - - Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}") - return finalTier - } catch (e: UnsatisfiedLinkError) { Log.e(TAG, "Failed to load CPU detection library", e) - - // Fallback to T0 and cache it - val fallbackTier = LLamaTier.T0 - context.llamaTierDataStore.edit { - it[DETECTED_TIER] = fallbackTier.rawValue - it[DETECTION_VERSION] = DATASTORE_VERSION - } - - Log.i(TAG, "Using fallback tier: ${fallbackTier.name}") - return fallbackTier + return null } catch (e: Exception) { Log.e(TAG, "Unexpected error during tier detection", e) return null } } + + private suspend fun LLamaTier.saveToDataStore(context: Context) { + context.llamaTierDataStore.edit { prefs -> + prefs[DETECTED_TIER] = this.rawValue + prefs[DETECTION_VERSION] = DATASTORE_VERSION + } + Log.i(TAG, "Saved ${this.name} to data store") + } } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt index 1d7b731961..c439561aba 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt @@ -9,7 +9,7 @@ import android.llama.cpp.TierDetection */ internal class TierDetectionImpl(private val context: Context) : TierDetection { override val detectedTier: LLamaTier? - get() = InferenceEngineLoader.detectedTier + get() = InferenceEngineLoader.getDetectedTier(context) override fun clearCache() = InferenceEngineLoader.clearCache(context) }