lib: support x86-64 by dynamically set Arm related definitions
This commit is contained in:
parent
8f90e42ee2
commit
63e5bd0771
|
|
@ -10,7 +10,7 @@ import android.util.Log
|
|||
object StubTierDetection : TierDetection {
|
||||
private val tag = StubTierDetection::class.java.simpleName
|
||||
|
||||
override fun getDetectedTier(): LLamaTier? = LLamaTier.T2
|
||||
override fun getDetectedTier(): LLamaTier? = LLamaTier.T3
|
||||
|
||||
override fun clearCache() {
|
||||
Log.d(tag, "Cache cleared")
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ android {
|
|||
consumerProguardFiles("consumer-rules.pro")
|
||||
|
||||
ndk {
|
||||
abiFilters += listOf("arm64-v8a")
|
||||
abiFilters += listOf("arm64-v8a", "x86_64")
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
|
|
@ -29,12 +29,9 @@ android {
|
|||
arguments += "-DLLAMA_BUILD_COMMON=ON"
|
||||
arguments += "-DLLAMA_CURL=OFF"
|
||||
|
||||
arguments += "-DGGML_SYSTEM_ARCH=ARM" // Undocumented before 3.21
|
||||
arguments += "-DGGML_NATIVE=OFF"
|
||||
arguments += "-DGGML_BACKEND_DL=ON"
|
||||
arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
|
||||
arguments += "-DGGML_CPU_KLEIDIAI=ON"
|
||||
arguments += "-DGGML_OPENMP=ON"
|
||||
arguments += "-DGGML_LLAMAFILE=OFF"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,12 +27,33 @@ target_link_libraries(kleidi-llama-cpu-detector
|
|||
# 2. Kleidi Llama library
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
if(DEFINED ANDROID_ABI)
|
||||
message(STATUS "Detected Android ABI: ${ANDROID_ABI}")
|
||||
if(ANDROID_ABI STREQUAL "arm64-v8a")
|
||||
set(GGML_SYSTEM_ARCH "ARM")
|
||||
set(GGML_CPU_KLEIDIAI ON)
|
||||
set(GGML_OPENMP ON)
|
||||
elseif(ANDROID_ABI STREQUAL "x86_64")
|
||||
set(GGML_SYSTEM_ARCH "x86")
|
||||
set(GGML_CPU_KLEIDIAI OFF)
|
||||
set(GGML_OPENMP OFF)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
|
||||
add_subdirectory(${LLAMA_SRC} build-llama)
|
||||
|
||||
add_library(${CMAKE_PROJECT_NAME} SHARED
|
||||
kleidi-llama.cpp)
|
||||
|
||||
target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
|
||||
GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH}
|
||||
GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}>
|
||||
GGML_OPENMP=$<BOOL:${GGML_OPENMP}>
|
||||
)
|
||||
|
||||
target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
|
||||
${LLAMA_SRC}
|
||||
${LLAMA_SRC}/common
|
||||
|
|
|
|||
|
|
@ -1,45 +1,53 @@
|
|||
#include <jni.h>
|
||||
#include <cpuinfo_aarch64.h>
|
||||
#include <android/log.h>
|
||||
#include <string>
|
||||
|
||||
#if defined(__aarch64__)
|
||||
#include <cpuinfo_aarch64.h>
|
||||
using namespace cpu_features;
|
||||
static const Aarch64Info info = GetAarch64Info();
|
||||
static const Aarch64Features features = info.features;
|
||||
#endif
|
||||
|
||||
#define LOG_TAG "CpuDetector"
|
||||
#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, __VA_ARGS__)
|
||||
|
||||
static const Aarch64Info info = GetAarch64Info();
|
||||
static const Aarch64Features features = info.features;
|
||||
|
||||
extern "C" JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier(
|
||||
JNIEnv* env,
|
||||
jclass clazz) {
|
||||
JNIEnv* /*env*/,
|
||||
jobject /*clazz*/) {
|
||||
int tier = 0; // Default to T0 (baseline)
|
||||
|
||||
#if defined(__aarch64__)
|
||||
// Check features in reverse order (highest tier first)
|
||||
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
|
||||
if (features.sve && features.sve2) {
|
||||
tier = 3; // T3: ARMv9-a with SVE/SVE2
|
||||
LOGI("Detected SVE/SVE2 support - selecting T3");
|
||||
if (features.sme) {
|
||||
tier = 5; // T5: ARMv9-a with SVE/SVE2
|
||||
LOGI("Detected SME support - selecting T5");
|
||||
}
|
||||
else if (features.sve && features.sve2) {
|
||||
tier = 4; // T4: ARMv9-a with SVE/SVE2
|
||||
LOGI("Detected SVE/SVE2 support - selecting T4");
|
||||
}
|
||||
else if (features.i8mm) {
|
||||
tier = 2; // T2: ARMv8.6-a with i8mm
|
||||
LOGI("Detected i8mm support - selecting T2");
|
||||
tier = 3; // T3: ARMv8.6-a with i8mm
|
||||
LOGI("Detected i8mm support - selecting T3");
|
||||
}
|
||||
else if (features.asimddp) {
|
||||
tier = 1; // T1: ARMv8.2-a with dotprod
|
||||
LOGI("Detected dotprod support - selecting T1");
|
||||
tier = 2; // T2: ARMv8.2-a with dotprod
|
||||
LOGI("Detected dotprod support - selecting T2");
|
||||
}
|
||||
else if (features.asimd) {
|
||||
tier = 0; // T0: baseline ARMv8-a with ASIMD
|
||||
LOGI("Detected basic ASIMD support - selecting T0");
|
||||
tier = 1; // T1: baseline ARMv8-a with ASIMD
|
||||
LOGI("Detected basic ASIMD support - selecting T1");
|
||||
}
|
||||
else {
|
||||
// Fallback - this shouldn't happen on arm64-v8a devices
|
||||
tier = 0;
|
||||
LOGI("No expected features detected - falling back to T0");
|
||||
tier = 1;
|
||||
LOGI("No expected features detected - falling back to T1");
|
||||
}
|
||||
#else
|
||||
LOGI("non aarch64 architecture detected - defaulting to T0");
|
||||
#endif
|
||||
|
||||
return tier;
|
||||
}
|
||||
|
|
@ -48,15 +56,19 @@ Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier(
|
|||
extern "C" JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_internal_TierDetectionImpl_getCpuFeaturesString(
|
||||
JNIEnv* env,
|
||||
jclass clazz) {
|
||||
jobject /*clazz*/) {
|
||||
std::string text;
|
||||
|
||||
#if defined(__aarch64__)
|
||||
if (features.asimd) text += "ASIMD ";
|
||||
if (features.asimddp) text += "ASIMDDP ";
|
||||
if (features.i8mm) text += "I8MM ";
|
||||
if (features.sve) text += "SVE ";
|
||||
if (features.sve2) text += "SVE2 ";
|
||||
if (features.sme) text += "SME ";
|
||||
#else
|
||||
LOGI("non aarch64 architecture detected");
|
||||
#endif
|
||||
|
||||
return env->NewStringUTF(text.c_str());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,10 +81,10 @@ object ArmFeaturesMapper {
|
|||
private fun getSupportedFeatures(tier: LLamaTier?): List<Boolean>? =
|
||||
when (tier) {
|
||||
LLamaTier.NONE, null -> null // No tier detected
|
||||
LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only
|
||||
LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
|
||||
LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
|
||||
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2
|
||||
LLamaTier.T4 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2
|
||||
LLamaTier.T1 -> listOf(true, false, false, false, false) // ASIMD only
|
||||
LLamaTier.T2 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
|
||||
LLamaTier.T3 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
|
||||
LLamaTier.T4 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2
|
||||
LLamaTier.T5 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,17 +12,17 @@ interface TierDetection {
|
|||
* ARM optimization tiers supported by the Kleidi-Llama library.
|
||||
* Higher tiers provide better performance on supported hardware.
|
||||
*/
|
||||
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
|
||||
NONE(-404, "", "No valid Arm® optimization available!"),
|
||||
T0(0, "llama_android_t0", "ARMv8-a baseline with ASIMD"),
|
||||
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
|
||||
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
|
||||
T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2"),
|
||||
T4(4, "llama_android_t4", "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2");
|
||||
enum class LLamaTier(val rawValue: Int, val description: String) {
|
||||
NONE(0, "No valid Arm® optimization available!"),
|
||||
T1(1, "ARMv8-a baseline with ASIMD"),
|
||||
T2(2, "ARMv8.2-a with DotProd"),
|
||||
T3(3, "ARMv8.6-a with DotProd + I8MM"),
|
||||
T4(4, "ARMv9-a with DotProd + I8MM + SVE/SVE2"),
|
||||
T5(5, "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2");
|
||||
|
||||
companion object {
|
||||
fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value }
|
||||
|
||||
val maxSupportedTier = T3
|
||||
val maxSupportedTier = T5
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue