lib: refactored InferenceEngineLoader; added a `NONE` Llama Tier
This commit is contained in:
parent
8c6e449ad2
commit
1f41ae2315
|
|
@ -65,14 +65,14 @@ object ArmFeaturesMapper {
|
|||
* Maps a LLamaTier to its supported ARM features.
|
||||
* Returns a list of booleans where each index corresponds to allFeatures.
|
||||
*/
|
||||
fun getSupportedFeatures(tier: LLamaTier?): List<Boolean> =
|
||||
fun getSupportedFeatures(tier: LLamaTier?): List<Boolean>? =
|
||||
when (tier) {
|
||||
LLamaTier.NONE, null -> null // No tier detected
|
||||
LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only
|
||||
LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
|
||||
LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
|
||||
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE
|
||||
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
|
||||
null -> listOf(false, false, false, false, false) // No tier detected
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -83,7 +83,7 @@ object ArmFeaturesMapper {
|
|||
allFeatures.mapIndexed { index, feature ->
|
||||
DisplayItem(
|
||||
feature = feature,
|
||||
isSupported = flags.getOrElse(index) { false }
|
||||
isSupported = flags?.getOrElse(index) { false } == true
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ object KleidiLlama {
|
|||
* Create an inference engine instance with automatic tier detection.
|
||||
*/
|
||||
fun createInferenceEngine(context: Context) =
|
||||
InferenceEngineFactory.createInstance(context)
|
||||
InferenceEngineFactory.getInstance(context)
|
||||
|
||||
/**
|
||||
* Get tier detection information for debugging/settings.
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ interface TierDetection {
|
|||
* Higher tiers provide better performance on supported hardware.
|
||||
*/
|
||||
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
|
||||
NONE(404, "", "No valid Arm optimization available!"),
|
||||
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
|
||||
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
|
||||
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import android.llama.cpp.TierDetection
|
|||
* Internal factory to create [InferenceEngine] and [TierDetection]
|
||||
*/
|
||||
internal object InferenceEngineFactory {
|
||||
fun createInstance(context: Context) = InferenceEngineLoader.createInstance(context)
|
||||
fun getInstance(context: Context) = InferenceEngineLoader.getInstance(context)
|
||||
|
||||
fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,22 +48,22 @@ internal class InferenceEngineImpl private constructor(
|
|||
|
||||
/**
|
||||
* Create [InferenceEngineImpl] instance with specific tier
|
||||
*
|
||||
* @throws IllegalArgumentException if tier's library name is invalid
|
||||
* @throws UnsatisfiedLinkError if library failed to load
|
||||
*/
|
||||
internal fun createWithTier(tier: LLamaTier): InferenceEngineImpl? {
|
||||
if (initialized) {
|
||||
Log.w(TAG, "LLamaAndroid already initialized")
|
||||
return null
|
||||
}
|
||||
internal fun createWithTier(tier: LLamaTier): InferenceEngineImpl {
|
||||
assert(!initialized) { "Inference Engine has already been initialized!" }
|
||||
|
||||
try {
|
||||
require(tier.libraryName.isNotBlank()) { "Unexpected library: ${tier.libraryName}" }
|
||||
|
||||
return try {
|
||||
Log.i(TAG, "Instantiating InferenceEngineImpl w/ ${tier.libraryName}")
|
||||
val instance = InferenceEngineImpl(tier)
|
||||
initialized = true
|
||||
return instance
|
||||
InferenceEngineImpl(tier).also { initialized = true }
|
||||
|
||||
} catch (e: UnsatisfiedLinkError) {
|
||||
Log.e(TAG, "Failed to load ${tier.libraryName}", e)
|
||||
return null
|
||||
throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,39 +37,46 @@ internal object InferenceEngineLoader {
|
|||
|
||||
private var _cachedInstance: InferenceEngineImpl? = null
|
||||
private var _detectedTier: LLamaTier? = null
|
||||
val detectedTier: LLamaTier? get() = _detectedTier
|
||||
|
||||
/**
|
||||
* Get the detected tier, loading from cache if needed
|
||||
*/
|
||||
fun getDetectedTier(context: Context): LLamaTier? =
|
||||
_detectedTier ?: runBlocking {
|
||||
loadDetectedTierFromDataStore(context)
|
||||
}
|
||||
|
||||
/**
|
||||
* Factory method to get a configured [InferenceEngineImpl] instance.
|
||||
* Handles tier detection, caching, and library loading automatically.
|
||||
*/
|
||||
@Synchronized
|
||||
fun createInstance(context: Context): InferenceEngine? {
|
||||
fun getInstance(context: Context): InferenceEngine? {
|
||||
// Return cached instance if available
|
||||
_cachedInstance?.let { return it }
|
||||
|
||||
return runBlocking {
|
||||
// Obtain the optimal tier from cache if available
|
||||
val tier = loadDetectedTierFromDataStore(context) ?: run {
|
||||
Log.i(TAG, "Performing fresh tier detection")
|
||||
detectAndSaveOptimalTier(context)
|
||||
}
|
||||
|
||||
if (tier == null || tier == LLamaTier.NONE) {
|
||||
Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier")
|
||||
return@runBlocking null
|
||||
}
|
||||
|
||||
try {
|
||||
// Obtain the optimal tier from cache if available
|
||||
val tier = getOrDetectOptimalTier(context) ?: run {
|
||||
Log.e(TAG, "Failed to determine optimal tier")
|
||||
return@runBlocking null
|
||||
}
|
||||
_detectedTier = tier
|
||||
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
|
||||
|
||||
// Create and cache the inference engine instance
|
||||
val instance = InferenceEngineImpl.createWithTier(tier) ?: run {
|
||||
Log.e(TAG, "Failed to instantiate InferenceEngineImpl")
|
||||
return@runBlocking null
|
||||
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
|
||||
InferenceEngineImpl.createWithTier(tier).also {
|
||||
_cachedInstance = it
|
||||
Log.i(TAG, "Successfully instantiated Inference Engine w/ ${tier.name}")
|
||||
}
|
||||
_cachedInstance = instance
|
||||
Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}")
|
||||
|
||||
instance
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error creating InferenceEngineImpl instance", e)
|
||||
Log.e(TAG, "Error instantiating Inference Engine", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
|
@ -86,31 +93,37 @@ internal object InferenceEngineLoader {
|
|||
}
|
||||
|
||||
/**
|
||||
* Get optimal tier from cache or detect it fresh
|
||||
* Load cached tier from datastore without performing detection
|
||||
*/
|
||||
private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? {
|
||||
private suspend fun loadDetectedTierFromDataStore(context: Context): LLamaTier? {
|
||||
val preferences = context.llamaTierDataStore.data.first()
|
||||
|
||||
// Check if we have a cached result with the current detection version
|
||||
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
|
||||
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
|
||||
if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
|
||||
val cachedTier = LLamaTier.Companion.fromRawValue(cachedTierValue)
|
||||
if (cachedTier != null) {
|
||||
Log.i(TAG, "Using cached tier detection: ${cachedTier.name}")
|
||||
return cachedTier
|
||||
}
|
||||
}
|
||||
|
||||
// No valid cache, detect fresh
|
||||
Log.i(TAG, "Performing fresh tier detection")
|
||||
return detectAndCacheOptimalTier(context)
|
||||
return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
|
||||
LLamaTier.fromRawValue(cachedTierValue)?.also {
|
||||
Log.i(TAG, "Loaded cached tier: ${it.name}")
|
||||
_detectedTier = it
|
||||
}
|
||||
} else {
|
||||
Log.i(TAG, "No valid cached tier found")
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect optimal tier and save to cache
|
||||
*/
|
||||
private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? {
|
||||
private suspend fun detectAndSaveOptimalTier(context: Context): LLamaTier? =
|
||||
detectOptimalTier().also { tier ->
|
||||
tier?.saveToDataStore(context)
|
||||
_detectedTier = tier
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect optimal tier and save to cache
|
||||
*/
|
||||
private fun detectOptimalTier(): LLamaTier? {
|
||||
try {
|
||||
// Load CPU detection library
|
||||
System.loadLibrary("llama_cpu_detector")
|
||||
|
|
@ -122,44 +135,34 @@ internal object InferenceEngineLoader {
|
|||
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
|
||||
|
||||
// Convert to enum and validate
|
||||
val tier = LLamaTier.Companion.fromRawValue(tierValue) ?: run {
|
||||
Log.w(TAG, "Invalid tier value $tierValue")
|
||||
return null
|
||||
val tier = LLamaTier.fromRawValue(tierValue) ?: run {
|
||||
Log.e(TAG, "Invalid tier value $tierValue")
|
||||
return LLamaTier.NONE
|
||||
}
|
||||
|
||||
// Ensure we don't exceed maximum supported tier
|
||||
val finalTier = if (tier.rawValue > LLamaTier.Companion.maxSupportedTier.rawValue) {
|
||||
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.Companion.maxSupportedTier.name}")
|
||||
LLamaTier.Companion.maxSupportedTier
|
||||
return if (tier.rawValue > LLamaTier.maxSupportedTier.rawValue) {
|
||||
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.maxSupportedTier.name}")
|
||||
LLamaTier.maxSupportedTier
|
||||
} else {
|
||||
tier
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
context.llamaTierDataStore.edit {
|
||||
it[DETECTED_TIER] = finalTier.rawValue
|
||||
it[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
|
||||
Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}")
|
||||
return finalTier
|
||||
|
||||
} catch (e: UnsatisfiedLinkError) {
|
||||
Log.e(TAG, "Failed to load CPU detection library", e)
|
||||
|
||||
// Fallback to T0 and cache it
|
||||
val fallbackTier = LLamaTier.T0
|
||||
context.llamaTierDataStore.edit {
|
||||
it[DETECTED_TIER] = fallbackTier.rawValue
|
||||
it[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
|
||||
Log.i(TAG, "Using fallback tier: ${fallbackTier.name}")
|
||||
return fallbackTier
|
||||
return null
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Unexpected error during tier detection", e)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun LLamaTier.saveToDataStore(context: Context) {
|
||||
context.llamaTierDataStore.edit { prefs ->
|
||||
prefs[DETECTED_TIER] = this.rawValue
|
||||
prefs[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
Log.i(TAG, "Saved ${this.name} to data store")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import android.llama.cpp.TierDetection
|
|||
*/
|
||||
internal class TierDetectionImpl(private val context: Context) : TierDetection {
|
||||
override val detectedTier: LLamaTier?
|
||||
get() = InferenceEngineLoader.detectedTier
|
||||
get() = InferenceEngineLoader.getDetectedTier(context)
|
||||
|
||||
override fun clearCache() = InferenceEngineLoader.clearCache(context)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue