From 63e5bd07712fe32fb575bab079ae63a9749bcac6 Mon Sep 17 00:00:00 2001 From: Han Yin Date: Wed, 17 Sep 2025 13:38:04 -0700 Subject: [PATCH] lib: support x86-64 by dynamically set Arm related definitions --- .../example/llama/engine/StubTierDetection.kt | 2 +- examples/llama.android/llama/build.gradle.kts | 5 +- .../llama/src/main/cpp/CMakeLists.txt | 21 ++++++++ .../llama/src/main/cpp/cpu_detector.cpp | 50 ++++++++++++------- .../java/android/llama/cpp/ArmFeatures.kt | 10 ++-- .../java/android/llama/cpp/TierDetection.kt | 16 +++--- 6 files changed, 67 insertions(+), 37 deletions(-) diff --git a/examples/llama.android/app/src/main/java/com/example/llama/engine/StubTierDetection.kt b/examples/llama.android/app/src/main/java/com/example/llama/engine/StubTierDetection.kt index f40f8a65af..c27050027b 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/engine/StubTierDetection.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/engine/StubTierDetection.kt @@ -10,7 +10,7 @@ import android.util.Log object StubTierDetection : TierDetection { private val tag = StubTierDetection::class.java.simpleName - override fun getDetectedTier(): LLamaTier? = LLamaTier.T2 + override fun getDetectedTier(): LLamaTier? = LLamaTier.T3 override fun clearCache() { Log.d(tag, "Cache cleared") diff --git a/examples/llama.android/llama/build.gradle.kts b/examples/llama.android/llama/build.gradle.kts index d24b73ff68..f8dcddbe25 100644 --- a/examples/llama.android/llama/build.gradle.kts +++ b/examples/llama.android/llama/build.gradle.kts @@ -17,7 +17,7 @@ android { consumerProguardFiles("consumer-rules.pro") ndk { - abiFilters += listOf("arm64-v8a") + abiFilters += listOf("arm64-v8a", "x86_64") } externalNativeBuild { cmake { @@ -29,12 +29,9 @@ android { arguments += "-DLLAMA_BUILD_COMMON=ON" arguments += "-DLLAMA_CURL=OFF" - arguments += "-DGGML_SYSTEM_ARCH=ARM" // Undocumented before 3.21 arguments += "-DGGML_NATIVE=OFF" arguments += "-DGGML_BACKEND_DL=ON" arguments += "-DGGML_CPU_ALL_VARIANTS=ON" - arguments += "-DGGML_CPU_KLEIDIAI=ON" - arguments += "-DGGML_OPENMP=ON" arguments += "-DGGML_LLAMAFILE=OFF" } } diff --git a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt b/examples/llama.android/llama/src/main/cpp/CMakeLists.txt index 1df1b31552..d1ac9274e5 100644 --- a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +++ b/examples/llama.android/llama/src/main/cpp/CMakeLists.txt @@ -27,12 +27,33 @@ target_link_libraries(kleidi-llama-cpu-detector # 2. Kleidi Llama library # -------------------------------------------------------------------------- +if(DEFINED ANDROID_ABI) + message(STATUS "Detected Android ABI: ${ANDROID_ABI}") + if(ANDROID_ABI STREQUAL "arm64-v8a") + set(GGML_SYSTEM_ARCH "ARM") + set(GGML_CPU_KLEIDIAI ON) + set(GGML_OPENMP ON) + elseif(ANDROID_ABI STREQUAL "x86_64") + set(GGML_SYSTEM_ARCH "x86") + set(GGML_CPU_KLEIDIAI OFF) + set(GGML_OPENMP OFF) + else() + message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}") + endif() +endif() + set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../) add_subdirectory(${LLAMA_SRC} build-llama) add_library(${CMAKE_PROJECT_NAME} SHARED kleidi-llama.cpp) +target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE + GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH} + GGML_CPU_KLEIDIAI=$ + GGML_OPENMP=$ +) + target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE ${LLAMA_SRC} ${LLAMA_SRC}/common diff --git a/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp b/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp index 9f3a1ce9d6..85826754a1 100644 --- a/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp +++ b/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp @@ -1,45 +1,53 @@ #include -#include #include #include +#if defined(__aarch64__) +#include using namespace cpu_features; +static const Aarch64Info info = GetAarch64Info(); +static const Aarch64Features features = info.features; +#endif #define LOG_TAG "CpuDetector" #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) -static const Aarch64Info info = GetAarch64Info(); -static const Aarch64Features features = info.features; - extern "C" JNIEXPORT jint JNICALL Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier( - JNIEnv* env, - jclass clazz) { + JNIEnv* /*env*/, + jobject /*clazz*/) { int tier = 0; // Default to T0 (baseline) +#if defined(__aarch64__) // Check features in reverse order (highest tier first) - // TODO-han.yin: implement T4 once obtaining an Android device with SME! - if (features.sve && features.sve2) { - tier = 3; // T3: ARMv9-a with SVE/SVE2 - LOGI("Detected SVE/SVE2 support - selecting T3"); + if (features.sme) { + tier = 5; // T5: ARMv9-a with SVE/SVE2 + LOGI("Detected SME support - selecting T5"); + } + else if (features.sve && features.sve2) { + tier = 4; // T4: ARMv9-a with SVE/SVE2 + LOGI("Detected SVE/SVE2 support - selecting T4"); } else if (features.i8mm) { - tier = 2; // T2: ARMv8.6-a with i8mm - LOGI("Detected i8mm support - selecting T2"); + tier = 3; // T3: ARMv8.6-a with i8mm + LOGI("Detected i8mm support - selecting T3"); } else if (features.asimddp) { - tier = 1; // T1: ARMv8.2-a with dotprod - LOGI("Detected dotprod support - selecting T1"); + tier = 2; // T2: ARMv8.2-a with dotprod + LOGI("Detected dotprod support - selecting T2"); } else if (features.asimd) { - tier = 0; // T0: baseline ARMv8-a with ASIMD - LOGI("Detected basic ASIMD support - selecting T0"); + tier = 1; // T1: baseline ARMv8-a with ASIMD + LOGI("Detected basic ASIMD support - selecting T1"); } else { // Fallback - this shouldn't happen on arm64-v8a devices - tier = 0; - LOGI("No expected features detected - falling back to T0"); + tier = 1; + LOGI("No expected features detected - falling back to T1"); } +#else + LOGI("non aarch64 architecture detected - defaulting to T0"); +#endif return tier; } @@ -48,15 +56,19 @@ Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier( extern "C" JNIEXPORT jstring JNICALL Java_android_llama_cpp_internal_TierDetectionImpl_getCpuFeaturesString( JNIEnv* env, - jclass clazz) { + jobject /*clazz*/) { std::string text; +#if defined(__aarch64__) if (features.asimd) text += "ASIMD "; if (features.asimddp) text += "ASIMDDP "; if (features.i8mm) text += "I8MM "; if (features.sve) text += "SVE "; if (features.sve2) text += "SVE2 "; if (features.sme) text += "SME "; +#else + LOGI("non aarch64 architecture detected"); +#endif return env->NewStringUTF(text.c_str()); } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt index 1ca4949450..d3b1f668b0 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt @@ -81,10 +81,10 @@ object ArmFeaturesMapper { private fun getSupportedFeatures(tier: LLamaTier?): List? = when (tier) { LLamaTier.NONE, null -> null // No tier detected - LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only - LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD - LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM - LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2 - LLamaTier.T4 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2 + LLamaTier.T1 -> listOf(true, false, false, false, false) // ASIMD only + LLamaTier.T2 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD + LLamaTier.T3 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM + LLamaTier.T4 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2 + LLamaTier.T5 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2 } } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt index 5728c1c45b..26224359b5 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt @@ -12,17 +12,17 @@ interface TierDetection { * ARM optimization tiers supported by the Kleidi-Llama library. * Higher tiers provide better performance on supported hardware. */ -enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) { - NONE(-404, "", "No valid ArmĀ® optimization available!"), - T0(0, "llama_android_t0", "ARMv8-a baseline with ASIMD"), - T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"), - T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"), - T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2"), - T4(4, "llama_android_t4", "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2"); +enum class LLamaTier(val rawValue: Int, val description: String) { + NONE(0, "No valid ArmĀ® optimization available!"), + T1(1, "ARMv8-a baseline with ASIMD"), + T2(2, "ARMv8.2-a with DotProd"), + T3(3, "ARMv8.6-a with DotProd + I8MM"), + T4(4, "ARMv9-a with DotProd + I8MM + SVE/SVE2"), + T5(5, "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2"); companion object { fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value } - val maxSupportedTier = T3 + val maxSupportedTier = T5 } }