lib: change `LlamaTier` to `ArmCpuTier`

This commit is contained in:
Han Yin 2025-10-12 15:59:51 -07:00
parent 3644082a82
commit 266fc314ef
5 changed files with 33 additions and 33 deletions

View File

@ -1,6 +1,6 @@
package com.arm.aiplayground.engine
import com.arm.aichat.LLamaTier
import com.arm.aichat.ArmCpuTier
import com.arm.aichat.TierDetection
import android.util.Log
@ -10,7 +10,7 @@ import android.util.Log
object StubTierDetection : TierDetection {
private val tag = StubTierDetection::class.java.simpleName
override fun getDetectedTier(): LLamaTier? = LLamaTier.T3
override fun getDetectedTier(): ArmCpuTier? = ArmCpuTier.T3
override fun clearCache() {
Log.d(tag, "Cache cleared")

View File

@ -2,7 +2,7 @@ package com.arm.aiplayground.viewmodel
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.arm.aichat.LLamaTier
import com.arm.aichat.ArmCpuTier
import com.arm.aichat.TierDetection
import com.arm.aiplayground.data.repo.ModelRepository
import com.arm.aiplayground.data.source.prefs.ColorThemeMode
@ -66,7 +66,7 @@ class SettingsViewModel @Inject constructor(
private val _darkThemeMode = MutableStateFlow(DarkThemeMode.AUTO)
val darkThemeMode: StateFlow<DarkThemeMode> = _darkThemeMode.asStateFlow()
val detectedTier: LLamaTier?
val detectedTier: ArmCpuTier?
get() = tierDetection.getDetectedTier()
init {

View File

@ -11,7 +11,7 @@ data class ArmFeature(
)
/**
* Helper class to map LLamaTier to supported Arm® features.
* Helper class to map [ArmCpuTier] to supported Arm® features.
*/
object ArmFeaturesMapper {
@ -62,7 +62,7 @@ object ArmFeaturesMapper {
/**
* Gets the feature support data for UI display.
*/
fun getFeatureDisplayData(tier: LLamaTier?): List<DisplayItem>? =
fun getFeatureDisplayData(tier: ArmCpuTier?): List<DisplayItem>? =
getSupportedFeatures(tier).let { optFlags ->
optFlags?.let { flags ->
allFeatures.mapIndexed { index, feature ->
@ -75,16 +75,16 @@ object ArmFeaturesMapper {
}
/**
* Maps a LLamaTier to its supported Arm® features.
* Maps a [ArmCpuTier] to its supported Arm® features.
* Returns a list of booleans where each index corresponds to allFeatures.
*/
private fun getSupportedFeatures(tier: LLamaTier?): List<Boolean>? =
private fun getSupportedFeatures(tier: ArmCpuTier?): List<Boolean>? =
when (tier) {
LLamaTier.NONE, null -> null // No tier detected
LLamaTier.T1 -> listOf(true, false, false, false, false) // ASIMD only
LLamaTier.T2 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
LLamaTier.T3 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
LLamaTier.T4 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2
LLamaTier.T5 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2
ArmCpuTier.NONE, null -> null // No tier detected
ArmCpuTier.T1 -> listOf(true, false, false, false, false) // ASIMD only
ArmCpuTier.T2 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
ArmCpuTier.T3 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
ArmCpuTier.T4 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2
ArmCpuTier.T5 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2
}
}

View File

@ -1,10 +1,10 @@
package com.arm.aichat
/**
* Public interface for [LLamaTier] detection information.
* Public interface for [ArmCpuTier] detection information.
*/
interface TierDetection {
fun getDetectedTier(): LLamaTier?
fun getDetectedTier(): ArmCpuTier?
fun clearCache()
}
@ -12,7 +12,7 @@ interface TierDetection {
* ARM optimization tiers supported by this library.
* Higher tiers provide better performance on supported hardware.
*/
enum class LLamaTier(val rawValue: Int, val description: String) {
enum class ArmCpuTier(val rawValue: Int, val description: String) {
NONE(0, "No valid Arm® optimization available!"),
T1(1, "ARMv8-a baseline with ASIMD"),
T2(2, "ARMv8.2-a with DotProd"),
@ -21,7 +21,7 @@ enum class LLamaTier(val rawValue: Int, val description: String) {
T5(5, "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2");
companion object {
fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value }
fun fromRawValue(value: Int): ArmCpuTier? = entries.find { it.rawValue == value }
val maxSupportedTier = T5
}

View File

@ -7,13 +7,13 @@ import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.edit
import androidx.datastore.preferences.core.intPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import com.arm.aichat.LLamaTier
import com.arm.aichat.ArmCpuTier
import com.arm.aichat.TierDetection
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.runBlocking
/**
* Internal [LLamaTier] detection implementation
* Internal [ArmCpuTier] detection implementation
*/
internal class TierDetectionImpl private constructor(
private val context: Context
@ -25,7 +25,7 @@ internal class TierDetectionImpl private constructor(
// CPU feature detection preferences
private const val DATASTORE_CPU_DETECTION = "cpu-detection"
private const val DATASTORE_VERSION = 1
private val Context.llamaTierDataStore: DataStore<Preferences>
private val Context.armCpuTierDataStore: DataStore<Preferences>
by preferencesDataStore(name = DATASTORE_CPU_DETECTION)
private val DETECTION_VERSION = intPreferencesKey("detection_version")
@ -49,12 +49,12 @@ internal class TierDetectionImpl private constructor(
private external fun getCpuFeaturesString(): String
private var _detectedTier: LLamaTier? = null
private var _detectedTier: ArmCpuTier? = null
/**
* Get the detected tier, loading from cache if needed
*/
override fun getDetectedTier(): LLamaTier? =
override fun getDetectedTier(): ArmCpuTier? =
_detectedTier ?: runBlocking { obtainTier() }
/**
@ -73,13 +73,13 @@ internal class TierDetectionImpl private constructor(
/**
* Load cached tier from datastore without performing detection
*/
private suspend fun loadDetectedTierFromDataStore(): LLamaTier? {
val preferences = context.llamaTierDataStore.data.first()
private suspend fun loadDetectedTierFromDataStore(): ArmCpuTier? {
val preferences = context.armCpuTierDataStore.data.first()
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
LLamaTier.fromRawValue(cachedTierValue)?.also {
ArmCpuTier.fromRawValue(cachedTierValue)?.also {
Log.i(TAG, "Loaded cached tier: ${it.name}")
_detectedTier = it
}
@ -92,7 +92,7 @@ internal class TierDetectionImpl private constructor(
/**
* Actual implementation of optimal tier detection via native methods
*/
private fun performOptimalTierDetection(): LLamaTier? {
private fun performOptimalTierDetection(): ArmCpuTier? {
try {
// Load CPU detection library
System.loadLibrary("cpu-detector")
@ -104,13 +104,13 @@ internal class TierDetectionImpl private constructor(
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
// Convert to enum and validate
val tier = LLamaTier.fromRawValue(tierValue) ?: run {
val tier = ArmCpuTier.fromRawValue(tierValue) ?: run {
Log.e(TAG, "Invalid tier value $tierValue")
return LLamaTier.NONE
return ArmCpuTier.NONE
}
// Ensure we don't exceed maximum supported tier
val maxTier = LLamaTier.maxSupportedTier
val maxTier = ArmCpuTier.maxSupportedTier
return if (tier.rawValue > maxTier.rawValue) {
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}")
maxTier
@ -132,13 +132,13 @@ internal class TierDetectionImpl private constructor(
* Clear cached detection results (for testing/debugging)
*/
override fun clearCache() {
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
runBlocking { context.armCpuTierDataStore.edit { it.clear() } }
_detectedTier = null
Log.i(TAG, "Cleared CPU detection results")
}
private suspend fun LLamaTier.saveToDataStore() {
context.llamaTierDataStore.edit { prefs ->
private suspend fun ArmCpuTier.saveToDataStore() {
context.armCpuTierDataStore.edit { prefs ->
prefs[DETECTED_TIER] = this.rawValue
prefs[DETECTION_VERSION] = DATASTORE_VERSION
}