lib: refactored InferenceEngineLoader; added a `NONE` Llama Tier

This commit is contained in:
Han Yin 2025-07-03 13:33:05 -07:00
parent 8c6e449ad2
commit 1f41ae2315
7 changed files with 78 additions and 74 deletions

View File

@ -65,14 +65,14 @@ object ArmFeaturesMapper {
* Maps a LLamaTier to its supported ARM features. * Maps a LLamaTier to its supported ARM features.
* Returns a list of booleans where each index corresponds to allFeatures. * Returns a list of booleans where each index corresponds to allFeatures.
*/ */
fun getSupportedFeatures(tier: LLamaTier?): List<Boolean> = fun getSupportedFeatures(tier: LLamaTier?): List<Boolean>? =
when (tier) { when (tier) {
LLamaTier.NONE, null -> null // No tier detected
LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only
LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE
// TODO-han.yin: implement T4 once obtaining an Android device with SME! // TODO-han.yin: implement T4 once obtaining an Android device with SME!
null -> listOf(false, false, false, false, false) // No tier detected
} }
/** /**
@ -83,7 +83,7 @@ object ArmFeaturesMapper {
allFeatures.mapIndexed { index, feature -> allFeatures.mapIndexed { index, feature ->
DisplayItem( DisplayItem(
feature = feature, feature = feature,
isSupported = flags.getOrElse(index) { false } isSupported = flags?.getOrElse(index) { false } == true
) )
} }
} }

View File

@ -12,7 +12,7 @@ object KleidiLlama {
* Create an inference engine instance with automatic tier detection. * Create an inference engine instance with automatic tier detection.
*/ */
fun createInferenceEngine(context: Context) = fun createInferenceEngine(context: Context) =
InferenceEngineFactory.createInstance(context) InferenceEngineFactory.getInstance(context)
/** /**
* Get tier detection information for debugging/settings. * Get tier detection information for debugging/settings.

View File

@ -13,6 +13,7 @@ interface TierDetection {
* Higher tiers provide better performance on supported hardware. * Higher tiers provide better performance on supported hardware.
*/ */
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) { enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
NONE(404, "", "No valid Arm optimization available!"),
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"), T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"), T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"), T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),

View File

@ -7,7 +7,7 @@ import android.llama.cpp.TierDetection
* Internal factory to create [InferenceEngine] and [TierDetection] * Internal factory to create [InferenceEngine] and [TierDetection]
*/ */
internal object InferenceEngineFactory { internal object InferenceEngineFactory {
fun createInstance(context: Context) = InferenceEngineLoader.createInstance(context) fun getInstance(context: Context) = InferenceEngineLoader.getInstance(context)
fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context) fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context)
} }

View File

@ -48,22 +48,22 @@ internal class InferenceEngineImpl private constructor(
/** /**
* Create [InferenceEngineImpl] instance with specific tier * Create [InferenceEngineImpl] instance with specific tier
*
* @throws IllegalArgumentException if tier's library name is invalid
* @throws UnsatisfiedLinkError if library failed to load
*/ */
internal fun createWithTier(tier: LLamaTier): InferenceEngineImpl? { internal fun createWithTier(tier: LLamaTier): InferenceEngineImpl {
if (initialized) { assert(!initialized) { "Inference Engine has already been initialized!" }
Log.w(TAG, "LLamaAndroid already initialized")
return null
}
try { require(tier.libraryName.isNotBlank()) { "Unexpected library: ${tier.libraryName}" }
return try {
Log.i(TAG, "Instantiating InferenceEngineImpl w/ ${tier.libraryName}") Log.i(TAG, "Instantiating InferenceEngineImpl w/ ${tier.libraryName}")
val instance = InferenceEngineImpl(tier) InferenceEngineImpl(tier).also { initialized = true }
initialized = true
return instance
} catch (e: UnsatisfiedLinkError) { } catch (e: UnsatisfiedLinkError) {
Log.e(TAG, "Failed to load ${tier.libraryName}", e) Log.e(TAG, "Failed to load ${tier.libraryName}", e)
return null throw e
} }
} }
} }

View File

@ -37,39 +37,46 @@ internal object InferenceEngineLoader {
private var _cachedInstance: InferenceEngineImpl? = null private var _cachedInstance: InferenceEngineImpl? = null
private var _detectedTier: LLamaTier? = null private var _detectedTier: LLamaTier? = null
val detectedTier: LLamaTier? get() = _detectedTier
/**
* Get the detected tier, loading from cache if needed
*/
fun getDetectedTier(context: Context): LLamaTier? =
_detectedTier ?: runBlocking {
loadDetectedTierFromDataStore(context)
}
/** /**
* Factory method to get a configured [InferenceEngineImpl] instance. * Factory method to get a configured [InferenceEngineImpl] instance.
* Handles tier detection, caching, and library loading automatically. * Handles tier detection, caching, and library loading automatically.
*/ */
@Synchronized @Synchronized
fun createInstance(context: Context): InferenceEngine? { fun getInstance(context: Context): InferenceEngine? {
// Return cached instance if available // Return cached instance if available
_cachedInstance?.let { return it } _cachedInstance?.let { return it }
return runBlocking { return runBlocking {
try {
// Obtain the optimal tier from cache if available // Obtain the optimal tier from cache if available
val tier = getOrDetectOptimalTier(context) ?: run { val tier = loadDetectedTierFromDataStore(context) ?: run {
Log.e(TAG, "Failed to determine optimal tier") Log.i(TAG, "Performing fresh tier detection")
detectAndSaveOptimalTier(context)
}
if (tier == null || tier == LLamaTier.NONE) {
Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier")
return@runBlocking null return@runBlocking null
} }
_detectedTier = tier
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
try {
// Create and cache the inference engine instance // Create and cache the inference engine instance
val instance = InferenceEngineImpl.createWithTier(tier) ?: run { Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
Log.e(TAG, "Failed to instantiate InferenceEngineImpl") InferenceEngineImpl.createWithTier(tier).also {
return@runBlocking null _cachedInstance = it
Log.i(TAG, "Successfully instantiated Inference Engine w/ ${tier.name}")
} }
_cachedInstance = instance
Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}")
instance
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error creating InferenceEngineImpl instance", e) Log.e(TAG, "Error instantiating Inference Engine", e)
null null
} }
} }
@ -86,31 +93,37 @@ internal object InferenceEngineLoader {
} }
/** /**
* Get optimal tier from cache or detect it fresh * Load cached tier from datastore without performing detection
*/ */
private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? { private suspend fun loadDetectedTierFromDataStore(context: Context): LLamaTier? {
val preferences = context.llamaTierDataStore.data.first() val preferences = context.llamaTierDataStore.data.first()
// Check if we have a cached result with the current detection version
val cachedVersion = preferences[DETECTION_VERSION] ?: -1 val cachedVersion = preferences[DETECTION_VERSION] ?: -1
val cachedTierValue = preferences[DETECTED_TIER] ?: -1 val cachedTierValue = preferences[DETECTED_TIER] ?: -1
if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
val cachedTier = LLamaTier.Companion.fromRawValue(cachedTierValue)
if (cachedTier != null) {
Log.i(TAG, "Using cached tier detection: ${cachedTier.name}")
return cachedTier
}
}
// No valid cache, detect fresh return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
Log.i(TAG, "Performing fresh tier detection") LLamaTier.fromRawValue(cachedTierValue)?.also {
return detectAndCacheOptimalTier(context) Log.i(TAG, "Loaded cached tier: ${it.name}")
_detectedTier = it
}
} else {
Log.i(TAG, "No valid cached tier found")
null
}
} }
/** /**
* Detect optimal tier and save to cache * Detect optimal tier and save to cache
*/ */
private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? { private suspend fun detectAndSaveOptimalTier(context: Context): LLamaTier? =
detectOptimalTier().also { tier ->
tier?.saveToDataStore(context)
_detectedTier = tier
}
/**
* Detect optimal tier and save to cache
*/
private fun detectOptimalTier(): LLamaTier? {
try { try {
// Load CPU detection library // Load CPU detection library
System.loadLibrary("llama_cpu_detector") System.loadLibrary("llama_cpu_detector")
@ -122,44 +135,34 @@ internal object InferenceEngineLoader {
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features") Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
// Convert to enum and validate // Convert to enum and validate
val tier = LLamaTier.Companion.fromRawValue(tierValue) ?: run { val tier = LLamaTier.fromRawValue(tierValue) ?: run {
Log.w(TAG, "Invalid tier value $tierValue") Log.e(TAG, "Invalid tier value $tierValue")
return null return LLamaTier.NONE
} }
// Ensure we don't exceed maximum supported tier // Ensure we don't exceed maximum supported tier
val finalTier = if (tier.rawValue > LLamaTier.Companion.maxSupportedTier.rawValue) { return if (tier.rawValue > LLamaTier.maxSupportedTier.rawValue) {
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.Companion.maxSupportedTier.name}") Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.maxSupportedTier.name}")
LLamaTier.Companion.maxSupportedTier LLamaTier.maxSupportedTier
} else { } else {
tier tier
} }
// Cache the result
context.llamaTierDataStore.edit {
it[DETECTED_TIER] = finalTier.rawValue
it[DETECTION_VERSION] = DATASTORE_VERSION
}
Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}")
return finalTier
} catch (e: UnsatisfiedLinkError) { } catch (e: UnsatisfiedLinkError) {
Log.e(TAG, "Failed to load CPU detection library", e) Log.e(TAG, "Failed to load CPU detection library", e)
return null
// Fallback to T0 and cache it
val fallbackTier = LLamaTier.T0
context.llamaTierDataStore.edit {
it[DETECTED_TIER] = fallbackTier.rawValue
it[DETECTION_VERSION] = DATASTORE_VERSION
}
Log.i(TAG, "Using fallback tier: ${fallbackTier.name}")
return fallbackTier
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Unexpected error during tier detection", e) Log.e(TAG, "Unexpected error during tier detection", e)
return null return null
} }
} }
private suspend fun LLamaTier.saveToDataStore(context: Context) {
context.llamaTierDataStore.edit { prefs ->
prefs[DETECTED_TIER] = this.rawValue
prefs[DETECTION_VERSION] = DATASTORE_VERSION
}
Log.i(TAG, "Saved ${this.name} to data store")
}
} }

View File

@ -9,7 +9,7 @@ import android.llama.cpp.TierDetection
*/ */
internal class TierDetectionImpl(private val context: Context) : TierDetection { internal class TierDetectionImpl(private val context: Context) : TierDetection {
override val detectedTier: LLamaTier? override val detectedTier: LLamaTier?
get() = InferenceEngineLoader.detectedTier get() = InferenceEngineLoader.getDetectedTier(context)
override fun clearCache() = InferenceEngineLoader.clearCache(context) override fun clearCache() = InferenceEngineLoader.clearCache(context)
} }