diff --git a/.gitmodules b/.gitmodules index b3c259fa9e..e69de29bb2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "include/cpu_features"] - path = include/cpu_features - url = https://github.com/google/cpu_features diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt index 4923e8e764..52c5dc2154 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt @@ -14,7 +14,6 @@ import androidx.recyclerview.widget.LinearLayoutManager import androidx.recyclerview.widget.RecyclerView import com.arm.aichat.AiChat import com.arm.aichat.InferenceEngine -import com.arm.aichat.TierDetection import com.arm.aichat.gguf.GgufMetadata import com.arm.aichat.gguf.GgufMetadataReader import com.google.android.material.floatingactionbutton.FloatingActionButton @@ -30,18 +29,16 @@ import java.util.UUID class MainActivity : AppCompatActivity() { // Android views - private lateinit var tierTv: TextView - private lateinit var pickerBtn: FloatingActionButton private lateinit var ggufTv: TextView private lateinit var messagesRv: RecyclerView private lateinit var userInputEt: EditText - private lateinit var userSendBtn: FloatingActionButton + private lateinit var userActionFab: FloatingActionButton - // Arm AI Chat engine and utils - private lateinit var detection: TierDetection + // Arm AI Chat inference engine private lateinit var engine: InferenceEngine // Conversation states + private var isModelReady = false private val messages = mutableListOf() private val lastAssistantMsg = StringBuilder() private val messageAdapter = MessageAdapter(messages) @@ -52,35 +49,27 @@ class MainActivity : AppCompatActivity() { setContentView(R.layout.activity_main) // Find views - tierTv = findViewById(R.id.tier) - pickerBtn = findViewById(R.id.pick_model) ggufTv = findViewById(R.id.gguf) messagesRv = findViewById(R.id.messages) messagesRv.layoutManager = LinearLayoutManager(this) messagesRv.adapter = messageAdapter userInputEt = findViewById(R.id.user_input) - userSendBtn = findViewById(R.id.user_send) + userActionFab = findViewById(R.id.fab) // Arm AI Chat initialization lifecycleScope.launch(Dispatchers.Default) { - // Obtain the device's CPU feature tier - detection = AiChat.getTierDetection(applicationContext) - withContext(Dispatchers.Main) { - tierTv.text = detection.getDetectedTier()?.description ?: "N/A" - } - - // Obtain the inference engine engine = AiChat.getInferenceEngine(applicationContext) } - // Upon file picker button tapped, prompt user to select a GGUF metadata on the device - pickerBtn.setOnClickListener { - getContent.launch(arrayOf("*/*")) - } - - // Upon user send button tapped, validate input and send to engine - userSendBtn.setOnClickListener { - handleUserInput() + // Upon CTA button tapped + userActionFab.setOnClickListener { + if (isModelReady) { + // If model is ready, validate input and send to engine + handleUserInput() + } else { + // Otherwise, prompt user to select a GGUF metadata on the device + getContent.launch(arrayOf("*/*")) + } } } @@ -96,7 +85,7 @@ class MainActivity : AppCompatActivity() { */ private fun handleSelectedModel(uri: Uri) { // Update UI states - pickerBtn.isEnabled = false + userActionFab.isEnabled = false userInputEt.hint = "Parsing GGUF..." ggufTv.text = "Parsing metadata from selected file \n$uri" @@ -120,9 +109,11 @@ class MainActivity : AppCompatActivity() { loadModel(modelName, modelFile) withContext(Dispatchers.Main) { + isModelReady = true userInputEt.hint = "Type and send a message!" userInputEt.isEnabled = true - userSendBtn.isEnabled = true + userActionFab.setImageResource(R.drawable.outline_send_24) + userActionFab.isEnabled = true } } } @@ -171,7 +162,7 @@ class MainActivity : AppCompatActivity() { Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show() } else { userInputEt.text = null - userSendBtn.isEnabled = false + userActionFab.isEnabled = false // Update message states messages.add(Message(UUID.randomUUID().toString(), userSsg, true)) @@ -182,7 +173,7 @@ class MainActivity : AppCompatActivity() { engine.sendUserPrompt(userSsg) .onCompletion { withContext(Dispatchers.Main) { - userSendBtn.isEnabled = true + userActionFab.isEnabled = true } }.collect { token -> val messageCount = messages.size diff --git a/examples/llama.android/app/src/main/res/layout/activity_main.xml b/examples/llama.android/app/src/main/res/layout/activity_main.xml index 90eda033e7..bf6ef35925 100644 --- a/examples/llama.android/app/src/main/res/layout/activity_main.xml +++ b/examples/llama.android/app/src/main/res/layout/activity_main.xml @@ -13,48 +13,28 @@ android:orientation="vertical" tools:context=".MainActivity"> - - - - - - - - - + android:layout_weight="1"> - + android:layout_height="wrap_content" + android:fadeScrollbars="false"> - + + + + + + android:src="@drawable/outline_folder_open_24" /> - \ No newline at end of file + diff --git a/examples/llama.android/lib/build.gradle.kts b/examples/llama.android/lib/build.gradle.kts index 263fee6068..5255f0c17b 100644 --- a/examples/llama.android/lib/build.gradle.kts +++ b/examples/llama.android/lib/build.gradle.kts @@ -1,7 +1,6 @@ plugins { alias(libs.plugins.android.library) alias(libs.plugins.jetbrains.kotlin.android) - `maven-publish` } android { @@ -70,31 +69,6 @@ android { } } -publishing { - publications { - register("release") { - groupId = "com.arm" - artifactId = "ai-chat" - version = "0.1.0" - - afterEvaluate { - from(components["release"]) - } - } - } - - repositories { - maven { - name = "artifactory" - url = uri(project.findProperty("artifactoryUrl") as? String ?: "") - credentials { - username = project.findProperty("artifactoryUsername") as? String ?: "" - password = project.findProperty("artifactoryPassword") as? String ?: "" - } - } - } -} - dependencies { implementation(libs.androidx.core.ktx) implementation(libs.androidx.datastore.preferences) diff --git a/examples/llama.android/lib/src/main/cpp/CMakeLists.txt b/examples/llama.android/lib/src/main/cpp/CMakeLists.txt index 06d3d03aa5..7862c61a3f 100644 --- a/examples/llama.android/lib/src/main/cpp/CMakeLists.txt +++ b/examples/llama.android/lib/src/main/cpp/CMakeLists.txt @@ -12,19 +12,7 @@ set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE) # -------------------------------------------------------------------------- -# 1. CPU feature detection library -# -------------------------------------------------------------------------- -add_subdirectory( - ${CMAKE_CURRENT_LIST_DIR}/../../../../../../include/cpu_features - ${CMAKE_BINARY_DIR}/cpu_features_build) -add_library(cpu-detector SHARED cpu_detector.cpp) -target_link_libraries(cpu-detector - PRIVATE CpuFeatures::cpu_features - android - log) - -# -------------------------------------------------------------------------- -# 2. AI Chat library +# AI Chat library # -------------------------------------------------------------------------- if(DEFINED ANDROID_ABI) diff --git a/examples/llama.android/lib/src/main/cpp/cpu_detector.cpp b/examples/llama.android/lib/src/main/cpp/cpu_detector.cpp deleted file mode 100644 index 05bc238056..0000000000 --- a/examples/llama.android/lib/src/main/cpp/cpu_detector.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include -#include -#include - -#if defined(__aarch64__) -#include -using namespace cpu_features; -static const Aarch64Info info = GetAarch64Info(); -static const Aarch64Features features = info.features; -#endif - -#define LOG_TAG "CpuDetector" -#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) - -extern "C" JNIEXPORT jint JNICALL -Java_com_arm_aichat_internal_TierDetectionImpl_getOptimalTier( - JNIEnv* /*env*/, - jobject /*clazz*/) { - int tier = 0; // Default to T0 (baseline) - -#if defined(__aarch64__) - // Check features in reverse order (highest tier first) - if (features.sme) { - tier = 5; // T5: ARMv9-a with SVE/SVE2 - LOGI("Detected SME support - selecting T5"); - } - else if (features.sve && features.sve2) { - tier = 4; // T4: ARMv9-a with SVE/SVE2 - LOGI("Detected SVE/SVE2 support - selecting T4"); - } - else if (features.i8mm) { - tier = 3; // T3: ARMv8.6-a with i8mm - LOGI("Detected i8mm support - selecting T3"); - } - else if (features.asimddp) { - tier = 2; // T2: ARMv8.2-a with dotprod - LOGI("Detected dotprod support - selecting T2"); - } - else if (features.asimd) { - tier = 1; // T1: baseline ARMv8-a with ASIMD - LOGI("Detected basic ASIMD support - selecting T1"); - } - else { - // Fallback - this shouldn't happen on arm64-v8a devices - tier = 1; - LOGI("No expected features detected - falling back to T1"); - } -#else - LOGI("non aarch64 architecture detected - defaulting to T0"); -#endif - - return tier; -} - -// Optional: Keep a feature string function for debugging -extern "C" JNIEXPORT jstring JNICALL -Java_com_arm_aichat_internal_TierDetectionImpl_getCpuFeaturesString( - JNIEnv* env, - jobject /*clazz*/) { - std::string text; - -#if defined(__aarch64__) - if (features.asimd) text += "ASIMD "; - if (features.asimddp) text += "ASIMDDP "; - if (features.i8mm) text += "I8MM "; - if (features.sve) text += "SVE "; - if (features.sve2) text += "SVE2 "; - if (features.sme) text += "SME "; -#else - LOGI("non aarch64 architecture detected"); -#endif - - return env->NewStringUTF(text.c_str()); -} diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt index 151b2fbfa2..b72a24ec1d 100644 --- a/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt +++ b/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt @@ -2,7 +2,6 @@ package com.arm.aichat import android.content.Context import com.arm.aichat.internal.InferenceEngineImpl -import com.arm.aichat.internal.TierDetectionImpl /** * Main entry point for Arm's AI Chat library. @@ -12,9 +11,4 @@ object AiChat { * Get the inference engine single instance. */ fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context) - - /** - * Get tier detection single instance. - */ - fun getTierDetection(context: Context): TierDetection = TierDetectionImpl.getInstance(context) } diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/ArmFeatures.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/ArmFeatures.kt deleted file mode 100644 index cd7b4b4d31..0000000000 --- a/examples/llama.android/lib/src/main/java/com/arm/aichat/ArmFeatures.kt +++ /dev/null @@ -1,90 +0,0 @@ -package com.arm.aichat - -/** - * Represents an Arm® CPU feature with its metadata. - */ -data class ArmFeature( - val name: String, - val displayName: String, - val description: String, - val armDocUrl: String -) - -/** - * Helper class to map [ArmCpuTier] to supported Arm® features. - */ -object ArmFeaturesMapper { - - /** - * UI display item combining feature info with support status. - */ - data class DisplayItem( - val feature: ArmFeature, - val isSupported: Boolean - ) - - /** - * All Arm® features supported by the library, in order of introduction. - */ - val allFeatures = listOf( - ArmFeature( - name = "ASIMD", - displayName = "ASIMD", - description = "Advanced SIMD (NEON) - baseline vectorization", - armDocUrl = "https://community.arm.com/arm-community-blogs/b/architectures-and-processors-blog/posts/matrix-matrix-multiplication-neon-sve-and-sme-compared" - ), - ArmFeature( - name = "DOTPROD", - displayName = "DOTPROD", - description = "Dot Product instructions for neural networks", - armDocUrl = "https://community.arm.com/arm-community-blogs/b/tools-software-ides-blog/posts/exploring-the-arm-dot-product-instructions" - ), - ArmFeature( - name = "I8MM", - displayName = "I8MM", - description = "Integer 8-bit Matrix Multiplication", - armDocUrl = "https://community.arm.com/arm-community-blogs/b/ai-blog/posts/optimize-llama-cpp-with-arm-i8mm-instruction" - ), - ArmFeature( - name = "SVE", - displayName = "SVE", - description = "Scalable Vector Extension", - armDocUrl = "https://community.arm.com/arm-community-blogs/b/architectures-and-processors-blog/posts/sve2" - ), - ArmFeature( - name = "SME", - displayName = "SME", - description = "Scalable Matrix Extension", - armDocUrl = "https://newsroom.arm.com/blog/scalable-matrix-extension" - ) - ) - - /** - * Gets the feature support data for UI display. - */ - fun getFeatureDisplayData(tier: ArmCpuTier?): List? = - getSupportedFeatures(tier).let { optFlags -> - optFlags?.let { flags -> - allFeatures.mapIndexed { index, feature -> - DisplayItem( - feature = feature, - isSupported = flags.getOrElse(index) { false } - ) - } - } - } - - /** - * Maps a [ArmCpuTier] to its supported Arm® features. - * Returns a list of booleans where each index corresponds to allFeatures. - */ - private fun getSupportedFeatures(tier: ArmCpuTier?): List? = - when (tier) { - ArmCpuTier.NONE, null -> null // No tier detected - ArmCpuTier.T1 -> listOf(true, false, false, false, false) // ASIMD only - ArmCpuTier.T2 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD - ArmCpuTier.T3 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM - ArmCpuTier.T4 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2 - ArmCpuTier.T5 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2 - } -} diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/TierDetection.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/TierDetection.kt deleted file mode 100644 index 71908bfcf4..0000000000 --- a/examples/llama.android/lib/src/main/java/com/arm/aichat/TierDetection.kt +++ /dev/null @@ -1,28 +0,0 @@ -package com.arm.aichat - -/** - * Public interface for [ArmCpuTier] detection information. - */ -interface TierDetection { - fun getDetectedTier(): ArmCpuTier? - fun clearCache() -} - -/** - * ARM optimization tiers supported by this library. - * Higher tiers provide better performance on supported hardware. - */ -enum class ArmCpuTier(val rawValue: Int, val description: String) { - NONE(0, "No valid Arm® optimization available!"), - T1(1, "ARMv8-a baseline with ASIMD"), - T2(2, "ARMv8.2-a with DotProd"), - T3(3, "ARMv8.6-a with DotProd + I8MM"), - T4(4, "ARMv9-a with DotProd + I8MM + SVE/SVE2"), - T5(5, "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2"); - - companion object { - fun fromRawValue(value: Int): ArmCpuTier? = entries.find { it.rawValue == value } - - val maxSupportedTier = T5 - } -} diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/TierDetectionImpl.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/TierDetectionImpl.kt deleted file mode 100644 index ddd46c5707..0000000000 --- a/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/TierDetectionImpl.kt +++ /dev/null @@ -1,147 +0,0 @@ -package com.arm.aichat.internal - -import android.content.Context -import android.util.Log -import androidx.datastore.core.DataStore -import androidx.datastore.preferences.core.Preferences -import androidx.datastore.preferences.core.edit -import androidx.datastore.preferences.core.intPreferencesKey -import androidx.datastore.preferences.preferencesDataStore -import com.arm.aichat.ArmCpuTier -import com.arm.aichat.TierDetection -import kotlinx.coroutines.flow.first -import kotlinx.coroutines.runBlocking - -/** - * Internal [ArmCpuTier] detection implementation - */ -internal class TierDetectionImpl private constructor( - private val context: Context -): TierDetection { - - companion object { - private val TAG = TierDetectionImpl::class.simpleName - - // CPU feature detection preferences - private const val DATASTORE_CPU_DETECTION = "cpu-detection" - private const val DATASTORE_VERSION = 1 - private val Context.armCpuTierDataStore: DataStore - by preferencesDataStore(name = DATASTORE_CPU_DETECTION) - - private val DETECTION_VERSION = intPreferencesKey("detection_version") - private val DETECTED_TIER = intPreferencesKey("detected_tier") - - @Volatile - private var instance: TierDetection? = null - - /** - * Create or obtain [TierDetectionImpl]'s single instance. - * - * @param Context for obtaining the data store - */ - internal fun getInstance(context: Context) = - instance ?: synchronized(this) { - instance ?: TierDetectionImpl(context).also { instance = it } - } - } - - private external fun getOptimalTier(): Int - - private external fun getCpuFeaturesString(): String - - private var _detectedTier: ArmCpuTier? = null - - /** - * Get the detected tier, loading from cache if needed - */ - override fun getDetectedTier(): ArmCpuTier? = - _detectedTier ?: runBlocking { obtainTier() } - - /** - * First attempt to load detected tier from storage, if available; - * Otherwise, perform a fresh detection, then save to storage and cache. - */ - private suspend fun obtainTier() = - loadDetectedTierFromDataStore() ?: run { - Log.i(TAG, "Performing fresh tier detection") - performOptimalTierDetection().also { tier -> - tier?.saveToDataStore() - _detectedTier = tier - } - } - - /** - * Load cached tier from datastore without performing detection - */ - private suspend fun loadDetectedTierFromDataStore(): ArmCpuTier? { - val preferences = context.armCpuTierDataStore.data.first() - val cachedVersion = preferences[DETECTION_VERSION] ?: -1 - val cachedTierValue = preferences[DETECTED_TIER] ?: -1 - - return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) { - ArmCpuTier.fromRawValue(cachedTierValue)?.also { - Log.i(TAG, "Loaded cached tier: ${it.name}") - _detectedTier = it - } - } else { - Log.i(TAG, "No valid cached tier found") - null - } - } - - /** - * Actual implementation of optimal tier detection via native methods - */ - private fun performOptimalTierDetection(): ArmCpuTier? { - try { - // Load CPU detection library - System.loadLibrary("cpu-detector") - Log.i(TAG, "CPU feature detector loaded successfully") - - // Detect optimal tier - val tierValue = getOptimalTier() - val features = getCpuFeaturesString() - Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features") - - // Convert to enum and validate - val tier = ArmCpuTier.fromRawValue(tierValue) ?: run { - Log.e(TAG, "Invalid tier value $tierValue") - return ArmCpuTier.NONE - } - - // Ensure we don't exceed maximum supported tier - val maxTier = ArmCpuTier.maxSupportedTier - return if (tier.rawValue > maxTier.rawValue) { - Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}") - maxTier - } else { - tier - } - - } catch (e: UnsatisfiedLinkError) { - Log.e(TAG, "Failed to load CPU detection library", e) - return null - - } catch (e: Exception) { - Log.e(TAG, "Unexpected error during tier detection", e) - return null - } - } - - /** - * Clear cached detection results (for testing/debugging) - */ - override fun clearCache() { - runBlocking { context.armCpuTierDataStore.edit { it.clear() } } - _detectedTier = null - Log.i(TAG, "Cleared CPU detection results") - } - - private suspend fun ArmCpuTier.saveToDataStore() { - context.armCpuTierDataStore.edit { prefs -> - prefs[DETECTED_TIER] = this.rawValue - prefs[DETECTION_VERSION] = DATASTORE_VERSION - } - Log.i(TAG, "Saved ${this.name} to data store") - } -}