core: restructure Kleidi-Llama library

This commit is contained in:
Han Yin 2025-09-03 14:19:23 -07:00
parent 6cde2fe1bd
commit 6db4c70991
10 changed files with 172 additions and 182 deletions

View File

@ -10,8 +10,7 @@ import android.util.Log
object StubTierDetection : TierDetection {
private val tag = StubTierDetection::class.java.simpleName
override val detectedTier: LLamaTier?
get() = LLamaTier.T2
override fun getDetectedTier(): LLamaTier? = LLamaTier.T2
override fun clearCache() {
Log.d(tag, "Cache cleared")

View File

@ -67,7 +67,7 @@ class SettingsViewModel @Inject constructor(
val darkThemeMode: StateFlow<DarkThemeMode> = _darkThemeMode.asStateFlow()
val detectedTier: LLamaTier?
get() = tierDetection.detectedTier
get() = tierDetection.getDetectedTier()
init {
viewModelScope.launch {

View File

@ -17,8 +17,8 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE)
add_subdirectory(
${CMAKE_CURRENT_LIST_DIR}/../../../../../../include/cpu_features
${CMAKE_BINARY_DIR}/cpu_features_build)
add_library(llama_cpu_detector SHARED cpu_detector.cpp)
target_link_libraries(llama_cpu_detector
add_library(kleidi-llama-cpu-detector SHARED cpu_detector.cpp)
target_link_libraries(kleidi-llama-cpu-detector
PRIVATE CpuFeatures::cpu_features
android
log)

View File

@ -12,7 +12,7 @@ static const Aarch64Info info = GetAarch64Info();
static const Aarch64Features features = info.features;
extern "C" JNIEXPORT jint JNICALL
Java_android_llama_cpp_internal_InferenceEngineLoader_getOptimalTier(
Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier(
JNIEnv* env,
jclass clazz) {
int tier = 0; // Default to T0 (baseline)
@ -46,7 +46,7 @@ Java_android_llama_cpp_internal_InferenceEngineLoader_getOptimalTier(
// Optional: Keep a feature string function for debugging
extern "C" JNIEXPORT jstring JNICALL
Java_android_llama_cpp_internal_InferenceEngineLoader_getCpuFeaturesString(
Java_android_llama_cpp_internal_TierDetectionImpl_getCpuFeaturesString(
JNIEnv* env,
jclass clazz) {
std::string text;

View File

@ -84,7 +84,7 @@ object ArmFeaturesMapper {
LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only
LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2
LLamaTier.T4 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2
}
}

View File

@ -2,6 +2,7 @@ package android.llama.cpp
import android.content.Context
import android.llama.cpp.internal.InferenceEngineFactory
import android.llama.cpp.internal.TierDetectionImpl
/**
* Main entry point for the Llama Android library.
@ -11,12 +12,10 @@ object KleidiLlama {
/**
* Create an inference engine instance with automatic tier detection.
*/
fun createInferenceEngine(context: Context) =
InferenceEngineFactory.getInstance(context)
fun createInferenceEngine(context: Context) = InferenceEngineFactory.getInstance(context)
/**
* Get tier detection information for debugging/settings.
*/
fun getTierDetection(context: Context) =
InferenceEngineFactory.getTierDetection(context)
fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context)
}

View File

@ -4,7 +4,7 @@ package android.llama.cpp
* Public interface for [LLamaTier] detection information.
*/
interface TierDetection {
val detectedTier: LLamaTier?
fun getDetectedTier(): LLamaTier?
fun clearCache()
}
@ -17,8 +17,8 @@ enum class LLamaTier(val rawValue: Int, val libraryName: String, val description
T0(0, "llama_android_t0", "ARMv8-a baseline with ASIMD"),
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2");
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2"),
T4(4, "llama_android_t4", "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2");
companion object {
fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value }

View File

@ -1,13 +1,45 @@
package android.llama.cpp.internal
import android.content.Context
import android.llama.cpp.InferenceEngine
import android.llama.cpp.TierDetection
import android.util.Log
import kotlinx.coroutines.runBlocking
/**
* Internal factory to create [InferenceEngine] and [TierDetection]
*/
internal object InferenceEngineFactory {
fun getInstance(context: Context) = InferenceEngineLoader.getInstance(context)
private val TAG = InferenceEngineFactory::class.simpleName
fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context)
private var _cachedInstance: InferenceEngineImpl? = null
/**
* Factory method to get a configured [InferenceEngineImpl] instance.
* Handles tier detection, caching, and library loading automatically.
*/
@Synchronized
fun getInstance(context: Context): InferenceEngine? {
// Return cached instance if available
_cachedInstance?.let { return it }
return runBlocking {
try {
// Create and cache the inference engine instance
InferenceEngineImpl.create(context).also {
_cachedInstance = it
Log.i(TAG, "Successfully instantiated Inference Engine")
}
} catch (e: Exception) {
Log.e(TAG, "Error instantiating Inference Engine", e)
null
}
}
}
fun clearCache() {
_cachedInstance = null
Log.i(TAG, "Cleared cached instance of InferenceEngine")
}
}

View File

@ -1,159 +0,0 @@
package android.llama.cpp.internal
import android.content.Context
import android.llama.cpp.InferenceEngine
import android.llama.cpp.LLamaTier
import android.util.Log
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.edit
import androidx.datastore.preferences.core.intPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.runBlocking
/**
* Internal [android.llama.cpp.InferenceEngine] loader implementation
*/
internal object InferenceEngineLoader {
private val TAG = InferenceEngineLoader::class.simpleName
// CPU feature detection preferences
private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection"
private val Context.llamaTierDataStore: DataStore<Preferences>
by preferencesDataStore(name = DATASTORE_CPU_DETECTION)
private val DETECTION_VERSION = intPreferencesKey("detection_version")
private val DETECTED_TIER = intPreferencesKey("detected_tier")
// Constants
private const val DATASTORE_VERSION = 1
@JvmStatic
private external fun getOptimalTier(): Int
@JvmStatic
private external fun getCpuFeaturesString(): String
private var _cachedInstance: InferenceEngineImpl? = null
private var _detectedTier: LLamaTier? = null
/**
* Get the detected tier, loading from cache if needed
*/
fun getDetectedTier(context: Context): LLamaTier? =
_detectedTier ?: runBlocking { obtainTier(context) }
/**
* Factory method to get a configured [InferenceEngineImpl] instance.
* Handles tier detection, caching, and library loading automatically.
*/
@Synchronized
fun getInstance(context: Context): InferenceEngine? {
// Return cached instance if available
_cachedInstance?.let { return it }
return runBlocking {
try {
// Create and cache the inference engine instance
InferenceEngineImpl.create(context).also {
_cachedInstance = it
Log.i(TAG, "Successfully instantiated Inference Engine")
}
} catch (e: Exception) {
Log.e(TAG, "Error instantiating Inference Engine", e)
null
}
}
}
/**
* First attempt to load detected tier from storage, if available;
* Otherwise, perform a fresh detection, then save to storage and cache.
*/
private suspend fun obtainTier(context: Context) =
loadDetectedTierFromDataStore(context) ?: run {
Log.i(TAG, "Performing fresh tier detection")
performOptimalTierDetection().also { tier ->
tier?.saveToDataStore(context)
_detectedTier = tier
}
}
/**
* Load cached tier from datastore without performing detection
*/
private suspend fun loadDetectedTierFromDataStore(context: Context): LLamaTier? {
val preferences = context.llamaTierDataStore.data.first()
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
LLamaTier.fromRawValue(cachedTierValue)?.also {
Log.i(TAG, "Loaded cached tier: ${it.name}")
_detectedTier = it
}
} else {
Log.i(TAG, "No valid cached tier found")
null
}
}
/**
* Actual implementation of optimal tier detection via native methods
*/
private fun performOptimalTierDetection(): LLamaTier? {
try {
// Load CPU detection library
System.loadLibrary("llama_cpu_detector")
Log.i(TAG, "CPU feature detector loaded successfully")
// Detect optimal tier
val tierValue = getOptimalTier()
val features = getCpuFeaturesString()
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
// Convert to enum and validate
val tier = LLamaTier.fromRawValue(tierValue) ?: run {
Log.e(TAG, "Invalid tier value $tierValue")
return LLamaTier.NONE
}
// Ensure we don't exceed maximum supported tier
val maxTier = LLamaTier.maxSupportedTier
return if (tier.rawValue > maxTier.rawValue) {
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}")
maxTier
} else {
tier
}
} catch (e: UnsatisfiedLinkError) {
Log.e(TAG, "Failed to load CPU detection library", e)
return null
} catch (e: Exception) {
Log.e(TAG, "Unexpected error during tier detection", e)
return null
}
}
/**
* Clear cached detection results (for testing/debugging)
*/
fun clearCache(context: Context) {
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
_cachedInstance = null
_detectedTier = null
Log.i(TAG, "Cleared detection results and cached instance")
}
private suspend fun LLamaTier.saveToDataStore(context: Context) {
context.llamaTierDataStore.edit { prefs ->
prefs[DETECTED_TIER] = this.rawValue
prefs[DETECTION_VERSION] = DATASTORE_VERSION
}
Log.i(TAG, "Saved ${this.name} to data store")
}
}

View File

@ -3,13 +3,132 @@ package android.llama.cpp.internal
import android.content.Context
import android.llama.cpp.LLamaTier
import android.llama.cpp.TierDetection
import android.util.Log
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.edit
import androidx.datastore.preferences.core.intPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.runBlocking
/**
* Internal tier detection implementation
* Internal [LLamaTier] detection implementation
*/
internal class TierDetectionImpl(private val context: Context) : TierDetection {
override val detectedTier: LLamaTier?
get() = InferenceEngineLoader.getDetectedTier(context)
internal class TierDetectionImpl(
private val context: Context
): TierDetection {
override fun clearCache() = InferenceEngineLoader.clearCache(context)
companion object {
private val TAG = TierDetectionImpl::class.simpleName
// CPU feature detection preferences
private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection"
private const val DATASTORE_VERSION = 1
private val Context.llamaTierDataStore: DataStore<Preferences>
by preferencesDataStore(name = DATASTORE_CPU_DETECTION)
private val DETECTION_VERSION = intPreferencesKey("detection_version")
private val DETECTED_TIER = intPreferencesKey("detected_tier")
}
private external fun getOptimalTier(): Int
private external fun getCpuFeaturesString(): String
private var _detectedTier: LLamaTier? = null
/**
* Get the detected tier, loading from cache if needed
*/
override fun getDetectedTier(): LLamaTier? =
_detectedTier ?: runBlocking { obtainTier() }
/**
* First attempt to load detected tier from storage, if available;
* Otherwise, perform a fresh detection, then save to storage and cache.
*/
private suspend fun obtainTier() =
loadDetectedTierFromDataStore() ?: run {
Log.i(TAG, "Performing fresh tier detection")
performOptimalTierDetection().also { tier ->
tier?.saveToDataStore()
_detectedTier = tier
}
}
/**
* Load cached tier from datastore without performing detection
*/
private suspend fun loadDetectedTierFromDataStore(): LLamaTier? {
val preferences = context.llamaTierDataStore.data.first()
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
LLamaTier.fromRawValue(cachedTierValue)?.also {
Log.i(TAG, "Loaded cached tier: ${it.name}")
_detectedTier = it
}
} else {
Log.i(TAG, "No valid cached tier found")
null
}
}
/**
* Actual implementation of optimal tier detection via native methods
*/
private fun performOptimalTierDetection(): LLamaTier? {
try {
// Load CPU detection library
System.loadLibrary("kleidi-llama-cpu-detector")
Log.i(TAG, "CPU feature detector loaded successfully")
// Detect optimal tier
val tierValue = getOptimalTier()
val features = getCpuFeaturesString()
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
// Convert to enum and validate
val tier = LLamaTier.fromRawValue(tierValue) ?: run {
Log.e(TAG, "Invalid tier value $tierValue")
return LLamaTier.NONE
}
// Ensure we don't exceed maximum supported tier
val maxTier = LLamaTier.maxSupportedTier
return if (tier.rawValue > maxTier.rawValue) {
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}")
maxTier
} else {
tier
}
} catch (e: UnsatisfiedLinkError) {
Log.e(TAG, "Failed to load CPU detection library", e)
return null
} catch (e: Exception) {
Log.e(TAG, "Unexpected error during tier detection", e)
return null
}
}
/**
* Clear cached detection results (for testing/debugging)
*/
override fun clearCache() {
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
_detectedTier = null
Log.i(TAG, "Cleared CPU detection results")
}
private suspend fun LLamaTier.saveToDataStore() {
context.llamaTierDataStore.edit { prefs ->
prefs[DETECTED_TIER] = this.rawValue
prefs[DETECTION_VERSION] = DATASTORE_VERSION
}
Log.i(TAG, "Saved ${this.name} to data store")
}
}