lib: support x86-64 by dynamically set Arm related definitions
This commit is contained in:
parent
8f90e42ee2
commit
63e5bd0771
|
|
@ -10,7 +10,7 @@ import android.util.Log
|
||||||
object StubTierDetection : TierDetection {
|
object StubTierDetection : TierDetection {
|
||||||
private val tag = StubTierDetection::class.java.simpleName
|
private val tag = StubTierDetection::class.java.simpleName
|
||||||
|
|
||||||
override fun getDetectedTier(): LLamaTier? = LLamaTier.T2
|
override fun getDetectedTier(): LLamaTier? = LLamaTier.T3
|
||||||
|
|
||||||
override fun clearCache() {
|
override fun clearCache() {
|
||||||
Log.d(tag, "Cache cleared")
|
Log.d(tag, "Cache cleared")
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ android {
|
||||||
consumerProguardFiles("consumer-rules.pro")
|
consumerProguardFiles("consumer-rules.pro")
|
||||||
|
|
||||||
ndk {
|
ndk {
|
||||||
abiFilters += listOf("arm64-v8a")
|
abiFilters += listOf("arm64-v8a", "x86_64")
|
||||||
}
|
}
|
||||||
externalNativeBuild {
|
externalNativeBuild {
|
||||||
cmake {
|
cmake {
|
||||||
|
|
@ -29,12 +29,9 @@ android {
|
||||||
arguments += "-DLLAMA_BUILD_COMMON=ON"
|
arguments += "-DLLAMA_BUILD_COMMON=ON"
|
||||||
arguments += "-DLLAMA_CURL=OFF"
|
arguments += "-DLLAMA_CURL=OFF"
|
||||||
|
|
||||||
arguments += "-DGGML_SYSTEM_ARCH=ARM" // Undocumented before 3.21
|
|
||||||
arguments += "-DGGML_NATIVE=OFF"
|
arguments += "-DGGML_NATIVE=OFF"
|
||||||
arguments += "-DGGML_BACKEND_DL=ON"
|
arguments += "-DGGML_BACKEND_DL=ON"
|
||||||
arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
|
arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
|
||||||
arguments += "-DGGML_CPU_KLEIDIAI=ON"
|
|
||||||
arguments += "-DGGML_OPENMP=ON"
|
|
||||||
arguments += "-DGGML_LLAMAFILE=OFF"
|
arguments += "-DGGML_LLAMAFILE=OFF"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,33 @@ target_link_libraries(kleidi-llama-cpu-detector
|
||||||
# 2. Kleidi Llama library
|
# 2. Kleidi Llama library
|
||||||
# --------------------------------------------------------------------------
|
# --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
if(DEFINED ANDROID_ABI)
|
||||||
|
message(STATUS "Detected Android ABI: ${ANDROID_ABI}")
|
||||||
|
if(ANDROID_ABI STREQUAL "arm64-v8a")
|
||||||
|
set(GGML_SYSTEM_ARCH "ARM")
|
||||||
|
set(GGML_CPU_KLEIDIAI ON)
|
||||||
|
set(GGML_OPENMP ON)
|
||||||
|
elseif(ANDROID_ABI STREQUAL "x86_64")
|
||||||
|
set(GGML_SYSTEM_ARCH "x86")
|
||||||
|
set(GGML_CPU_KLEIDIAI OFF)
|
||||||
|
set(GGML_OPENMP OFF)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
|
set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
|
||||||
add_subdirectory(${LLAMA_SRC} build-llama)
|
add_subdirectory(${LLAMA_SRC} build-llama)
|
||||||
|
|
||||||
add_library(${CMAKE_PROJECT_NAME} SHARED
|
add_library(${CMAKE_PROJECT_NAME} SHARED
|
||||||
kleidi-llama.cpp)
|
kleidi-llama.cpp)
|
||||||
|
|
||||||
|
target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
|
||||||
|
GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH}
|
||||||
|
GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}>
|
||||||
|
GGML_OPENMP=$<BOOL:${GGML_OPENMP}>
|
||||||
|
)
|
||||||
|
|
||||||
target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
|
target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
|
||||||
${LLAMA_SRC}
|
${LLAMA_SRC}
|
||||||
${LLAMA_SRC}/common
|
${LLAMA_SRC}/common
|
||||||
|
|
|
||||||
|
|
@ -1,45 +1,53 @@
|
||||||
#include <jni.h>
|
#include <jni.h>
|
||||||
#include <cpuinfo_aarch64.h>
|
|
||||||
#include <android/log.h>
|
#include <android/log.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#if defined(__aarch64__)
|
||||||
|
#include <cpuinfo_aarch64.h>
|
||||||
using namespace cpu_features;
|
using namespace cpu_features;
|
||||||
|
static const Aarch64Info info = GetAarch64Info();
|
||||||
|
static const Aarch64Features features = info.features;
|
||||||
|
#endif
|
||||||
|
|
||||||
#define LOG_TAG "CpuDetector"
|
#define LOG_TAG "CpuDetector"
|
||||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||||
|
|
||||||
static const Aarch64Info info = GetAarch64Info();
|
|
||||||
static const Aarch64Features features = info.features;
|
|
||||||
|
|
||||||
extern "C" JNIEXPORT jint JNICALL
|
extern "C" JNIEXPORT jint JNICALL
|
||||||
Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier(
|
Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier(
|
||||||
JNIEnv* env,
|
JNIEnv* /*env*/,
|
||||||
jclass clazz) {
|
jobject /*clazz*/) {
|
||||||
int tier = 0; // Default to T0 (baseline)
|
int tier = 0; // Default to T0 (baseline)
|
||||||
|
|
||||||
|
#if defined(__aarch64__)
|
||||||
// Check features in reverse order (highest tier first)
|
// Check features in reverse order (highest tier first)
|
||||||
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
|
if (features.sme) {
|
||||||
if (features.sve && features.sve2) {
|
tier = 5; // T5: ARMv9-a with SVE/SVE2
|
||||||
tier = 3; // T3: ARMv9-a with SVE/SVE2
|
LOGI("Detected SME support - selecting T5");
|
||||||
LOGI("Detected SVE/SVE2 support - selecting T3");
|
}
|
||||||
|
else if (features.sve && features.sve2) {
|
||||||
|
tier = 4; // T4: ARMv9-a with SVE/SVE2
|
||||||
|
LOGI("Detected SVE/SVE2 support - selecting T4");
|
||||||
}
|
}
|
||||||
else if (features.i8mm) {
|
else if (features.i8mm) {
|
||||||
tier = 2; // T2: ARMv8.6-a with i8mm
|
tier = 3; // T3: ARMv8.6-a with i8mm
|
||||||
LOGI("Detected i8mm support - selecting T2");
|
LOGI("Detected i8mm support - selecting T3");
|
||||||
}
|
}
|
||||||
else if (features.asimddp) {
|
else if (features.asimddp) {
|
||||||
tier = 1; // T1: ARMv8.2-a with dotprod
|
tier = 2; // T2: ARMv8.2-a with dotprod
|
||||||
LOGI("Detected dotprod support - selecting T1");
|
LOGI("Detected dotprod support - selecting T2");
|
||||||
}
|
}
|
||||||
else if (features.asimd) {
|
else if (features.asimd) {
|
||||||
tier = 0; // T0: baseline ARMv8-a with ASIMD
|
tier = 1; // T1: baseline ARMv8-a with ASIMD
|
||||||
LOGI("Detected basic ASIMD support - selecting T0");
|
LOGI("Detected basic ASIMD support - selecting T1");
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// Fallback - this shouldn't happen on arm64-v8a devices
|
// Fallback - this shouldn't happen on arm64-v8a devices
|
||||||
tier = 0;
|
tier = 1;
|
||||||
LOGI("No expected features detected - falling back to T0");
|
LOGI("No expected features detected - falling back to T1");
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
LOGI("non aarch64 architecture detected - defaulting to T0");
|
||||||
|
#endif
|
||||||
|
|
||||||
return tier;
|
return tier;
|
||||||
}
|
}
|
||||||
|
|
@ -48,15 +56,19 @@ Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier(
|
||||||
extern "C" JNIEXPORT jstring JNICALL
|
extern "C" JNIEXPORT jstring JNICALL
|
||||||
Java_android_llama_cpp_internal_TierDetectionImpl_getCpuFeaturesString(
|
Java_android_llama_cpp_internal_TierDetectionImpl_getCpuFeaturesString(
|
||||||
JNIEnv* env,
|
JNIEnv* env,
|
||||||
jclass clazz) {
|
jobject /*clazz*/) {
|
||||||
std::string text;
|
std::string text;
|
||||||
|
|
||||||
|
#if defined(__aarch64__)
|
||||||
if (features.asimd) text += "ASIMD ";
|
if (features.asimd) text += "ASIMD ";
|
||||||
if (features.asimddp) text += "ASIMDDP ";
|
if (features.asimddp) text += "ASIMDDP ";
|
||||||
if (features.i8mm) text += "I8MM ";
|
if (features.i8mm) text += "I8MM ";
|
||||||
if (features.sve) text += "SVE ";
|
if (features.sve) text += "SVE ";
|
||||||
if (features.sve2) text += "SVE2 ";
|
if (features.sve2) text += "SVE2 ";
|
||||||
if (features.sme) text += "SME ";
|
if (features.sme) text += "SME ";
|
||||||
|
#else
|
||||||
|
LOGI("non aarch64 architecture detected");
|
||||||
|
#endif
|
||||||
|
|
||||||
return env->NewStringUTF(text.c_str());
|
return env->NewStringUTF(text.c_str());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -81,10 +81,10 @@ object ArmFeaturesMapper {
|
||||||
private fun getSupportedFeatures(tier: LLamaTier?): List<Boolean>? =
|
private fun getSupportedFeatures(tier: LLamaTier?): List<Boolean>? =
|
||||||
when (tier) {
|
when (tier) {
|
||||||
LLamaTier.NONE, null -> null // No tier detected
|
LLamaTier.NONE, null -> null // No tier detected
|
||||||
LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only
|
LLamaTier.T1 -> listOf(true, false, false, false, false) // ASIMD only
|
||||||
LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
|
LLamaTier.T2 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
|
||||||
LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
|
LLamaTier.T3 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
|
||||||
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2
|
LLamaTier.T4 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2
|
||||||
LLamaTier.T4 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2
|
LLamaTier.T5 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,17 +12,17 @@ interface TierDetection {
|
||||||
* ARM optimization tiers supported by the Kleidi-Llama library.
|
* ARM optimization tiers supported by the Kleidi-Llama library.
|
||||||
* Higher tiers provide better performance on supported hardware.
|
* Higher tiers provide better performance on supported hardware.
|
||||||
*/
|
*/
|
||||||
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
|
enum class LLamaTier(val rawValue: Int, val description: String) {
|
||||||
NONE(-404, "", "No valid Arm® optimization available!"),
|
NONE(0, "No valid Arm® optimization available!"),
|
||||||
T0(0, "llama_android_t0", "ARMv8-a baseline with ASIMD"),
|
T1(1, "ARMv8-a baseline with ASIMD"),
|
||||||
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
|
T2(2, "ARMv8.2-a with DotProd"),
|
||||||
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
|
T3(3, "ARMv8.6-a with DotProd + I8MM"),
|
||||||
T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2"),
|
T4(4, "ARMv9-a with DotProd + I8MM + SVE/SVE2"),
|
||||||
T4(4, "llama_android_t4", "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2");
|
T5(5, "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2");
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value }
|
fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value }
|
||||||
|
|
||||||
val maxSupportedTier = T3
|
val maxSupportedTier = T5
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue