From 266fc314ef7ae6990769f5be77c3caf273bd7db5 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Sun, 12 Oct 2025 15:59:51 -0700 Subject: [PATCH] lib: change `LlamaTier` to `ArmCpuTier` --- .../aiplayground/engine/StubTierDetection.kt | 4 +-- .../viewmodel/SettingsViewModel.kt | 4 +-- .../main/java/com/arm/aichat/ArmFeatures.kt | 20 ++++++------- .../main/java/com/arm/aichat/TierDetection.kt | 8 ++--- .../arm/aichat/internal/TierDetectionImpl.kt | 30 +++++++++---------- 5 files changed, 33 insertions(+), 33 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/arm/aiplayground/engine/StubTierDetection.kt b/examples/llama.android/app/src/main/java/com/arm/aiplayground/engine/StubTierDetection.kt index ca646f41c6..4d9955a531 100644 --- a/examples/llama.android/app/src/main/java/com/arm/aiplayground/engine/StubTierDetection.kt +++ b/examples/llama.android/app/src/main/java/com/arm/aiplayground/engine/StubTierDetection.kt @@ -1,6 +1,6 @@ package com.arm.aiplayground.engine -import com.arm.aichat.LLamaTier +import com.arm.aichat.ArmCpuTier import com.arm.aichat.TierDetection import android.util.Log @@ -10,7 +10,7 @@ import android.util.Log object StubTierDetection : TierDetection { private val tag = StubTierDetection::class.java.simpleName - override fun getDetectedTier(): LLamaTier? = LLamaTier.T3 + override fun getDetectedTier(): ArmCpuTier? = ArmCpuTier.T3 override fun clearCache() { Log.d(tag, "Cache cleared") diff --git a/examples/llama.android/app/src/main/java/com/arm/aiplayground/viewmodel/SettingsViewModel.kt b/examples/llama.android/app/src/main/java/com/arm/aiplayground/viewmodel/SettingsViewModel.kt index 35ebfd8341..c5d2c45c6e 100644 --- a/examples/llama.android/app/src/main/java/com/arm/aiplayground/viewmodel/SettingsViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/arm/aiplayground/viewmodel/SettingsViewModel.kt @@ -2,7 +2,7 @@ package com.arm.aiplayground.viewmodel import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope -import com.arm.aichat.LLamaTier +import com.arm.aichat.ArmCpuTier import com.arm.aichat.TierDetection import com.arm.aiplayground.data.repo.ModelRepository import com.arm.aiplayground.data.source.prefs.ColorThemeMode @@ -66,7 +66,7 @@ class SettingsViewModel @Inject constructor( private val _darkThemeMode = MutableStateFlow(DarkThemeMode.AUTO) val darkThemeMode: StateFlow = _darkThemeMode.asStateFlow() - val detectedTier: LLamaTier? + val detectedTier: ArmCpuTier? get() = tierDetection.getDetectedTier() init { diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/ArmFeatures.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/ArmFeatures.kt index 6ddca0d70b..cd7b4b4d31 100644 --- a/examples/llama.android/lib/src/main/java/com/arm/aichat/ArmFeatures.kt +++ b/examples/llama.android/lib/src/main/java/com/arm/aichat/ArmFeatures.kt @@ -11,7 +11,7 @@ data class ArmFeature( ) /** - * Helper class to map LLamaTier to supported Arm® features. + * Helper class to map [ArmCpuTier] to supported Arm® features. */ object ArmFeaturesMapper { @@ -62,7 +62,7 @@ object ArmFeaturesMapper { /** * Gets the feature support data for UI display. */ - fun getFeatureDisplayData(tier: LLamaTier?): List? = + fun getFeatureDisplayData(tier: ArmCpuTier?): List? = getSupportedFeatures(tier).let { optFlags -> optFlags?.let { flags -> allFeatures.mapIndexed { index, feature -> @@ -75,16 +75,16 @@ object ArmFeaturesMapper { } /** - * Maps a LLamaTier to its supported Arm® features. + * Maps a [ArmCpuTier] to its supported Arm® features. * Returns a list of booleans where each index corresponds to allFeatures. */ - private fun getSupportedFeatures(tier: LLamaTier?): List? = + private fun getSupportedFeatures(tier: ArmCpuTier?): List? = when (tier) { - LLamaTier.NONE, null -> null // No tier detected - LLamaTier.T1 -> listOf(true, false, false, false, false) // ASIMD only - LLamaTier.T2 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD - LLamaTier.T3 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM - LLamaTier.T4 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2 - LLamaTier.T5 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2 + ArmCpuTier.NONE, null -> null // No tier detected + ArmCpuTier.T1 -> listOf(true, false, false, false, false) // ASIMD only + ArmCpuTier.T2 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD + ArmCpuTier.T3 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM + ArmCpuTier.T4 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2 + ArmCpuTier.T5 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2 } } diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/TierDetection.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/TierDetection.kt index 3799346359..71908bfcf4 100644 --- a/examples/llama.android/lib/src/main/java/com/arm/aichat/TierDetection.kt +++ b/examples/llama.android/lib/src/main/java/com/arm/aichat/TierDetection.kt @@ -1,10 +1,10 @@ package com.arm.aichat /** - * Public interface for [LLamaTier] detection information. + * Public interface for [ArmCpuTier] detection information. */ interface TierDetection { - fun getDetectedTier(): LLamaTier? + fun getDetectedTier(): ArmCpuTier? fun clearCache() } @@ -12,7 +12,7 @@ interface TierDetection { * ARM optimization tiers supported by this library. * Higher tiers provide better performance on supported hardware. */ -enum class LLamaTier(val rawValue: Int, val description: String) { +enum class ArmCpuTier(val rawValue: Int, val description: String) { NONE(0, "No valid Arm® optimization available!"), T1(1, "ARMv8-a baseline with ASIMD"), T2(2, "ARMv8.2-a with DotProd"), @@ -21,7 +21,7 @@ enum class LLamaTier(val rawValue: Int, val description: String) { T5(5, "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2"); companion object { - fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value } + fun fromRawValue(value: Int): ArmCpuTier? = entries.find { it.rawValue == value } val maxSupportedTier = T5 } diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/TierDetectionImpl.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/TierDetectionImpl.kt index 70c8d00356..ddd46c5707 100644 --- a/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/TierDetectionImpl.kt +++ b/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/TierDetectionImpl.kt @@ -7,13 +7,13 @@ import androidx.datastore.preferences.core.Preferences import androidx.datastore.preferences.core.edit import androidx.datastore.preferences.core.intPreferencesKey import androidx.datastore.preferences.preferencesDataStore -import com.arm.aichat.LLamaTier +import com.arm.aichat.ArmCpuTier import com.arm.aichat.TierDetection import kotlinx.coroutines.flow.first import kotlinx.coroutines.runBlocking /** - * Internal [LLamaTier] detection implementation + * Internal [ArmCpuTier] detection implementation */ internal class TierDetectionImpl private constructor( private val context: Context @@ -25,7 +25,7 @@ internal class TierDetectionImpl private constructor( // CPU feature detection preferences private const val DATASTORE_CPU_DETECTION = "cpu-detection" private const val DATASTORE_VERSION = 1 - private val Context.llamaTierDataStore: DataStore + private val Context.armCpuTierDataStore: DataStore by preferencesDataStore(name = DATASTORE_CPU_DETECTION) private val DETECTION_VERSION = intPreferencesKey("detection_version") @@ -49,12 +49,12 @@ internal class TierDetectionImpl private constructor( private external fun getCpuFeaturesString(): String - private var _detectedTier: LLamaTier? = null + private var _detectedTier: ArmCpuTier? = null /** * Get the detected tier, loading from cache if needed */ - override fun getDetectedTier(): LLamaTier? = + override fun getDetectedTier(): ArmCpuTier? = _detectedTier ?: runBlocking { obtainTier() } /** @@ -73,13 +73,13 @@ internal class TierDetectionImpl private constructor( /** * Load cached tier from datastore without performing detection */ - private suspend fun loadDetectedTierFromDataStore(): LLamaTier? { - val preferences = context.llamaTierDataStore.data.first() + private suspend fun loadDetectedTierFromDataStore(): ArmCpuTier? { + val preferences = context.armCpuTierDataStore.data.first() val cachedVersion = preferences[DETECTION_VERSION] ?: -1 val cachedTierValue = preferences[DETECTED_TIER] ?: -1 return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) { - LLamaTier.fromRawValue(cachedTierValue)?.also { + ArmCpuTier.fromRawValue(cachedTierValue)?.also { Log.i(TAG, "Loaded cached tier: ${it.name}") _detectedTier = it } @@ -92,7 +92,7 @@ internal class TierDetectionImpl private constructor( /** * Actual implementation of optimal tier detection via native methods */ - private fun performOptimalTierDetection(): LLamaTier? { + private fun performOptimalTierDetection(): ArmCpuTier? { try { // Load CPU detection library System.loadLibrary("cpu-detector") @@ -104,13 +104,13 @@ internal class TierDetectionImpl private constructor( Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features") // Convert to enum and validate - val tier = LLamaTier.fromRawValue(tierValue) ?: run { + val tier = ArmCpuTier.fromRawValue(tierValue) ?: run { Log.e(TAG, "Invalid tier value $tierValue") - return LLamaTier.NONE + return ArmCpuTier.NONE } // Ensure we don't exceed maximum supported tier - val maxTier = LLamaTier.maxSupportedTier + val maxTier = ArmCpuTier.maxSupportedTier return if (tier.rawValue > maxTier.rawValue) { Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}") maxTier @@ -132,13 +132,13 @@ internal class TierDetectionImpl private constructor( * Clear cached detection results (for testing/debugging) */ override fun clearCache() { - runBlocking { context.llamaTierDataStore.edit { it.clear() } } + runBlocking { context.armCpuTierDataStore.edit { it.clear() } } _detectedTier = null Log.i(TAG, "Cleared CPU detection results") } - private suspend fun LLamaTier.saveToDataStore() { - context.llamaTierDataStore.edit { prefs -> + private suspend fun ArmCpuTier.saveToDataStore() { + context.armCpuTierDataStore.edit { prefs -> prefs[DETECTED_TIER] = this.rawValue prefs[DETECTION_VERSION] = DATASTORE_VERSION }