core: restructure Kleidi-Llama library
This commit is contained in:
parent
6cde2fe1bd
commit
6db4c70991
|
|
@ -10,8 +10,7 @@ import android.util.Log
|
|||
object StubTierDetection : TierDetection {
|
||||
private val tag = StubTierDetection::class.java.simpleName
|
||||
|
||||
override val detectedTier: LLamaTier?
|
||||
get() = LLamaTier.T2
|
||||
override fun getDetectedTier(): LLamaTier? = LLamaTier.T2
|
||||
|
||||
override fun clearCache() {
|
||||
Log.d(tag, "Cache cleared")
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class SettingsViewModel @Inject constructor(
|
|||
val darkThemeMode: StateFlow<DarkThemeMode> = _darkThemeMode.asStateFlow()
|
||||
|
||||
val detectedTier: LLamaTier?
|
||||
get() = tierDetection.detectedTier
|
||||
get() = tierDetection.getDetectedTier()
|
||||
|
||||
init {
|
||||
viewModelScope.launch {
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE)
|
|||
add_subdirectory(
|
||||
${CMAKE_CURRENT_LIST_DIR}/../../../../../../include/cpu_features
|
||||
${CMAKE_BINARY_DIR}/cpu_features_build)
|
||||
add_library(llama_cpu_detector SHARED cpu_detector.cpp)
|
||||
target_link_libraries(llama_cpu_detector
|
||||
add_library(kleidi-llama-cpu-detector SHARED cpu_detector.cpp)
|
||||
target_link_libraries(kleidi-llama-cpu-detector
|
||||
PRIVATE CpuFeatures::cpu_features
|
||||
android
|
||||
log)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ static const Aarch64Info info = GetAarch64Info();
|
|||
static const Aarch64Features features = info.features;
|
||||
|
||||
extern "C" JNIEXPORT jint JNICALL
|
||||
Java_android_llama_cpp_internal_InferenceEngineLoader_getOptimalTier(
|
||||
Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier(
|
||||
JNIEnv* env,
|
||||
jclass clazz) {
|
||||
int tier = 0; // Default to T0 (baseline)
|
||||
|
|
@ -46,7 +46,7 @@ Java_android_llama_cpp_internal_InferenceEngineLoader_getOptimalTier(
|
|||
|
||||
// Optional: Keep a feature string function for debugging
|
||||
extern "C" JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_internal_InferenceEngineLoader_getCpuFeaturesString(
|
||||
Java_android_llama_cpp_internal_TierDetectionImpl_getCpuFeaturesString(
|
||||
JNIEnv* env,
|
||||
jclass clazz) {
|
||||
std::string text;
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ object ArmFeaturesMapper {
|
|||
LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only
|
||||
LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD
|
||||
LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM
|
||||
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE
|
||||
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
|
||||
LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2
|
||||
LLamaTier.T4 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package android.llama.cpp
|
|||
|
||||
import android.content.Context
|
||||
import android.llama.cpp.internal.InferenceEngineFactory
|
||||
import android.llama.cpp.internal.TierDetectionImpl
|
||||
|
||||
/**
|
||||
* Main entry point for the Llama Android library.
|
||||
|
|
@ -11,12 +12,10 @@ object KleidiLlama {
|
|||
/**
|
||||
* Create an inference engine instance with automatic tier detection.
|
||||
*/
|
||||
fun createInferenceEngine(context: Context) =
|
||||
InferenceEngineFactory.getInstance(context)
|
||||
fun createInferenceEngine(context: Context) = InferenceEngineFactory.getInstance(context)
|
||||
|
||||
/**
|
||||
* Get tier detection information for debugging/settings.
|
||||
*/
|
||||
fun getTierDetection(context: Context) =
|
||||
InferenceEngineFactory.getTierDetection(context)
|
||||
fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ package android.llama.cpp
|
|||
* Public interface for [LLamaTier] detection information.
|
||||
*/
|
||||
interface TierDetection {
|
||||
val detectedTier: LLamaTier?
|
||||
fun getDetectedTier(): LLamaTier?
|
||||
fun clearCache()
|
||||
}
|
||||
|
||||
|
|
@ -17,8 +17,8 @@ enum class LLamaTier(val rawValue: Int, val libraryName: String, val description
|
|||
T0(0, "llama_android_t0", "ARMv8-a baseline with ASIMD"),
|
||||
T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"),
|
||||
T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"),
|
||||
T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2");
|
||||
// TODO-han.yin: implement T4 once obtaining an Android device with SME!
|
||||
T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2"),
|
||||
T4(4, "llama_android_t4", "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2");
|
||||
|
||||
companion object {
|
||||
fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value }
|
||||
|
|
|
|||
|
|
@ -1,13 +1,45 @@
|
|||
package android.llama.cpp.internal
|
||||
|
||||
import android.content.Context
|
||||
import android.llama.cpp.InferenceEngine
|
||||
import android.llama.cpp.TierDetection
|
||||
import android.util.Log
|
||||
import kotlinx.coroutines.runBlocking
|
||||
|
||||
/**
|
||||
* Internal factory to create [InferenceEngine] and [TierDetection]
|
||||
*/
|
||||
internal object InferenceEngineFactory {
|
||||
fun getInstance(context: Context) = InferenceEngineLoader.getInstance(context)
|
||||
private val TAG = InferenceEngineFactory::class.simpleName
|
||||
|
||||
fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context)
|
||||
private var _cachedInstance: InferenceEngineImpl? = null
|
||||
|
||||
/**
|
||||
* Factory method to get a configured [InferenceEngineImpl] instance.
|
||||
* Handles tier detection, caching, and library loading automatically.
|
||||
*/
|
||||
@Synchronized
|
||||
fun getInstance(context: Context): InferenceEngine? {
|
||||
// Return cached instance if available
|
||||
_cachedInstance?.let { return it }
|
||||
|
||||
return runBlocking {
|
||||
try {
|
||||
// Create and cache the inference engine instance
|
||||
InferenceEngineImpl.create(context).also {
|
||||
_cachedInstance = it
|
||||
Log.i(TAG, "Successfully instantiated Inference Engine")
|
||||
}
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error instantiating Inference Engine", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun clearCache() {
|
||||
_cachedInstance = null
|
||||
Log.i(TAG, "Cleared cached instance of InferenceEngine")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,159 +0,0 @@
|
|||
package android.llama.cpp.internal
|
||||
|
||||
import android.content.Context
|
||||
import android.llama.cpp.InferenceEngine
|
||||
import android.llama.cpp.LLamaTier
|
||||
import android.util.Log
|
||||
import androidx.datastore.core.DataStore
|
||||
import androidx.datastore.preferences.core.Preferences
|
||||
import androidx.datastore.preferences.core.edit
|
||||
import androidx.datastore.preferences.core.intPreferencesKey
|
||||
import androidx.datastore.preferences.preferencesDataStore
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.runBlocking
|
||||
|
||||
/**
|
||||
* Internal [android.llama.cpp.InferenceEngine] loader implementation
|
||||
*/
|
||||
internal object InferenceEngineLoader {
|
||||
private val TAG = InferenceEngineLoader::class.simpleName
|
||||
|
||||
// CPU feature detection preferences
|
||||
private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection"
|
||||
private val Context.llamaTierDataStore: DataStore<Preferences>
|
||||
by preferencesDataStore(name = DATASTORE_CPU_DETECTION)
|
||||
|
||||
private val DETECTION_VERSION = intPreferencesKey("detection_version")
|
||||
private val DETECTED_TIER = intPreferencesKey("detected_tier")
|
||||
|
||||
// Constants
|
||||
private const val DATASTORE_VERSION = 1
|
||||
|
||||
@JvmStatic
|
||||
private external fun getOptimalTier(): Int
|
||||
|
||||
@JvmStatic
|
||||
private external fun getCpuFeaturesString(): String
|
||||
|
||||
private var _cachedInstance: InferenceEngineImpl? = null
|
||||
private var _detectedTier: LLamaTier? = null
|
||||
|
||||
/**
|
||||
* Get the detected tier, loading from cache if needed
|
||||
*/
|
||||
fun getDetectedTier(context: Context): LLamaTier? =
|
||||
_detectedTier ?: runBlocking { obtainTier(context) }
|
||||
|
||||
/**
|
||||
* Factory method to get a configured [InferenceEngineImpl] instance.
|
||||
* Handles tier detection, caching, and library loading automatically.
|
||||
*/
|
||||
@Synchronized
|
||||
fun getInstance(context: Context): InferenceEngine? {
|
||||
// Return cached instance if available
|
||||
_cachedInstance?.let { return it }
|
||||
|
||||
return runBlocking {
|
||||
try {
|
||||
// Create and cache the inference engine instance
|
||||
InferenceEngineImpl.create(context).also {
|
||||
_cachedInstance = it
|
||||
Log.i(TAG, "Successfully instantiated Inference Engine")
|
||||
}
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error instantiating Inference Engine", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* First attempt to load detected tier from storage, if available;
|
||||
* Otherwise, perform a fresh detection, then save to storage and cache.
|
||||
*/
|
||||
private suspend fun obtainTier(context: Context) =
|
||||
loadDetectedTierFromDataStore(context) ?: run {
|
||||
Log.i(TAG, "Performing fresh tier detection")
|
||||
performOptimalTierDetection().also { tier ->
|
||||
tier?.saveToDataStore(context)
|
||||
_detectedTier = tier
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load cached tier from datastore without performing detection
|
||||
*/
|
||||
private suspend fun loadDetectedTierFromDataStore(context: Context): LLamaTier? {
|
||||
val preferences = context.llamaTierDataStore.data.first()
|
||||
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
|
||||
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
|
||||
|
||||
return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
|
||||
LLamaTier.fromRawValue(cachedTierValue)?.also {
|
||||
Log.i(TAG, "Loaded cached tier: ${it.name}")
|
||||
_detectedTier = it
|
||||
}
|
||||
} else {
|
||||
Log.i(TAG, "No valid cached tier found")
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Actual implementation of optimal tier detection via native methods
|
||||
*/
|
||||
private fun performOptimalTierDetection(): LLamaTier? {
|
||||
try {
|
||||
// Load CPU detection library
|
||||
System.loadLibrary("llama_cpu_detector")
|
||||
Log.i(TAG, "CPU feature detector loaded successfully")
|
||||
|
||||
// Detect optimal tier
|
||||
val tierValue = getOptimalTier()
|
||||
val features = getCpuFeaturesString()
|
||||
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
|
||||
|
||||
// Convert to enum and validate
|
||||
val tier = LLamaTier.fromRawValue(tierValue) ?: run {
|
||||
Log.e(TAG, "Invalid tier value $tierValue")
|
||||
return LLamaTier.NONE
|
||||
}
|
||||
|
||||
// Ensure we don't exceed maximum supported tier
|
||||
val maxTier = LLamaTier.maxSupportedTier
|
||||
return if (tier.rawValue > maxTier.rawValue) {
|
||||
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}")
|
||||
maxTier
|
||||
} else {
|
||||
tier
|
||||
}
|
||||
|
||||
} catch (e: UnsatisfiedLinkError) {
|
||||
Log.e(TAG, "Failed to load CPU detection library", e)
|
||||
return null
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Unexpected error during tier detection", e)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear cached detection results (for testing/debugging)
|
||||
*/
|
||||
fun clearCache(context: Context) {
|
||||
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
|
||||
_cachedInstance = null
|
||||
_detectedTier = null
|
||||
Log.i(TAG, "Cleared detection results and cached instance")
|
||||
}
|
||||
|
||||
private suspend fun LLamaTier.saveToDataStore(context: Context) {
|
||||
context.llamaTierDataStore.edit { prefs ->
|
||||
prefs[DETECTED_TIER] = this.rawValue
|
||||
prefs[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
Log.i(TAG, "Saved ${this.name} to data store")
|
||||
}
|
||||
}
|
||||
|
|
@ -3,13 +3,132 @@ package android.llama.cpp.internal
|
|||
import android.content.Context
|
||||
import android.llama.cpp.LLamaTier
|
||||
import android.llama.cpp.TierDetection
|
||||
import android.util.Log
|
||||
import androidx.datastore.core.DataStore
|
||||
import androidx.datastore.preferences.core.Preferences
|
||||
import androidx.datastore.preferences.core.edit
|
||||
import androidx.datastore.preferences.core.intPreferencesKey
|
||||
import androidx.datastore.preferences.preferencesDataStore
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.runBlocking
|
||||
|
||||
/**
|
||||
* Internal tier detection implementation
|
||||
* Internal [LLamaTier] detection implementation
|
||||
*/
|
||||
internal class TierDetectionImpl(private val context: Context) : TierDetection {
|
||||
override val detectedTier: LLamaTier?
|
||||
get() = InferenceEngineLoader.getDetectedTier(context)
|
||||
internal class TierDetectionImpl(
|
||||
private val context: Context
|
||||
): TierDetection {
|
||||
|
||||
override fun clearCache() = InferenceEngineLoader.clearCache(context)
|
||||
companion object {
|
||||
private val TAG = TierDetectionImpl::class.simpleName
|
||||
|
||||
// CPU feature detection preferences
|
||||
private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection"
|
||||
private const val DATASTORE_VERSION = 1
|
||||
private val Context.llamaTierDataStore: DataStore<Preferences>
|
||||
by preferencesDataStore(name = DATASTORE_CPU_DETECTION)
|
||||
|
||||
private val DETECTION_VERSION = intPreferencesKey("detection_version")
|
||||
private val DETECTED_TIER = intPreferencesKey("detected_tier")
|
||||
}
|
||||
|
||||
private external fun getOptimalTier(): Int
|
||||
|
||||
private external fun getCpuFeaturesString(): String
|
||||
|
||||
private var _detectedTier: LLamaTier? = null
|
||||
|
||||
/**
|
||||
* Get the detected tier, loading from cache if needed
|
||||
*/
|
||||
override fun getDetectedTier(): LLamaTier? =
|
||||
_detectedTier ?: runBlocking { obtainTier() }
|
||||
|
||||
/**
|
||||
* First attempt to load detected tier from storage, if available;
|
||||
* Otherwise, perform a fresh detection, then save to storage and cache.
|
||||
*/
|
||||
private suspend fun obtainTier() =
|
||||
loadDetectedTierFromDataStore() ?: run {
|
||||
Log.i(TAG, "Performing fresh tier detection")
|
||||
performOptimalTierDetection().also { tier ->
|
||||
tier?.saveToDataStore()
|
||||
_detectedTier = tier
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load cached tier from datastore without performing detection
|
||||
*/
|
||||
private suspend fun loadDetectedTierFromDataStore(): LLamaTier? {
|
||||
val preferences = context.llamaTierDataStore.data.first()
|
||||
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
|
||||
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
|
||||
|
||||
return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
|
||||
LLamaTier.fromRawValue(cachedTierValue)?.also {
|
||||
Log.i(TAG, "Loaded cached tier: ${it.name}")
|
||||
_detectedTier = it
|
||||
}
|
||||
} else {
|
||||
Log.i(TAG, "No valid cached tier found")
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Actual implementation of optimal tier detection via native methods
|
||||
*/
|
||||
private fun performOptimalTierDetection(): LLamaTier? {
|
||||
try {
|
||||
// Load CPU detection library
|
||||
System.loadLibrary("kleidi-llama-cpu-detector")
|
||||
Log.i(TAG, "CPU feature detector loaded successfully")
|
||||
|
||||
// Detect optimal tier
|
||||
val tierValue = getOptimalTier()
|
||||
val features = getCpuFeaturesString()
|
||||
Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features")
|
||||
|
||||
// Convert to enum and validate
|
||||
val tier = LLamaTier.fromRawValue(tierValue) ?: run {
|
||||
Log.e(TAG, "Invalid tier value $tierValue")
|
||||
return LLamaTier.NONE
|
||||
}
|
||||
|
||||
// Ensure we don't exceed maximum supported tier
|
||||
val maxTier = LLamaTier.maxSupportedTier
|
||||
return if (tier.rawValue > maxTier.rawValue) {
|
||||
Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}")
|
||||
maxTier
|
||||
} else {
|
||||
tier
|
||||
}
|
||||
|
||||
} catch (e: UnsatisfiedLinkError) {
|
||||
Log.e(TAG, "Failed to load CPU detection library", e)
|
||||
return null
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Unexpected error during tier detection", e)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear cached detection results (for testing/debugging)
|
||||
*/
|
||||
override fun clearCache() {
|
||||
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
|
||||
_detectedTier = null
|
||||
Log.i(TAG, "Cleared CPU detection results")
|
||||
}
|
||||
|
||||
private suspend fun LLamaTier.saveToDataStore() {
|
||||
context.llamaTierDataStore.edit { prefs ->
|
||||
prefs[DETECTED_TIER] = this.rawValue
|
||||
prefs[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
Log.i(TAG, "Saved ${this.name} to data store")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue