lib: replace the naive & plain SharedPreferences with DataStore implementation

This commit is contained in:
Han Yin 2025-06-26 20:38:52 -07:00
parent 130cba9aa6
commit 57c3a9dda7
2 changed files with 54 additions and 44 deletions

View File

@ -59,6 +59,7 @@ android {
dependencies {
implementation(libs.androidx.core.ktx)
implementation(libs.androidx.datastore.preferences)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)

View File

@ -1,9 +1,14 @@
package android.llama.cpp
import android.content.Context
import android.content.SharedPreferences
import android.util.Log
import androidx.core.content.edit
import androidx.datastore.core.DataStore
import androidx.datastore.preferences.core.Preferences
import androidx.datastore.preferences.core.edit
import androidx.datastore.preferences.core.intPreferencesKey
import androidx.datastore.preferences.preferencesDataStore
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.runBlocking
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
@ -26,10 +31,16 @@ class InferenceEngineLoader private constructor() {
companion object {
private val TAG = InferenceEngineLoader::class.simpleName
private const val DETECTION_VERSION = 1
private const val PREFS_NAME = "llama_cpu_detection"
private const val KEY_DETECTED_TIER = "detected_tier"
private const val KEY_DETECTION_VERSION = "detection_version"
// CPU feature detection preferences
private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection"
private val Context.llamaTierDataStore: DataStore<Preferences>
by preferencesDataStore(name = DATASTORE_CPU_DETECTION)
private val DETECTION_VERSION = intPreferencesKey("detection_version")
private val DETECTED_TIER = intPreferencesKey("detected_tier")
// Constants
private const val DATASTORE_VERSION = 1
@JvmStatic
private external fun getOptimalTier(): Int
@ -50,11 +61,12 @@ class InferenceEngineLoader private constructor() {
// Return cached instance if available
_cachedInstance?.let { return it }
return runBlocking {
try {
// Obtain the optimal tier from cache if available
val tier = getOrDetectOptimalTier(context) ?: run {
Log.e(TAG, "Failed to determine optimal tier")
return null
return@runBlocking null
}
_detectedTier = tier
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
@ -62,16 +74,17 @@ class InferenceEngineLoader private constructor() {
// Create and cache the inference engine instance
val instance = InferenceEngineImpl.createWithTier(tier) ?: run {
Log.e(TAG, "Failed to instantiate InferenceEngineImpl")
return null
return@runBlocking null
}
_cachedInstance = instance
Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}")
return instance
instance
} catch (e: Exception) {
Log.e(TAG, "Error creating InferenceEngineImpl instance", e)
return null
null
}
}
}
@ -79,7 +92,7 @@ class InferenceEngineLoader private constructor() {
* Clear cached detection results (for testing/debugging)
*/
fun clearCache(context: Context) {
getSharedPrefs(context).edit { clear() }
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
_cachedInstance = null
_detectedTier = null
Log.i(TAG, "Cleared detection results and cached instance")
@ -88,13 +101,13 @@ class InferenceEngineLoader private constructor() {
/**
* Get optimal tier from cache or detect it fresh
*/
private fun getOrDetectOptimalTier(context: Context): LLamaTier? {
val prefs = getSharedPrefs(context)
private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? {
val preferences = context.llamaTierDataStore.data.first()
// Check if we have a cached result with the current detection version
val cachedVersion = prefs.getInt(KEY_DETECTION_VERSION, -1)
val cachedTierValue = prefs.getInt(KEY_DETECTED_TIER, -1)
if (cachedVersion == DETECTION_VERSION && cachedTierValue >= 0) {
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
val cachedTier = LLamaTier.fromRawValue(cachedTierValue)
if (cachedTier != null) {
Log.i(TAG, "Using cached tier detection: ${cachedTier.name}")
@ -110,7 +123,7 @@ class InferenceEngineLoader private constructor() {
/**
* Detect optimal tier and save to cache
*/
private fun detectAndCacheOptimalTier(context: Context): LLamaTier? {
private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? {
try {
// Load CPU detection library
System.loadLibrary("llama_cpu_detector")
@ -136,9 +149,9 @@ class InferenceEngineLoader private constructor() {
}
// Cache the result
getSharedPrefs(context).edit {
putInt(KEY_DETECTED_TIER, finalTier.rawValue)
putInt(KEY_DETECTION_VERSION, DETECTION_VERSION)
context.llamaTierDataStore.edit {
it[DETECTED_TIER] = finalTier.rawValue
it[DETECTION_VERSION] = DATASTORE_VERSION
}
Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}")
@ -149,9 +162,9 @@ class InferenceEngineLoader private constructor() {
// Fallback to T0 and cache it
val fallbackTier = LLamaTier.T0
getSharedPrefs(context).edit {
putInt(KEY_DETECTED_TIER, fallbackTier.rawValue)
putInt(KEY_DETECTION_VERSION, DETECTION_VERSION)
context.llamaTierDataStore.edit {
it[DETECTED_TIER] = fallbackTier.rawValue
it[DETECTION_VERSION] = DATASTORE_VERSION
}
Log.i(TAG, "Using fallback tier: ${fallbackTier.name}")
@ -162,9 +175,5 @@ class InferenceEngineLoader private constructor() {
return null
}
}
private fun getSharedPrefs(context: Context): SharedPreferences {
return context.getSharedPreferences(PREFS_NAME, Context.MODE_PRIVATE)
}
}
}