diff --git a/examples/llama.android/app/src/main/java/com/example/llama/engine/StubTierDetection.kt b/examples/llama.android/app/src/main/java/com/example/llama/engine/StubTierDetection.kt index b120ad2c41..f40f8a65af 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/engine/StubTierDetection.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/engine/StubTierDetection.kt @@ -10,8 +10,7 @@ import android.util.Log object StubTierDetection : TierDetection { private val tag = StubTierDetection::class.java.simpleName - override val detectedTier: LLamaTier? - get() = LLamaTier.T2 + override fun getDetectedTier(): LLamaTier? = LLamaTier.T2 override fun clearCache() { Log.d(tag, "Cache cleared") diff --git a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/SettingsViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/SettingsViewModel.kt index f715fd25dd..d8e025a05e 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/SettingsViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/viewmodel/SettingsViewModel.kt @@ -67,7 +67,7 @@ class SettingsViewModel @Inject constructor( val darkThemeMode: StateFlow = _darkThemeMode.asStateFlow() val detectedTier: LLamaTier? - get() = tierDetection.detectedTier + get() = tierDetection.getDetectedTier() init { viewModelScope.launch { diff --git a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt b/examples/llama.android/llama/src/main/cpp/CMakeLists.txt index ae45302cc1..c88112b86d 100644 --- a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt +++ b/examples/llama.android/llama/src/main/cpp/CMakeLists.txt @@ -17,8 +17,8 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE) add_subdirectory( ${CMAKE_CURRENT_LIST_DIR}/../../../../../../include/cpu_features ${CMAKE_BINARY_DIR}/cpu_features_build) -add_library(llama_cpu_detector SHARED cpu_detector.cpp) -target_link_libraries(llama_cpu_detector +add_library(kleidi-llama-cpu-detector SHARED cpu_detector.cpp) +target_link_libraries(kleidi-llama-cpu-detector PRIVATE CpuFeatures::cpu_features android log) diff --git a/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp b/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp index abad0818aa..9f3a1ce9d6 100644 --- a/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp +++ b/examples/llama.android/llama/src/main/cpp/cpu_detector.cpp @@ -12,7 +12,7 @@ static const Aarch64Info info = GetAarch64Info(); static const Aarch64Features features = info.features; extern "C" JNIEXPORT jint JNICALL -Java_android_llama_cpp_internal_InferenceEngineLoader_getOptimalTier( +Java_android_llama_cpp_internal_TierDetectionImpl_getOptimalTier( JNIEnv* env, jclass clazz) { int tier = 0; // Default to T0 (baseline) @@ -46,7 +46,7 @@ Java_android_llama_cpp_internal_InferenceEngineLoader_getOptimalTier( // Optional: Keep a feature string function for debugging extern "C" JNIEXPORT jstring JNICALL -Java_android_llama_cpp_internal_InferenceEngineLoader_getCpuFeaturesString( +Java_android_llama_cpp_internal_TierDetectionImpl_getCpuFeaturesString( JNIEnv* env, jclass clazz) { std::string text; diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt index adb9412ead..1ca4949450 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/ArmFeatures.kt @@ -84,7 +84,7 @@ object ArmFeaturesMapper { LLamaTier.T0 -> listOf(true, false, false, false, false) // ASIMD only LLamaTier.T1 -> listOf(true, true, false, false, false) // ASIMD + DOTPROD LLamaTier.T2 -> listOf(true, true, true, false, false) // ASIMD + DOTPROD + I8MM - LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE - // TODO-han.yin: implement T4 once obtaining an Android device with SME! + LLamaTier.T3 -> listOf(true, true, true, true, false) // ASIMD + DOTPROD + I8MM + SVE/2 + LLamaTier.T4 -> listOf(true, true, true, true, true) // ASIMD + DOTPROD + I8MM + SVE/2 + SME/2 } } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt index 2adabe7b91..4643753b35 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt @@ -2,6 +2,7 @@ package android.llama.cpp import android.content.Context import android.llama.cpp.internal.InferenceEngineFactory +import android.llama.cpp.internal.TierDetectionImpl /** * Main entry point for the Llama Android library. @@ -11,12 +12,10 @@ object KleidiLlama { /** * Create an inference engine instance with automatic tier detection. */ - fun createInferenceEngine(context: Context) = - InferenceEngineFactory.getInstance(context) + fun createInferenceEngine(context: Context) = InferenceEngineFactory.getInstance(context) /** * Get tier detection information for debugging/settings. */ - fun getTierDetection(context: Context) = - InferenceEngineFactory.getTierDetection(context) + fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context) } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt index a617aa8331..5728c1c45b 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/TierDetection.kt @@ -4,7 +4,7 @@ package android.llama.cpp * Public interface for [LLamaTier] detection information. */ interface TierDetection { - val detectedTier: LLamaTier? + fun getDetectedTier(): LLamaTier? fun clearCache() } @@ -17,8 +17,8 @@ enum class LLamaTier(val rawValue: Int, val libraryName: String, val description T0(0, "llama_android_t0", "ARMv8-a baseline with ASIMD"), T1(1, "llama_android_t1", "ARMv8.2-a with DotProd"), T2(2, "llama_android_t2", "ARMv8.6-a with DotProd + I8MM"), - T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2"); - // TODO-han.yin: implement T4 once obtaining an Android device with SME! + T3(3, "llama_android_t3", "ARMv9-a with DotProd + I8MM + SVE/SVE2"), + T4(4, "llama_android_t4", "ARMv9.2-a with DotProd + I8MM + SVE/SVE2 + SME/SME2"); companion object { fun fromRawValue(value: Int): LLamaTier? = entries.find { it.rawValue == value } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt index 66943b99b0..811f13b587 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt @@ -1,13 +1,45 @@ package android.llama.cpp.internal import android.content.Context +import android.llama.cpp.InferenceEngine import android.llama.cpp.TierDetection +import android.util.Log +import kotlinx.coroutines.runBlocking /** * Internal factory to create [InferenceEngine] and [TierDetection] */ internal object InferenceEngineFactory { - fun getInstance(context: Context) = InferenceEngineLoader.getInstance(context) + private val TAG = InferenceEngineFactory::class.simpleName - fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context) + private var _cachedInstance: InferenceEngineImpl? = null + + /** + * Factory method to get a configured [InferenceEngineImpl] instance. + * Handles tier detection, caching, and library loading automatically. + */ + @Synchronized + fun getInstance(context: Context): InferenceEngine? { + // Return cached instance if available + _cachedInstance?.let { return it } + + return runBlocking { + try { + // Create and cache the inference engine instance + InferenceEngineImpl.create(context).also { + _cachedInstance = it + Log.i(TAG, "Successfully instantiated Inference Engine") + } + + } catch (e: Exception) { + Log.e(TAG, "Error instantiating Inference Engine", e) + null + } + } + } + + fun clearCache() { + _cachedInstance = null + Log.i(TAG, "Cleared cached instance of InferenceEngine") + } } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt deleted file mode 100644 index e3c2174e82..0000000000 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineLoader.kt +++ /dev/null @@ -1,159 +0,0 @@ -package android.llama.cpp.internal - -import android.content.Context -import android.llama.cpp.InferenceEngine -import android.llama.cpp.LLamaTier -import android.util.Log -import androidx.datastore.core.DataStore -import androidx.datastore.preferences.core.Preferences -import androidx.datastore.preferences.core.edit -import androidx.datastore.preferences.core.intPreferencesKey -import androidx.datastore.preferences.preferencesDataStore -import kotlinx.coroutines.flow.first -import kotlinx.coroutines.runBlocking - -/** - * Internal [android.llama.cpp.InferenceEngine] loader implementation - */ -internal object InferenceEngineLoader { - private val TAG = InferenceEngineLoader::class.simpleName - - // CPU feature detection preferences - private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection" - private val Context.llamaTierDataStore: DataStore - by preferencesDataStore(name = DATASTORE_CPU_DETECTION) - - private val DETECTION_VERSION = intPreferencesKey("detection_version") - private val DETECTED_TIER = intPreferencesKey("detected_tier") - - // Constants - private const val DATASTORE_VERSION = 1 - - @JvmStatic - private external fun getOptimalTier(): Int - - @JvmStatic - private external fun getCpuFeaturesString(): String - - private var _cachedInstance: InferenceEngineImpl? = null - private var _detectedTier: LLamaTier? = null - - /** - * Get the detected tier, loading from cache if needed - */ - fun getDetectedTier(context: Context): LLamaTier? = - _detectedTier ?: runBlocking { obtainTier(context) } - - /** - * Factory method to get a configured [InferenceEngineImpl] instance. - * Handles tier detection, caching, and library loading automatically. - */ - @Synchronized - fun getInstance(context: Context): InferenceEngine? { - // Return cached instance if available - _cachedInstance?.let { return it } - - return runBlocking { - try { - // Create and cache the inference engine instance - InferenceEngineImpl.create(context).also { - _cachedInstance = it - Log.i(TAG, "Successfully instantiated Inference Engine") - } - - } catch (e: Exception) { - Log.e(TAG, "Error instantiating Inference Engine", e) - null - } - } - } - - /** - * First attempt to load detected tier from storage, if available; - * Otherwise, perform a fresh detection, then save to storage and cache. - */ - private suspend fun obtainTier(context: Context) = - loadDetectedTierFromDataStore(context) ?: run { - Log.i(TAG, "Performing fresh tier detection") - performOptimalTierDetection().also { tier -> - tier?.saveToDataStore(context) - _detectedTier = tier - } - } - - /** - * Load cached tier from datastore without performing detection - */ - private suspend fun loadDetectedTierFromDataStore(context: Context): LLamaTier? { - val preferences = context.llamaTierDataStore.data.first() - val cachedVersion = preferences[DETECTION_VERSION] ?: -1 - val cachedTierValue = preferences[DETECTED_TIER] ?: -1 - - return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) { - LLamaTier.fromRawValue(cachedTierValue)?.also { - Log.i(TAG, "Loaded cached tier: ${it.name}") - _detectedTier = it - } - } else { - Log.i(TAG, "No valid cached tier found") - null - } - } - - /** - * Actual implementation of optimal tier detection via native methods - */ - private fun performOptimalTierDetection(): LLamaTier? { - try { - // Load CPU detection library - System.loadLibrary("llama_cpu_detector") - Log.i(TAG, "CPU feature detector loaded successfully") - - // Detect optimal tier - val tierValue = getOptimalTier() - val features = getCpuFeaturesString() - Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features") - - // Convert to enum and validate - val tier = LLamaTier.fromRawValue(tierValue) ?: run { - Log.e(TAG, "Invalid tier value $tierValue") - return LLamaTier.NONE - } - - // Ensure we don't exceed maximum supported tier - val maxTier = LLamaTier.maxSupportedTier - return if (tier.rawValue > maxTier.rawValue) { - Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}") - maxTier - } else { - tier - } - - } catch (e: UnsatisfiedLinkError) { - Log.e(TAG, "Failed to load CPU detection library", e) - return null - - } catch (e: Exception) { - Log.e(TAG, "Unexpected error during tier detection", e) - return null - } - } - - /** - * Clear cached detection results (for testing/debugging) - */ - fun clearCache(context: Context) { - runBlocking { context.llamaTierDataStore.edit { it.clear() } } - _cachedInstance = null - _detectedTier = null - Log.i(TAG, "Cleared detection results and cached instance") - } - - private suspend fun LLamaTier.saveToDataStore(context: Context) { - context.llamaTierDataStore.edit { prefs -> - prefs[DETECTED_TIER] = this.rawValue - prefs[DETECTION_VERSION] = DATASTORE_VERSION - } - Log.i(TAG, "Saved ${this.name} to data store") - } -} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt index c439561aba..2683395a62 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt @@ -3,13 +3,132 @@ package android.llama.cpp.internal import android.content.Context import android.llama.cpp.LLamaTier import android.llama.cpp.TierDetection +import android.util.Log +import androidx.datastore.core.DataStore +import androidx.datastore.preferences.core.Preferences +import androidx.datastore.preferences.core.edit +import androidx.datastore.preferences.core.intPreferencesKey +import androidx.datastore.preferences.preferencesDataStore +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.runBlocking /** - * Internal tier detection implementation + * Internal [LLamaTier] detection implementation */ -internal class TierDetectionImpl(private val context: Context) : TierDetection { - override val detectedTier: LLamaTier? - get() = InferenceEngineLoader.getDetectedTier(context) +internal class TierDetectionImpl( + private val context: Context +): TierDetection { - override fun clearCache() = InferenceEngineLoader.clearCache(context) + companion object { + private val TAG = TierDetectionImpl::class.simpleName + + // CPU feature detection preferences + private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection" + private const val DATASTORE_VERSION = 1 + private val Context.llamaTierDataStore: DataStore + by preferencesDataStore(name = DATASTORE_CPU_DETECTION) + + private val DETECTION_VERSION = intPreferencesKey("detection_version") + private val DETECTED_TIER = intPreferencesKey("detected_tier") + } + + private external fun getOptimalTier(): Int + + private external fun getCpuFeaturesString(): String + + private var _detectedTier: LLamaTier? = null + + /** + * Get the detected tier, loading from cache if needed + */ + override fun getDetectedTier(): LLamaTier? = + _detectedTier ?: runBlocking { obtainTier() } + + /** + * First attempt to load detected tier from storage, if available; + * Otherwise, perform a fresh detection, then save to storage and cache. + */ + private suspend fun obtainTier() = + loadDetectedTierFromDataStore() ?: run { + Log.i(TAG, "Performing fresh tier detection") + performOptimalTierDetection().also { tier -> + tier?.saveToDataStore() + _detectedTier = tier + } + } + + /** + * Load cached tier from datastore without performing detection + */ + private suspend fun loadDetectedTierFromDataStore(): LLamaTier? { + val preferences = context.llamaTierDataStore.data.first() + val cachedVersion = preferences[DETECTION_VERSION] ?: -1 + val cachedTierValue = preferences[DETECTED_TIER] ?: -1 + + return if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) { + LLamaTier.fromRawValue(cachedTierValue)?.also { + Log.i(TAG, "Loaded cached tier: ${it.name}") + _detectedTier = it + } + } else { + Log.i(TAG, "No valid cached tier found") + null + } + } + + /** + * Actual implementation of optimal tier detection via native methods + */ + private fun performOptimalTierDetection(): LLamaTier? { + try { + // Load CPU detection library + System.loadLibrary("kleidi-llama-cpu-detector") + Log.i(TAG, "CPU feature detector loaded successfully") + + // Detect optimal tier + val tierValue = getOptimalTier() + val features = getCpuFeaturesString() + Log.i(TAG, "Raw tier $tierValue w/ CPU features: $features") + + // Convert to enum and validate + val tier = LLamaTier.fromRawValue(tierValue) ?: run { + Log.e(TAG, "Invalid tier value $tierValue") + return LLamaTier.NONE + } + + // Ensure we don't exceed maximum supported tier + val maxTier = LLamaTier.maxSupportedTier + return if (tier.rawValue > maxTier.rawValue) { + Log.w(TAG, "Detected tier ${tier.name} exceeds max supported, using ${maxTier.name}") + maxTier + } else { + tier + } + + } catch (e: UnsatisfiedLinkError) { + Log.e(TAG, "Failed to load CPU detection library", e) + return null + + } catch (e: Exception) { + Log.e(TAG, "Unexpected error during tier detection", e) + return null + } + } + + /** + * Clear cached detection results (for testing/debugging) + */ + override fun clearCache() { + runBlocking { context.llamaTierDataStore.edit { it.clear() } } + _detectedTier = null + Log.i(TAG, "Cleared CPU detection results") + } + + private suspend fun LLamaTier.saveToDataStore() { + context.llamaTierDataStore.edit { prefs -> + prefs[DETECTED_TIER] = this.rawValue + prefs[DETECTION_VERSION] = DATASTORE_VERSION + } + Log.i(TAG, "Saved ${this.name} to data store") + } }