lib: refactored InferenceEngineLoader; added a `NONE` Llama Tier

This commit is contained in:
Han Yin 2025-07-03 13:33:05 -07:00
parent 8c6e449ad2
commit 1f41ae2315
7 changed files with 78 additions and 74 deletions

View File

@ -65,14 +65,14 @@ object ArmFeaturesMapper {
* Maps a LLamaTier to its supported ARM features.
* Returns a list of booleans where each index corresponds to allFeatures.
*/
fun getSupportedFeatures(tier: LLamaTier?): List<Boolean> =
fun getSupportedFeatures(tier: LLamaTier?): List<Boolean>? =
when (tier) {
LLamaTier.NONE, null -> null // No tier detected
LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only
LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
null -> listOf(false, false, false, false, false) // No tier detected
}
/**
@ -83,7 +83,7 @@ object ArmFeaturesMapper {
allFeatures.mapIndexed { index, feature ->
DisplayItem(
feature = feature,
isSupported = flags.getOrElse(index) { false }
isSupported = flags?.getOrElse(index) { false } == true
)
}
}

View File

@ -12,7 +12,7 @@ object KleidiLlama {
* Create an inference engine instance with automatic tier detection.
*/
fun createInferenceEngine(context: Context) =
InferenceEngineFactory.createInstance(context)
InferenceEngineFactory.getInstance(context)
/**
* Get tier detection information for debugging/settings.

View File

@ -13,6 +13,7 @@ interface TierDetection {
* Higher tiers provide better performance on supported hardware.
*/
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
NONE(404, "", "No valid Arm optimization available!"),
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),

View File

@ -7,7 +7,7 @@ import android.llama.cpp.TierDetection
* Internal factory to create [InferenceEngine] and [TierDetection]
*/
internal object InferenceEngineFactory {
fun createInstance(context: Context) = InferenceEngineLoader.createInstance(context)
fun getInstance(context: Context) = InferenceEngineLoader.getInstance(context)
fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context)
}

View File

@ -48,22 +48,22 @@ internal class InferenceEngineImpl private constructor(
/**
* Create [InferenceEngineImpl] instance with specific tier
*
* @throws IllegalArgumentException if tier's library name is invalid
* @throws UnsatisfiedLinkError if library failed to load
*/
internal fun createWithTier(tier: LLamaTier): InferenceEngineImpl? {
if (initialized) {
Log.w(TAG, "LLamaAndroid already initialized")
return null
}
internal fun createWithTier(tier: LLamaTier): InferenceEngineImpl {
assert(!initialized) { "Inference Engine has already been initialized!" }
try {
require(tier.libraryName.isNotBlank()) { "Unexpected library: ${tier.libraryName}" }
return try {
Log.i(TAG, "Instantiating InferenceEngineImpl w/ ${tier.libraryName}")
val instance = InferenceEngineImpl(tier)
initialized = true
return instance
InferenceEngineImpl(tier).also { initialized = true }
} catch (e: UnsatisfiedLinkError) {
Log.e(TAG, "Failed to load ${tier.libraryName}", e)
return null
throw e
}
}
}

View File

@ -37,39 +37,46 @@ internal object InferenceEngineLoader {
private var _cachedInstance: InferenceEngineImpl? = null
private var _detectedTier: LLamaTier? = null
val detectedTier: LLamaTier? get() = _detectedTier
/**
* Get the detected tier, loading from cache if needed
*/
fun getDetectedTier(context: Context): LLamaTier? =
_detectedTier ?: runBlocking {
loadDetectedTierFromDataStore(context)
}
/**
* Factory method to get a configured [InferenceEngineImpl] instance.
* Handles tier detection, caching, and library loading automatically.
*/
@Synchronized
fun createInstance(context: Context): InferenceEngine? {
fun getInstance(context: Context): InferenceEngine? {
// Return cached instance if available
_cachedInstance?.let { return it }
return runBlocking {
// Obtain the optimal tier from cache if available
val tier = loadDetectedTierFromDataStore(context) ?: run {
Log.i(TAG, "Performing fresh tier detection")
detectAndSaveOptimalTier(context)
}
if (tier == null || tier == LLamaTier.NONE) {
Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier")
return@runBlocking null
}
try {
// Obtain the optimal tier from cache if available
val tier = getOrDetectOptimalTier(context) ?: run {
Log.e(TAG, "Failed to determine optimal tier")
return@runBlocking null
}
_detectedTier = tier
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
// Create and cache the inference engine instance
val instance = InferenceEngineImpl.createWithTier(tier) ?: run {
Log.e(TAG, "Failed to instantiate InferenceEngineImpl")
return@runBlocking null
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
InferenceEngineImpl.createWithTier(tier).also {
_cachedInstance = it
Log.i(TAG, "Successfully instantiated Inference Engine w/ ${tier.name}")
}
_cachedInstance = instance
Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}")
instance
} catch (e: Exception) {
Log.e(TAG, "Error creating InferenceEngineImpl instance", e)
Log.e(TAG, "Error instantiating Inference Engine", e)
null
}
}
@ -86,31 +93,37 @@ internal object InferenceEngineLoader {
}
/**
* Get optimal tier from cache or detect it fresh
* Load cached tier from datastore without performing detection
*/
private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? {
private suspend fun loadDetectedTierFromDataStore(context: Context): LLamaTier? {
val preferences = context.llamaTierDataStore.data.first()
// Check if we have a cached result with the current detection version
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
val cachedTier = LLamaTier.Companion.fromRawValue(cachedTierValue)
if (cachedTier != null) {
Log.i(TAG, "Using cached tier detection: ${cachedTier.name}")
return cachedTier
}
}
// No valid cache, detect fresh
Log.i(TAG, "Performing fresh tier detection")
return detectAndCacheOptimalTier(context)
return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
LLamaTier.fromRawValue(cachedTierValue)?.also {
Log.i(TAG, "Loaded cached tier: ${it.name}")
_detectedTier = it
}
} else {
Log.i(TAG, "No valid cached tier found")
null
}
}
/**
* Detect optimal tier and save to cache
*/
private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? {
private suspend fun detectAndSaveOptimalTier(context: Context): LLamaTier? =
detectOptimalTier().also { tier ->
tier?.saveToDataStore(context)
_detectedTier = tier
}
/**
* Detect optimal tier and save to cache
*/
private fun detectOptimalTier(): LLamaTier? {
try {
// Load CPU detection library
System.loadLibrary("llama_cpu_detector")
@ -122,44 +135,34 @@ internal object InferenceEngineLoader {
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
// Convert to enum and validate
val tier = LLamaTier.Companion.fromRawValue(tierValue) ?: run {
Log.w(TAG, "Invalid tier value $tierValue")
return null
val tier = LLamaTier.fromRawValue(tierValue) ?: run {
Log.e(TAG, "Invalid tier value $tierValue")
return LLamaTier.NONE
}
// Ensure we don't exceed maximum supported tier
val finalTier = if (tier.rawValue > LLamaTier.Companion.maxSupportedTier.rawValue) {
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.Companion.maxSupportedTier.name}")
LLamaTier.Companion.maxSupportedTier
return if (tier.rawValue > LLamaTier.maxSupportedTier.rawValue) {
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.maxSupportedTier.name}")
LLamaTier.maxSupportedTier
} else {
tier
}
// Cache the result
context.llamaTierDataStore.edit {
it[DETECTED_TIER] = finalTier.rawValue
it[DETECTION_VERSION] = DATASTORE_VERSION
}
Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}")
return finalTier
} catch (e: UnsatisfiedLinkError) {
Log.e(TAG, "Failed to load CPU detection library", e)
// Fallback to T0 and cache it
val fallbackTier = LLamaTier.T0
context.llamaTierDataStore.edit {
it[DETECTED_TIER] = fallbackTier.rawValue
it[DETECTION_VERSION] = DATASTORE_VERSION
}
Log.i(TAG, "Using fallback tier: ${fallbackTier.name}")
return fallbackTier
return null
} catch (e: Exception) {
Log.e(TAG, "Unexpected error during tier detection", e)
return null
}
}
private suspend fun LLamaTier.saveToDataStore(context: Context) {
context.llamaTierDataStore.edit { prefs ->
prefs[DETECTED_TIER] = this.rawValue
prefs[DETECTION_VERSION] = DATASTORE_VERSION
}
Log.i(TAG, "Saved ${this.name} to data store")
}
}

View File

@ -9,7 +9,7 @@ import android.llama.cpp.TierDetection
*/
internal class TierDetectionImpl(private val context: Context) : TierDetection {
override val detectedTier: LLamaTier?
get() = InferenceEngineLoader.detectedTier
get() = InferenceEngineLoader.getDetectedTier(context)
override fun clearCache() = InferenceEngineLoader.clearCache(context)
}