lib: support x86-64 by dynamically set Arm related definitions

This commit is contained in:
Han Yin 2025-09-17 13:38:04 -07:00
parent 8f90e42ee2
commit 63e5bd0771
6 changed files with 67 additions and 37 deletions

View File

@ -10,7 +10,7 @@ import android.util.Log
object StubTierDetection : TierDetection { object StubTierDetection : TierDetection {
private val tag = StubTierDetection::class.java.simpleName private val tag = StubTierDetection::class.java.simpleName
override fun getDetectedTier(): LLamaTier? = LLamaTier.T2 override fun getDetectedTier(): LLamaTier? = LLamaTier.T3
override fun clearCache() { override fun clearCache() {
Log.d(tag, "Cache cleared") Log.d(tag, "Cache cleared")

View File

@ -17,7 +17,7 @@ android {
consumerProguardFiles("consumer-rules.pro") consumerProguardFiles("consumer-rules.pro")
ndk { ndk {
abiFilters += listOf("arm64-v8a") abiFilters += listOf("arm64-v8a", "x86_64")
} }
externalNativeBuild { externalNativeBuild {
cmake { cmake {
@ -29,12 +29,9 @@ android {
arguments += "-DLLAMA_BUILD_COMMON=ON" arguments += "-DLLAMA_BUILD_COMMON=ON"
arguments += "-DLLAMA_CURL=OFF" arguments += "-DLLAMA_CURL=OFF"
arguments += "-DGGML_SYSTEM_ARCH=ARM" // Undocumented before 3.21
arguments += "-DGGML_NATIVE=OFF" arguments += "-DGGML_NATIVE=OFF"
arguments += "-DGGML_BACKEND_DL=ON" arguments += "-DGGML_BACKEND_DL=ON"
arguments += "-DGGML_CPU_ALL_VARIANTS=ON" arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
arguments += "-DGGML_CPU_KLEIDIAI=ON"
arguments += "-DGGML_OPENMP=ON"
arguments += "-DGGML_LLAMAFILE=OFF" arguments += "-DGGML_LLAMAFILE=OFF"
} }
} }

View File

@ -27,12 +27,33 @@ target_link_libraries(kleidi-llama-cpu-detector
# 2. Kleidi Llama library # 2. Kleidi Llama library
# -------------------------------------------------------------------------- # --------------------------------------------------------------------------
if(DEFINED ANDROID_ABI)
message(STATUS "Detected Android ABI: ${ANDROID_ABI}")
if(ANDROID_ABI STREQUAL "arm64-v8a")
set(GGML_SYSTEM_ARCH "ARM")
set(GGML_CPU_KLEIDIAI ON)
set(GGML_OPENMP ON)
elseif(ANDROID_ABI STREQUAL "x86_64")
set(GGML_SYSTEM_ARCH "x86")
set(GGML_CPU_KLEIDIAI OFF)
set(GGML_OPENMP OFF)
else()
message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}")
endif()
endif()
set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../) set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
add_subdirectory(${LLAMA_SRC} build-llama) add_subdirectory(${LLAMA_SRC} build-llama)
add_library(${CMAKE_PROJECT_NAME} SHARED add_library(${CMAKE_PROJECT_NAME} SHARED
kleidi-llama.cpp) kleidi-llama.cpp)
target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH}
GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}>
GGML_OPENMP=$<BOOL:${GGML_OPENMP}>
)
target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
${LLAMA_SRC} ${LLAMA_SRC}
${LLAMA_SRC}/common ${LLAMA_SRC}/common

View File

@ -1,45 +1,53 @@
#include <jni.h> #include <jni.h>
#include <cpuinfo_aarch64.h>
#include <android/log.h> #include <android/log.h>
#include <string> #include <string>
#if defined(__aarch64__)
#include <cpuinfo_aarch64.h>
using namespace cpu_features; using namespace cpu_features;
static const Aarch64Info info = GetAarch64Info();
static const Aarch64Features features = info.features;
#endif
#define LOG_TAG "CpuDetector" #define LOG_TAG "CpuDetector"
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__) #define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
static const Aarch64Info info = GetAarch64Info();
static const Aarch64Features features = info.features;
extern "C" JNIEXPORT jint JNICALL extern "C" JNIEXPORT jint JNICALL
Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier( Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier(
JNIEnv* env, JNIEnv* /*env*/,
jclass clazz) { jobject /*clazz*/) {
int tier = 0; // Default to T0 (baseline) int tier = 0; // Default to T0 (baseline)
#if defined(__aarch64__)
// Check features in reverse order (highest tier first) // Check features in reverse order (highest tier first)
// TODO-han.yin: implement T4 once obtaining an Android device with SME! if (features.sme) {
if (features.sve && features.sve2) { tier = 5; // T5: ARMv9-a with SVE/SVE2
tier = 3; // T3: ARMv9-a with SVE/SVE2 LOGI("Detected SME support - selecting T5");
LOGI("Detected SVE/SVE2 support - selecting T3"); }
else if (features.sve && features.sve2) {
tier = 4; // T4: ARMv9-a with SVE/SVE2
LOGI("Detected SVE/SVE2 support - selecting T4");
} }
else if (features.i8mm) { else if (features.i8mm) {
tier = 2; // T2: ARMv8.6-a with i8mm tier = 3; // T3: ARMv8.6-a with i8mm
LOGI("Detected i8mm support - selecting T2"); LOGI("Detected i8mm support - selecting T3");
} }
else if (features.asimddp) { else if (features.asimddp) {
tier = 1; // T1: ARMv8.2-a with dotprod tier = 2; // T2: ARMv8.2-a with dotprod
LOGI("Detected dotprod support - selecting T1"); LOGI("Detected dotprod support - selecting T2");
} }
else if (features.asimd) { else if (features.asimd) {
tier = 0; // T0: baseline ARMv8-a with ASIMD tier = 1; // T1: baseline ARMv8-a with ASIMD
LOGI("Detected basic ASIMD support - selecting T0"); LOGI("Detected basic ASIMD support - selecting T1");
} }
else { else {
// Fallback - this shouldn't happen on arm64-v8a devices // Fallback - this shouldn't happen on arm64-v8a devices
tier = 0; tier = 1;
LOGI("No expected features detected - falling back to T0"); LOGI("No expected features detected - falling back to T1");
} }
#else
LOGI("non aarch64 architecture detected - defaulting to T0");
#endif
return tier; return tier;
} }
@ -48,15 +56,19 @@ Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier(
extern "C" JNIEXPORT jstring JNICALL extern "C" JNIEXPORT jstring JNICALL
Java_android_llama_cpp_internal_TierDetectionImpl_getCpuFeaturesString( Java_android_llama_cpp_internal_TierDetectionImpl_getCpuFeaturesString(
JNIEnv* env, JNIEnv* env,
jclass clazz) { jobject /*clazz*/) {
std::string text; std::string text;
#if defined(__aarch64__)
if (features.asimd) text += "ASIMD "; if (features.asimd) text += "ASIMD ";
if (features.asimddp) text += "ASIMDDP "; if (features.asimddp) text += "ASIMDDP ";
if (features.i8mm) text += "I8MM "; if (features.i8mm) text += "I8MM ";
if (features.sve) text += "SVE "; if (features.sve) text += "SVE ";
if (features.sve2) text += "SVE2 "; if (features.sve2) text += "SVE2 ";
if (features.sme) text += "SME "; if (features.sme) text += "SME ";
#else
LOGI("non aarch64 architecture detected");
#endif
return env->NewStringUTF(text.c_str()); return env->NewStringUTF(text.c_str());
} }

View File

@ -81,10 +81,10 @@ object ArmFeaturesMapper {
private fun getSupportedFeatures(tier: LLamaTier?): List<Boolean>? = private fun getSupportedFeatures(tier: LLamaTier?): List<Boolean>? =
when (tier) { when (tier) {
LLamaTier.NONE, null -> null // No tier detected LLamaTier.NONE, null -> null // No tier detected
LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only LLamaTier.T1 -> listOf(true, false, false, false, false) // ASIMD only
LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD LLamaTier.T2 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM LLamaTier.T3 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2 LLamaTier.T4 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2
LLamaTier.T4 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2 LLamaTier.T5 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2
} }
} }

View File

@ -12,17 +12,17 @@ interface TierDetection {
* ARM optimization tiers supported by the Kleidi-Llama library. * ARM optimization tiers supported by the Kleidi-Llama library.
* Higher tiers provide better performance on supported hardware. * Higher tiers provide better performance on supported hardware.
*/ */
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) { enum class LLamaTier(val rawValue: Int, val description: String) {
NONE(-404, "", "No valid Arm® optimization available!"), NONE(0, "No valid Arm® optimization available!"),
T0(0, "llama_android_t0", "ARMv8-a baseline with ASIMD"), T1(1, "ARMv8-a baseline with ASIMD"),
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"), T2(2, "ARMv8.2-a with DotProd"),
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"), T3(3, "ARMv8.6-a with DotProd + I8MM"),
T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2"), T4(4, "ARMv9-a with DotProd + I8MM + SVE/SVE2"),
T4(4, "llama_android_t4", "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2"); T5(5, "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2");
companion object { companion object {
fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value } fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value }
val maxSupportedTier = T3 val maxSupportedTier = T5
} }
} }