lib: change `LlamaTier` to `ArmCpuTier`
This commit is contained in:
parent
3644082a82
commit
266fc314ef
|
|
@ -1,6 +1,6 @@
|
|||
package com.arm.aiplayground.engine
|
||||
|
||||
import com.arm.aichat.LLamaTier
|
||||
import com.arm.aichat.ArmCpuTier
|
||||
import com.arm.aichat.TierDetection
|
||||
import android.util.Log
|
||||
|
||||
|
|
@ -10,7 +10,7 @@ import android.util.Log
|
|||
object StubTierDetection : TierDetection {
|
||||
private val tag = StubTierDetection::class.java.simpleName
|
||||
|
||||
override fun getDetectedTier(): LLamaTier? = LLamaTier.T3
|
||||
override fun getDetectedTier(): ArmCpuTier? = ArmCpuTier.T3
|
||||
|
||||
override fun clearCache() {
|
||||
Log.d(tag, "Cache cleared")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ package com.arm.aiplayground.viewmodel
|
|||
|
||||
import androidx.lifecycle.ViewModel
|
||||
import androidx.lifecycle.viewModelScope
|
||||
import com.arm.aichat.LLamaTier
|
||||
import com.arm.aichat.ArmCpuTier
|
||||
import com.arm.aichat.TierDetection
|
||||
import com.arm.aiplayground.data.repo.ModelRepository
|
||||
import com.arm.aiplayground.data.source.prefs.ColorThemeMode
|
||||
|
|
@ -66,7 +66,7 @@ class SettingsViewModel @Inject constructor(
|
|||
private val _darkThemeMode = MutableStateFlow(DarkThemeMode.AUTO)
|
||||
val darkThemeMode: StateFlow<DarkThemeMode> = _darkThemeMode.asStateFlow()
|
||||
|
||||
val detectedTier: LLamaTier?
|
||||
val detectedTier: ArmCpuTier?
|
||||
get() = tierDetection.getDetectedTier()
|
||||
|
||||
init {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ data class ArmFeature(
|
|||
)
|
||||
|
||||
/**
|
||||
* Helper class to map LLamaTier to supported Arm® features.
|
||||
* Helper class to map [ArmCpuTier] to supported Arm® features.
|
||||
*/
|
||||
object ArmFeaturesMapper {
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ object ArmFeaturesMapper {
|
|||
/**
|
||||
* Gets the feature support data for UI display.
|
||||
*/
|
||||
fun getFeatureDisplayData(tier: LLamaTier?): List<DisplayItem>? =
|
||||
fun getFeatureDisplayData(tier: ArmCpuTier?): List<DisplayItem>? =
|
||||
getSupportedFeatures(tier).let { optFlags ->
|
||||
optFlags?.let { flags ->
|
||||
allFeatures.mapIndexed { index, feature ->
|
||||
|
|
@ -75,16 +75,16 @@ object ArmFeaturesMapper {
|
|||
}
|
||||
|
||||
/**
|
||||
* Maps a LLamaTier to its supported Arm® features.
|
||||
* Maps a [ArmCpuTier] to its supported Arm® features.
|
||||
* Returns a list of booleans where each index corresponds to allFeatures.
|
||||
*/
|
||||
private fun getSupportedFeatures(tier: LLamaTier?): List<Boolean>? =
|
||||
private fun getSupportedFeatures(tier: ArmCpuTier?): List<Boolean>? =
|
||||
when (tier) {
|
||||
LLamaTier.NONE, null -> null // No tier detected
|
||||
LLamaTier.T1 -> listOf(true, false, false, false, false) // ASIMD only
|
||||
LLamaTier.T2 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
|
||||
LLamaTier.T3 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
|
||||
LLamaTier.T4 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2
|
||||
LLamaTier.T5 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2
|
||||
ArmCpuTier.NONE, null -> null // No tier detected
|
||||
ArmCpuTier.T1 -> listOf(true, false, false, false, false) // ASIMD only
|
||||
ArmCpuTier.T2 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
|
||||
ArmCpuTier.T3 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
|
||||
ArmCpuTier.T4 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2
|
||||
ArmCpuTier.T5 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
package com.arm.aichat
|
||||
|
||||
/**
|
||||
* Public interface for [LLamaTier] detection information.
|
||||
* Public interface for [ArmCpuTier] detection information.
|
||||
*/
|
||||
interface TierDetection {
|
||||
fun getDetectedTier(): LLamaTier?
|
||||
fun getDetectedTier(): ArmCpuTier?
|
||||
fun clearCache()
|
||||
}
|
||||
|
||||
|
|
@ -12,7 +12,7 @@ interface TierDetection {
|
|||
* ARM optimization tiers supported by this library.
|
||||
* Higher tiers provide better performance on supported hardware.
|
||||
*/
|
||||
enum class LLamaTier(val rawValue: Int, val description: String) {
|
||||
enum class ArmCpuTier(val rawValue: Int, val description: String) {
|
||||
NONE(0, "No valid Arm® optimization available!"),
|
||||
T1(1, "ARMv8-a baseline with ASIMD"),
|
||||
T2(2, "ARMv8.2-a with DotProd"),
|
||||
|
|
@ -21,7 +21,7 @@ enum class LLamaTier(val rawValue: Int, val description: String) {
|
|||
T5(5, "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2");
|
||||
|
||||
companion object {
|
||||
fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value }
|
||||
fun fromRawValue(value: Int): ArmCpuTier? = entries.find { it.rawValue == value }
|
||||
|
||||
val maxSupportedTier = T5
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ import androidx.datastore.preferences.core.Preferences
|
|||
import androidx.datastore.preferences.core.edit
|
||||
import androidx.datastore.preferences.core.intPreferencesKey
|
||||
import androidx.datastore.preferences.preferencesDataStore
|
||||
import com.arm.aichat.LLamaTier
|
||||
import com.arm.aichat.ArmCpuTier
|
||||
import com.arm.aichat.TierDetection
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.runBlocking
|
||||
|
||||
/**
|
||||
* Internal [LLamaTier] detection implementation
|
||||
* Internal [ArmCpuTier] detection implementation
|
||||
*/
|
||||
internal class TierDetectionImpl private constructor(
|
||||
private val context: Context
|
||||
|
|
@ -25,7 +25,7 @@ internal class TierDetectionImpl private constructor(
|
|||
// CPU feature detection preferences
|
||||
private const val DATASTORE_CPU_DETECTION = "cpu-detection"
|
||||
private const val DATASTORE_VERSION = 1
|
||||
private val Context.llamaTierDataStore: DataStore<Preferences>
|
||||
private val Context.armCpuTierDataStore: DataStore<Preferences>
|
||||
by preferencesDataStore(name = DATASTORE_CPU_DETECTION)
|
||||
|
||||
private val DETECTION_VERSION = intPreferencesKey("detection_version")
|
||||
|
|
@ -49,12 +49,12 @@ internal class TierDetectionImpl private constructor(
|
|||
|
||||
private external fun getCpuFeaturesString(): String
|
||||
|
||||
private var _detectedTier: LLamaTier? = null
|
||||
private var _detectedTier: ArmCpuTier? = null
|
||||
|
||||
/**
|
||||
* Get the detected tier, loading from cache if needed
|
||||
*/
|
||||
override fun getDetectedTier(): LLamaTier? =
|
||||
override fun getDetectedTier(): ArmCpuTier? =
|
||||
_detectedTier ?: runBlocking { obtainTier() }
|
||||
|
||||
/**
|
||||
|
|
@ -73,13 +73,13 @@ internal class TierDetectionImpl private constructor(
|
|||
/**
|
||||
* Load cached tier from datastore without performing detection
|
||||
*/
|
||||
private suspend fun loadDetectedTierFromDataStore(): LLamaTier? {
|
||||
val preferences = context.llamaTierDataStore.data.first()
|
||||
private suspend fun loadDetectedTierFromDataStore(): ArmCpuTier? {
|
||||
val preferences = context.armCpuTierDataStore.data.first()
|
||||
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
|
||||
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
|
||||
|
||||
return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
|
||||
LLamaTier.fromRawValue(cachedTierValue)?.also {
|
||||
ArmCpuTier.fromRawValue(cachedTierValue)?.also {
|
||||
Log.i(TAG, "Loaded cached tier: ${it.name}")
|
||||
_detectedTier = it
|
||||
}
|
||||
|
|
@ -92,7 +92,7 @@ internal class TierDetectionImpl private constructor(
|
|||
/**
|
||||
* Actual implementation of optimal tier detection via native methods
|
||||
*/
|
||||
private fun performOptimalTierDetection(): LLamaTier? {
|
||||
private fun performOptimalTierDetection(): ArmCpuTier? {
|
||||
try {
|
||||
// Load CPU detection library
|
||||
System.loadLibrary("cpu-detector")
|
||||
|
|
@ -104,13 +104,13 @@ internal class TierDetectionImpl private constructor(
|
|||
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
|
||||
|
||||
// Convert to enum and validate
|
||||
val tier = LLamaTier.fromRawValue(tierValue) ?: run {
|
||||
val tier = ArmCpuTier.fromRawValue(tierValue) ?: run {
|
||||
Log.e(TAG, "Invalid tier value $tierValue")
|
||||
return LLamaTier.NONE
|
||||
return ArmCpuTier.NONE
|
||||
}
|
||||
|
||||
// Ensure we don't exceed maximum supported tier
|
||||
val maxTier = LLamaTier.maxSupportedTier
|
||||
val maxTier = ArmCpuTier.maxSupportedTier
|
||||
return if (tier.rawValue > maxTier.rawValue) {
|
||||
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}")
|
||||
maxTier
|
||||
|
|
@ -132,13 +132,13 @@ internal class TierDetectionImpl private constructor(
|
|||
* Clear cached detection results (for testing/debugging)
|
||||
*/
|
||||
override fun clearCache() {
|
||||
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
|
||||
runBlocking { context.armCpuTierDataStore.edit { it.clear() } }
|
||||
_detectedTier = null
|
||||
Log.i(TAG, "Cleared CPU detection results")
|
||||
}
|
||||
|
||||
private suspend fun LLamaTier.saveToDataStore() {
|
||||
context.llamaTierDataStore.edit { prefs ->
|
||||
private suspend fun ArmCpuTier.saveToDataStore() {
|
||||
context.armCpuTierDataStore.edit { prefs ->
|
||||
prefs[DETECTED_TIER] = this.rawValue
|
||||
prefs[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue