lib: replace the naive & plain SharedPreferences with DataStore implementation

This commit is contained in:
Han Yin 2025-06-26 20:38:52 -07:00
parent 130cba9aa6
commit 57c3a9dda7
2 changed files with 54 additions and 44 deletions

View File

@ -59,6 +59,7 @@ android {
dependencies { dependencies {
implementation(libs.androidx.core.ktx) implementation(libs.androidx.core.ktx)
implementation(libs.androidx.datastore.preferences)
testImplementation(libs.junit) testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit) androidTestImplementation(libs.androidx.junit)

View File

@ -1,9 +1,14 @@
package android.llama.cpp package android.llama.cpp
import android.content.Context import android.content.Context
import android.content.SharedPreferences
import android.util.Log import android.util.Log
import androidx.core.content.edit import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.edit
import androidx.datastore.preferences.core.intPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.runBlocking
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) { enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"), T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
@ -26,10 +31,16 @@ class InferenceEngineLoader private constructor() {
companion object { companion object {
private val TAG = InferenceEngineLoader::class.simpleName private val TAG = InferenceEngineLoader::class.simpleName
private const val DETECTION_VERSION = 1 // CPU feature detection preferences
private const val PREFS_NAME = "llama_cpu_detection" private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection"
private const val KEY_DETECTED_TIER = "detected_tier" private val Context.llamaTierDataStore: DataStore<Preferences>
private const val KEY_DETECTION_VERSION = "detection_version" by preferencesDataStore(name = DATASTORE_CPU_DETECTION)
private val DETECTION_VERSION = intPreferencesKey("detection_version")
private val DETECTED_TIER = intPreferencesKey("detected_tier")
// Constants
private const val DATASTORE_VERSION = 1
@JvmStatic @JvmStatic
private external fun getOptimalTier(): Int private external fun getOptimalTier(): Int
@ -50,11 +61,12 @@ class InferenceEngineLoader private constructor() {
// Return cached instance if available // Return cached instance if available
_cachedInstance?.let { return it } _cachedInstance?.let { return it }
return runBlocking {
try { try {
// Obtain the optimal tier from cache if available // Obtain the optimal tier from cache if available
val tier = getOrDetectOptimalTier(context) ?: run { val tier = getOrDetectOptimalTier(context) ?: run {
Log.e(TAG, "Failed to determine optimal tier") Log.e(TAG, "Failed to determine optimal tier")
return null return@runBlocking null
} }
_detectedTier = tier _detectedTier = tier
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})") Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
@ -62,16 +74,17 @@ class InferenceEngineLoader private constructor() {
// Create and cache the inference engine instance // Create and cache the inference engine instance
val instance = InferenceEngineImpl.createWithTier(tier) ?: run { val instance = InferenceEngineImpl.createWithTier(tier) ?: run {
Log.e(TAG, "Failed to instantiate InferenceEngineImpl") Log.e(TAG, "Failed to instantiate InferenceEngineImpl")
return null return@runBlocking null
} }
_cachedInstance = instance _cachedInstance = instance
Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}") Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}")
return instance instance
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error creating InferenceEngineImpl instance", e) Log.e(TAG, "Error creating InferenceEngineImpl instance", e)
return null null
}
} }
} }
@ -79,7 +92,7 @@ class InferenceEngineLoader private constructor() {
* Clear cached detection results (for testing/debugging) * Clear cached detection results (for testing/debugging)
*/ */
fun clearCache(context: Context) { fun clearCache(context: Context) {
getSharedPrefs(context).edit { clear() } runBlocking { context.llamaTierDataStore.edit { it.clear() } }
_cachedInstance = null _cachedInstance = null
_detectedTier = null _detectedTier = null
Log.i(TAG, "Cleared detection results and cached instance") Log.i(TAG, "Cleared detection results and cached instance")
@ -88,13 +101,13 @@ class InferenceEngineLoader private constructor() {
/** /**
* Get optimal tier from cache or detect it fresh * Get optimal tier from cache or detect it fresh
*/ */
private fun getOrDetectOptimalTier(context: Context): LLamaTier? { private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? {
val prefs = getSharedPrefs(context) val preferences = context.llamaTierDataStore.data.first()
// Check if we have a cached result with the current detection version // Check if we have a cached result with the current detection version
val cachedVersion = prefs.getInt(KEY_DETECTION_VERSION, -1) val cachedVersion = preferences[DETECTION_VERSION] ?: -1
val cachedTierValue = prefs.getInt(KEY_DETECTED_TIER, -1) val cachedTierValue = preferences[DETECTED_TIER] ?: -1
if (cachedVersion == DETECTION_VERSION && cachedTierValue >= 0) { if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
val cachedTier = LLamaTier.fromRawValue(cachedTierValue) val cachedTier = LLamaTier.fromRawValue(cachedTierValue)
if (cachedTier != null) { if (cachedTier != null) {
Log.i(TAG, "Using cached tier detection: ${cachedTier.name}") Log.i(TAG, "Using cached tier detection: ${cachedTier.name}")
@ -110,7 +123,7 @@ class InferenceEngineLoader private constructor() {
/** /**
* Detect optimal tier and save to cache * Detect optimal tier and save to cache
*/ */
private fun detectAndCacheOptimalTier(context: Context): LLamaTier? { private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? {
try { try {
// Load CPU detection library // Load CPU detection library
System.loadLibrary("llama_cpu_detector") System.loadLibrary("llama_cpu_detector")
@ -136,9 +149,9 @@ class InferenceEngineLoader private constructor() {
} }
// Cache the result // Cache the result
getSharedPrefs(context).edit { context.llamaTierDataStore.edit {
putInt(KEY_DETECTED_TIER, finalTier.rawValue) it[DETECTED_TIER] = finalTier.rawValue
putInt(KEY_DETECTION_VERSION, DETECTION_VERSION) it[DETECTION_VERSION] = DATASTORE_VERSION
} }
Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}") Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}")
@ -149,9 +162,9 @@ class InferenceEngineLoader private constructor() {
// Fallback to T0 and cache it // Fallback to T0 and cache it
val fallbackTier = LLamaTier.T0 val fallbackTier = LLamaTier.T0
getSharedPrefs(context).edit { context.llamaTierDataStore.edit {
putInt(KEY_DETECTED_TIER, fallbackTier.rawValue) it[DETECTED_TIER] = fallbackTier.rawValue
putInt(KEY_DETECTION_VERSION, DETECTION_VERSION) it[DETECTION_VERSION] = DATASTORE_VERSION
} }
Log.i(TAG, "Using fallback tier: ${fallbackTier.name}") Log.i(TAG, "Using fallback tier: ${fallbackTier.name}")
@ -162,9 +175,5 @@ class InferenceEngineLoader private constructor() {
return null return null
} }
} }
private fun getSharedPrefs(context: Context): SharedPreferences {
return context.getSharedPreferences(PREFS_NAME, Context.MODE_PRIVATE)
}
} }
} }