lib: refactored InferenceEngineLoader; added a `NONE` Llama Tier
This commit is contained in:
parent
8c6e449ad2
commit
1f41ae2315
|
|
@ -65,14 +65,14 @@ object ArmFeaturesMapper {
|
||||||
* Maps a LLamaTier to its supported ARM features.
|
* Maps a LLamaTier to its supported ARM features.
|
||||||
* Returns a list of booleans where each index corresponds to allFeatures.
|
* Returns a list of booleans where each index corresponds to allFeatures.
|
||||||
*/
|
*/
|
||||||
fun getSupportedFeatures(tier: LLamaTier?): List<Boolean> =
|
fun getSupportedFeatures(tier: LLamaTier?): List<Boolean>? =
|
||||||
when (tier) {
|
when (tier) {
|
||||||
|
LLamaTier.NONE, null -> null // No tier detected
|
||||||
LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only
|
LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only
|
||||||
LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
|
LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
|
||||||
LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
|
LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
|
||||||
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE
|
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE
|
||||||
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
|
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
|
||||||
null -> listOf(false, false, false, false, false) // No tier detected
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -83,7 +83,7 @@ object ArmFeaturesMapper {
|
||||||
allFeatures.mapIndexed { index, feature ->
|
allFeatures.mapIndexed { index, feature ->
|
||||||
DisplayItem(
|
DisplayItem(
|
||||||
feature = feature,
|
feature = feature,
|
||||||
isSupported = flags.getOrElse(index) { false }
|
isSupported = flags?.getOrElse(index) { false } == true
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ object KleidiLlama {
|
||||||
* Create an inference engine instance with automatic tier detection.
|
* Create an inference engine instance with automatic tier detection.
|
||||||
*/
|
*/
|
||||||
fun createInferenceEngine(context: Context) =
|
fun createInferenceEngine(context: Context) =
|
||||||
InferenceEngineFactory.createInstance(context)
|
InferenceEngineFactory.getInstance(context)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get tier detection information for debugging/settings.
|
* Get tier detection information for debugging/settings.
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ interface TierDetection {
|
||||||
* Higher tiers provide better performance on supported hardware.
|
* Higher tiers provide better performance on supported hardware.
|
||||||
*/
|
*/
|
||||||
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
|
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
|
||||||
|
NONE(404, "", "No valid Arm optimization available!"),
|
||||||
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
|
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
|
||||||
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
|
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
|
||||||
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
|
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import android.llama.cpp.TierDetection
|
||||||
* Internal factory to create [InferenceEngine] and [TierDetection]
|
* Internal factory to create [InferenceEngine] and [TierDetection]
|
||||||
*/
|
*/
|
||||||
internal object InferenceEngineFactory {
|
internal object InferenceEngineFactory {
|
||||||
fun createInstance(context: Context) = InferenceEngineLoader.createInstance(context)
|
fun getInstance(context: Context) = InferenceEngineLoader.getInstance(context)
|
||||||
|
|
||||||
fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context)
|
fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -48,22 +48,22 @@ internal class InferenceEngineImpl private constructor(
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create [InferenceEngineImpl] instance with specific tier
|
* Create [InferenceEngineImpl] instance with specific tier
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if tier's library name is invalid
|
||||||
|
* @throws UnsatisfiedLinkError if library failed to load
|
||||||
*/
|
*/
|
||||||
internal fun createWithTier(tier: LLamaTier): InferenceEngineImpl? {
|
internal fun createWithTier(tier: LLamaTier): InferenceEngineImpl {
|
||||||
if (initialized) {
|
assert(!initialized) { "Inference Engine has already been initialized!" }
|
||||||
Log.w(TAG, "LLamaAndroid already initialized")
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
require(tier.libraryName.isNotBlank()) { "Unexpected library: ${tier.libraryName}" }
|
||||||
|
|
||||||
|
return try {
|
||||||
Log.i(TAG, "Instantiating InferenceEngineImpl w/ ${tier.libraryName}")
|
Log.i(TAG, "Instantiating InferenceEngineImpl w/ ${tier.libraryName}")
|
||||||
val instance = InferenceEngineImpl(tier)
|
InferenceEngineImpl(tier).also { initialized = true }
|
||||||
initialized = true
|
|
||||||
return instance
|
|
||||||
|
|
||||||
} catch (e: UnsatisfiedLinkError) {
|
} catch (e: UnsatisfiedLinkError) {
|
||||||
Log.e(TAG, "Failed to load ${tier.libraryName}", e)
|
Log.e(TAG, "Failed to load ${tier.libraryName}", e)
|
||||||
return null
|
throw e
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -37,39 +37,46 @@ internal object InferenceEngineLoader {
|
||||||
|
|
||||||
private var _cachedInstance: InferenceEngineImpl? = null
|
private var _cachedInstance: InferenceEngineImpl? = null
|
||||||
private var _detectedTier: LLamaTier? = null
|
private var _detectedTier: LLamaTier? = null
|
||||||
val detectedTier: LLamaTier? get() = _detectedTier
|
|
||||||
|
/**
|
||||||
|
* Get the detected tier, loading from cache if needed
|
||||||
|
*/
|
||||||
|
fun getDetectedTier(context: Context): LLamaTier? =
|
||||||
|
_detectedTier ?: runBlocking {
|
||||||
|
loadDetectedTierFromDataStore(context)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Factory method to get a configured [InferenceEngineImpl] instance.
|
* Factory method to get a configured [InferenceEngineImpl] instance.
|
||||||
* Handles tier detection, caching, and library loading automatically.
|
* Handles tier detection, caching, and library loading automatically.
|
||||||
*/
|
*/
|
||||||
@Synchronized
|
@Synchronized
|
||||||
fun createInstance(context: Context): InferenceEngine? {
|
fun getInstance(context: Context): InferenceEngine? {
|
||||||
// Return cached instance if available
|
// Return cached instance if available
|
||||||
_cachedInstance?.let { return it }
|
_cachedInstance?.let { return it }
|
||||||
|
|
||||||
return runBlocking {
|
return runBlocking {
|
||||||
|
// Obtain the optimal tier from cache if available
|
||||||
|
val tier = loadDetectedTierFromDataStore(context) ?: run {
|
||||||
|
Log.i(TAG, "Performing fresh tier detection")
|
||||||
|
detectAndSaveOptimalTier(context)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tier == null || tier == LLamaTier.NONE) {
|
||||||
|
Log.e(TAG, "Aborted instantiating Inference Engine due to invalid tier")
|
||||||
|
return@runBlocking null
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Obtain the optimal tier from cache if available
|
|
||||||
val tier = getOrDetectOptimalTier(context) ?: run {
|
|
||||||
Log.e(TAG, "Failed to determine optimal tier")
|
|
||||||
return@runBlocking null
|
|
||||||
}
|
|
||||||
_detectedTier = tier
|
|
||||||
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
|
|
||||||
|
|
||||||
// Create and cache the inference engine instance
|
// Create and cache the inference engine instance
|
||||||
val instance = InferenceEngineImpl.createWithTier(tier) ?: run {
|
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
|
||||||
Log.e(TAG, "Failed to instantiate InferenceEngineImpl")
|
InferenceEngineImpl.createWithTier(tier).also {
|
||||||
return@runBlocking null
|
_cachedInstance = it
|
||||||
|
Log.i(TAG, "Successfully instantiated Inference Engine w/ ${tier.name}")
|
||||||
}
|
}
|
||||||
_cachedInstance = instance
|
|
||||||
Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}")
|
|
||||||
|
|
||||||
instance
|
|
||||||
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.e(TAG, "Error creating InferenceEngineImpl instance", e)
|
Log.e(TAG, "Error instantiating Inference Engine", e)
|
||||||
null
|
null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -86,31 +93,37 @@ internal object InferenceEngineLoader {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get optimal tier from cache or detect it fresh
|
* Load cached tier from datastore without performing detection
|
||||||
*/
|
*/
|
||||||
private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? {
|
private suspend fun loadDetectedTierFromDataStore(context: Context): LLamaTier? {
|
||||||
val preferences = context.llamaTierDataStore.data.first()
|
val preferences = context.llamaTierDataStore.data.first()
|
||||||
|
|
||||||
// Check if we have a cached result with the current detection version
|
|
||||||
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
|
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
|
||||||
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
|
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
|
||||||
if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
|
|
||||||
val cachedTier = LLamaTier.Companion.fromRawValue(cachedTierValue)
|
|
||||||
if (cachedTier != null) {
|
|
||||||
Log.i(TAG, "Using cached tier detection: ${cachedTier.name}")
|
|
||||||
return cachedTier
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// No valid cache, detect fresh
|
return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
|
||||||
Log.i(TAG, "Performing fresh tier detection")
|
LLamaTier.fromRawValue(cachedTierValue)?.also {
|
||||||
return detectAndCacheOptimalTier(context)
|
Log.i(TAG, "Loaded cached tier: ${it.name}")
|
||||||
|
_detectedTier = it
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Log.i(TAG, "No valid cached tier found")
|
||||||
|
null
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Detect optimal tier and save to cache
|
* Detect optimal tier and save to cache
|
||||||
*/
|
*/
|
||||||
private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? {
|
private suspend fun detectAndSaveOptimalTier(context: Context): LLamaTier? =
|
||||||
|
detectOptimalTier().also { tier ->
|
||||||
|
tier?.saveToDataStore(context)
|
||||||
|
_detectedTier = tier
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Detect optimal tier and save to cache
|
||||||
|
*/
|
||||||
|
private fun detectOptimalTier(): LLamaTier? {
|
||||||
try {
|
try {
|
||||||
// Load CPU detection library
|
// Load CPU detection library
|
||||||
System.loadLibrary("llama_cpu_detector")
|
System.loadLibrary("llama_cpu_detector")
|
||||||
|
|
@ -122,44 +135,34 @@ internal object InferenceEngineLoader {
|
||||||
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
|
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
|
||||||
|
|
||||||
// Convert to enum and validate
|
// Convert to enum and validate
|
||||||
val tier = LLamaTier.Companion.fromRawValue(tierValue) ?: run {
|
val tier = LLamaTier.fromRawValue(tierValue) ?: run {
|
||||||
Log.w(TAG, "Invalid tier value $tierValue")
|
Log.e(TAG, "Invalid tier value $tierValue")
|
||||||
return null
|
return LLamaTier.NONE
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we don't exceed maximum supported tier
|
// Ensure we don't exceed maximum supported tier
|
||||||
val finalTier = if (tier.rawValue > LLamaTier.Companion.maxSupportedTier.rawValue) {
|
return if (tier.rawValue > LLamaTier.maxSupportedTier.rawValue) {
|
||||||
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.Companion.maxSupportedTier.name}")
|
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${LLamaTier.maxSupportedTier.name}")
|
||||||
LLamaTier.Companion.maxSupportedTier
|
LLamaTier.maxSupportedTier
|
||||||
} else {
|
} else {
|
||||||
tier
|
tier
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache the result
|
|
||||||
context.llamaTierDataStore.edit {
|
|
||||||
it[DETECTED_TIER] = finalTier.rawValue
|
|
||||||
it[DETECTION_VERSION] = DATASTORE_VERSION
|
|
||||||
}
|
|
||||||
|
|
||||||
Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}")
|
|
||||||
return finalTier
|
|
||||||
|
|
||||||
} catch (e: UnsatisfiedLinkError) {
|
} catch (e: UnsatisfiedLinkError) {
|
||||||
Log.e(TAG, "Failed to load CPU detection library", e)
|
Log.e(TAG, "Failed to load CPU detection library", e)
|
||||||
|
return null
|
||||||
// Fallback to T0 and cache it
|
|
||||||
val fallbackTier = LLamaTier.T0
|
|
||||||
context.llamaTierDataStore.edit {
|
|
||||||
it[DETECTED_TIER] = fallbackTier.rawValue
|
|
||||||
it[DETECTION_VERSION] = DATASTORE_VERSION
|
|
||||||
}
|
|
||||||
|
|
||||||
Log.i(TAG, "Using fallback tier: ${fallbackTier.name}")
|
|
||||||
return fallbackTier
|
|
||||||
|
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.e(TAG, "Unexpected error during tier detection", e)
|
Log.e(TAG, "Unexpected error during tier detection", e)
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private suspend fun LLamaTier.saveToDataStore(context: Context) {
|
||||||
|
context.llamaTierDataStore.edit { prefs ->
|
||||||
|
prefs[DETECTED_TIER] = this.rawValue
|
||||||
|
prefs[DETECTION_VERSION] = DATASTORE_VERSION
|
||||||
|
}
|
||||||
|
Log.i(TAG, "Saved ${this.name} to data store")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import android.llama.cpp.TierDetection
|
||||||
*/
|
*/
|
||||||
internal class TierDetectionImpl(private val context: Context) : TierDetection {
|
internal class TierDetectionImpl(private val context: Context) : TierDetection {
|
||||||
override val detectedTier: LLamaTier?
|
override val detectedTier: LLamaTier?
|
||||||
get() = InferenceEngineLoader.detectedTier
|
get() = InferenceEngineLoader.getDetectedTier(context)
|
||||||
|
|
||||||
override fun clearCache() = InferenceEngineLoader.clearCache(context)
|
override fun clearCache() = InferenceEngineLoader.clearCache(context)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue