diff --git a/examples/llama.android/llama/build.gradle.kts b/examples/llama.android/llama/build.gradle.kts index 8cdb978eb4..6770ad859d 100644 --- a/examples/llama.android/llama/build.gradle.kts +++ b/examples/llama.android/llama/build.gradle.kts @@ -59,6 +59,7 @@ android { dependencies { implementation(libs.androidx.core.ktx) + implementation(libs.androidx.datastore.preferences) testImplementation(libs.junit) androidTestImplementation(libs.androidx.junit) diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngineLoader.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngineLoader.kt index 854726c3f7..89412ac068 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngineLoader.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/InferenceEngineLoader.kt @@ -1,9 +1,14 @@ package android.llama.cpp import android.content.Context -import android.content.SharedPreferences import android.util.Log -import androidx.core.content.edit +import androidx.datastore.core.DataStore +import androidx.datastore.preferences.core.Preferences +import androidx.datastore.preferences.core.edit +import androidx.datastore.preferences.core.intPreferencesKey +import androidx.datastore.preferences.preferencesDataStore +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.runBlocking enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) { T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"), @@ -26,10 +31,16 @@ class InferenceEngineLoader private constructor() { companion object { private val TAG = InferenceEngineLoader::class.simpleName - private const val DETECTION_VERSION = 1 - private const val PREFS_NAME = "llama_cpu_detection" - private const val KEY_DETECTED_TIER = "detected_tier" - private const val KEY_DETECTION_VERSION = "detection_version" + // CPU feature detection preferences + private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection" + private val Context.llamaTierDataStore: DataStore + by preferencesDataStore(name = DATASTORE_CPU_DETECTION) + + private val DETECTION_VERSION = intPreferencesKey("detection_version") + private val DETECTED_TIER = intPreferencesKey("detected_tier") + + // Constants + private const val DATASTORE_VERSION = 1 @JvmStatic private external fun getOptimalTier(): Int @@ -50,28 +61,30 @@ class InferenceEngineLoader private constructor() { // Return cached instance if available _cachedInstance?.let { return it } - try { - // Obtain the optimal tier from cache if available - val tier = getOrDetectOptimalTier(context) ?: run { - Log.e(TAG, "Failed to determine optimal tier") - return null + return runBlocking { + try { + // Obtain the optimal tier from cache if available + val tier = getOrDetectOptimalTier(context) ?: run { + Log.e(TAG, "Failed to determine optimal tier") + return@runBlocking null + } + _detectedTier = tier + Log.i(TAG, "Using tier: ${tier.name} (${tier.description})") + + // Create and cache the inference engine instance + val instance = InferenceEngineImpl.createWithTier(tier) ?: run { + Log.e(TAG, "Failed to instantiate InferenceEngineImpl") + return@runBlocking null + } + _cachedInstance = instance + Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}") + + instance + + } catch (e: Exception) { + Log.e(TAG, "Error creating InferenceEngineImpl instance", e) + null } - _detectedTier = tier - Log.i(TAG, "Using tier: ${tier.name} (${tier.description})") - - // Create and cache the inference engine instance - val instance = InferenceEngineImpl.createWithTier(tier) ?: run { - Log.e(TAG, "Failed to instantiate InferenceEngineImpl") - return null - } - _cachedInstance = instance - Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}") - - return instance - - } catch (e: Exception) { - Log.e(TAG, "Error creating InferenceEngineImpl instance", e) - return null } } @@ -79,7 +92,7 @@ class InferenceEngineLoader private constructor() { * Clear cached detection results (for testing/debugging) */ fun clearCache(context: Context) { - getSharedPrefs(context).edit { clear() } + runBlocking { context.llamaTierDataStore.edit { it.clear() } } _cachedInstance = null _detectedTier = null Log.i(TAG, "Cleared detection results and cached instance") @@ -88,13 +101,13 @@ class InferenceEngineLoader private constructor() { /** * Get optimal tier from cache or detect it fresh */ - private fun getOrDetectOptimalTier(context: Context): LLamaTier? { - val prefs = getSharedPrefs(context) + private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? { + val preferences = context.llamaTierDataStore.data.first() // Check if we have a cached result with the current detection version - val cachedVersion = prefs.getInt(KEY_DETECTION_VERSION, -1) - val cachedTierValue = prefs.getInt(KEY_DETECTED_TIER, -1) - if (cachedVersion == DETECTION_VERSION && cachedTierValue >= 0) { + val cachedVersion = preferences[DETECTION_VERSION] ?: -1 + val cachedTierValue = preferences[DETECTED_TIER] ?: -1 + if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) { val cachedTier = LLamaTier.fromRawValue(cachedTierValue) if (cachedTier != null) { Log.i(TAG, "Using cached tier detection: ${cachedTier.name}") @@ -110,7 +123,7 @@ class InferenceEngineLoader private constructor() { /** * Detect optimal tier and save to cache */ - private fun detectAndCacheOptimalTier(context: Context): LLamaTier? { + private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? { try { // Load CPU detection library System.loadLibrary("llama_cpu_detector") @@ -136,9 +149,9 @@ class InferenceEngineLoader private constructor() { } // Cache the result - getSharedPrefs(context).edit { - putInt(KEY_DETECTED_TIER, finalTier.rawValue) - putInt(KEY_DETECTION_VERSION, DETECTION_VERSION) + context.llamaTierDataStore.edit { + it[DETECTED_TIER] = finalTier.rawValue + it[DETECTION_VERSION] = DATASTORE_VERSION } Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}") @@ -149,9 +162,9 @@ class InferenceEngineLoader private constructor() { // Fallback to T0 and cache it val fallbackTier = LLamaTier.T0 - getSharedPrefs(context).edit { - putInt(KEY_DETECTED_TIER, fallbackTier.rawValue) - putInt(KEY_DETECTION_VERSION, DETECTION_VERSION) + context.llamaTierDataStore.edit { + it[DETECTED_TIER] = fallbackTier.rawValue + it[DETECTION_VERSION] = DATASTORE_VERSION } Log.i(TAG, "Using fallback tier: ${fallbackTier.name}") @@ -162,9 +175,5 @@ class InferenceEngineLoader private constructor() { return null } } - - private fun getSharedPrefs(context: Context): SharedPreferences { - return context.getSharedPreferences(PREFS_NAME, Context.MODE_PRIVATE) - } } }