lib: replace the naive & plain SharedPreferences with DataStore implementation
This commit is contained in:
parent
130cba9aa6
commit
57c3a9dda7
|
|
@ -59,6 +59,7 @@ android {
|
|||
|
||||
dependencies {
|
||||
implementation(libs.androidx.core.ktx)
|
||||
implementation(libs.androidx.datastore.preferences)
|
||||
|
||||
testImplementation(libs.junit)
|
||||
androidTestImplementation(libs.androidx.junit)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
package android.llama.cpp
|
||||
|
||||
import android.content.Context
|
||||
import android.content.SharedPreferences
|
||||
import android.util.Log
|
||||
import androidx.core.content.edit
|
||||
import androidx.datastore.core.DataStore
|
||||
import androidx.datastore.preferences.core.Preferences
|
||||
import androidx.datastore.preferences.core.edit
|
||||
import androidx.datastore.preferences.core.intPreferencesKey
|
||||
import androidx.datastore.preferences.preferencesDataStore
|
||||
import kotlinx.coroutines.flow.first
|
||||
import kotlinx.coroutines.runBlocking
|
||||
|
||||
enum class LLamaTier(val rawValue: Int, val libraryName: String, val description: String) {
|
||||
T0(0, "llama_android_t0", "ARMv8-a baseline with SIMD"),
|
||||
|
|
@ -26,10 +31,16 @@ class InferenceEngineLoader private constructor() {
|
|||
companion object {
|
||||
private val TAG = InferenceEngineLoader::class.simpleName
|
||||
|
||||
private const val DETECTION_VERSION = 1
|
||||
private const val PREFS_NAME = "llama_cpu_detection"
|
||||
private const val KEY_DETECTED_TIER = "detected_tier"
|
||||
private const val KEY_DETECTION_VERSION = "detection_version"
|
||||
// CPU feature detection preferences
|
||||
private const val DATASTORE_CPU_DETECTION = "llama_cpu_detection"
|
||||
private val Context.llamaTierDataStore: DataStore<Preferences>
|
||||
by preferencesDataStore(name = DATASTORE_CPU_DETECTION)
|
||||
|
||||
private val DETECTION_VERSION = intPreferencesKey("detection_version")
|
||||
private val DETECTED_TIER = intPreferencesKey("detected_tier")
|
||||
|
||||
// Constants
|
||||
private const val DATASTORE_VERSION = 1
|
||||
|
||||
@JvmStatic
|
||||
private external fun getOptimalTier(): Int
|
||||
|
|
@ -50,28 +61,30 @@ class InferenceEngineLoader private constructor() {
|
|||
// Return cached instance if available
|
||||
_cachedInstance?.let { return it }
|
||||
|
||||
try {
|
||||
// Obtain the optimal tier from cache if available
|
||||
val tier = getOrDetectOptimalTier(context) ?: run {
|
||||
Log.e(TAG, "Failed to determine optimal tier")
|
||||
return null
|
||||
return runBlocking {
|
||||
try {
|
||||
// Obtain the optimal tier from cache if available
|
||||
val tier = getOrDetectOptimalTier(context) ?: run {
|
||||
Log.e(TAG, "Failed to determine optimal tier")
|
||||
return@runBlocking null
|
||||
}
|
||||
_detectedTier = tier
|
||||
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
|
||||
|
||||
// Create and cache the inference engine instance
|
||||
val instance = InferenceEngineImpl.createWithTier(tier) ?: run {
|
||||
Log.e(TAG, "Failed to instantiate InferenceEngineImpl")
|
||||
return@runBlocking null
|
||||
}
|
||||
_cachedInstance = instance
|
||||
Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}")
|
||||
|
||||
instance
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error creating InferenceEngineImpl instance", e)
|
||||
null
|
||||
}
|
||||
_detectedTier = tier
|
||||
Log.i(TAG, "Using tier: ${tier.name} (${tier.description})")
|
||||
|
||||
// Create and cache the inference engine instance
|
||||
val instance = InferenceEngineImpl.createWithTier(tier) ?: run {
|
||||
Log.e(TAG, "Failed to instantiate InferenceEngineImpl")
|
||||
return null
|
||||
}
|
||||
_cachedInstance = instance
|
||||
Log.i(TAG, "Successfully created InferenceEngineImpl instance with ${tier.name}")
|
||||
|
||||
return instance
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error creating InferenceEngineImpl instance", e)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -79,7 +92,7 @@ class InferenceEngineLoader private constructor() {
|
|||
* Clear cached detection results (for testing/debugging)
|
||||
*/
|
||||
fun clearCache(context: Context) {
|
||||
getSharedPrefs(context).edit { clear() }
|
||||
runBlocking { context.llamaTierDataStore.edit { it.clear() } }
|
||||
_cachedInstance = null
|
||||
_detectedTier = null
|
||||
Log.i(TAG, "Cleared detection results and cached instance")
|
||||
|
|
@ -88,13 +101,13 @@ class InferenceEngineLoader private constructor() {
|
|||
/**
|
||||
* Get optimal tier from cache or detect it fresh
|
||||
*/
|
||||
private fun getOrDetectOptimalTier(context: Context): LLamaTier? {
|
||||
val prefs = getSharedPrefs(context)
|
||||
private suspend fun getOrDetectOptimalTier(context: Context): LLamaTier? {
|
||||
val preferences = context.llamaTierDataStore.data.first()
|
||||
|
||||
// Check if we have a cached result with the current detection version
|
||||
val cachedVersion = prefs.getInt(KEY_DETECTION_VERSION, -1)
|
||||
val cachedTierValue = prefs.getInt(KEY_DETECTED_TIER, -1)
|
||||
if (cachedVersion == DETECTION_VERSION && cachedTierValue >= 0) {
|
||||
val cachedVersion = preferences[DETECTION_VERSION] ?: -1
|
||||
val cachedTierValue = preferences[DETECTED_TIER] ?: -1
|
||||
if (cachedVersion == DATASTORE_VERSION && cachedTierValue >= 0) {
|
||||
val cachedTier = LLamaTier.fromRawValue(cachedTierValue)
|
||||
if (cachedTier != null) {
|
||||
Log.i(TAG, "Using cached tier detection: ${cachedTier.name}")
|
||||
|
|
@ -110,7 +123,7 @@ class InferenceEngineLoader private constructor() {
|
|||
/**
|
||||
* Detect optimal tier and save to cache
|
||||
*/
|
||||
private fun detectAndCacheOptimalTier(context: Context): LLamaTier? {
|
||||
private suspend fun detectAndCacheOptimalTier(context: Context): LLamaTier? {
|
||||
try {
|
||||
// Load CPU detection library
|
||||
System.loadLibrary("llama_cpu_detector")
|
||||
|
|
@ -136,9 +149,9 @@ class InferenceEngineLoader private constructor() {
|
|||
}
|
||||
|
||||
// Cache the result
|
||||
getSharedPrefs(context).edit {
|
||||
putInt(KEY_DETECTED_TIER, finalTier.rawValue)
|
||||
putInt(KEY_DETECTION_VERSION, DETECTION_VERSION)
|
||||
context.llamaTierDataStore.edit {
|
||||
it[DETECTED_TIER] = finalTier.rawValue
|
||||
it[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
|
||||
Log.i(TAG, "Detected and cached optimal tier: ${finalTier.name}")
|
||||
|
|
@ -149,9 +162,9 @@ class InferenceEngineLoader private constructor() {
|
|||
|
||||
// Fallback to T0 and cache it
|
||||
val fallbackTier = LLamaTier.T0
|
||||
getSharedPrefs(context).edit {
|
||||
putInt(KEY_DETECTED_TIER, fallbackTier.rawValue)
|
||||
putInt(KEY_DETECTION_VERSION, DETECTION_VERSION)
|
||||
context.llamaTierDataStore.edit {
|
||||
it[DETECTED_TIER] = fallbackTier.rawValue
|
||||
it[DETECTION_VERSION] = DATASTORE_VERSION
|
||||
}
|
||||
|
||||
Log.i(TAG, "Using fallback tier: ${fallbackTier.name}")
|
||||
|
|
@ -162,9 +175,5 @@ class InferenceEngineLoader private constructor() {
|
|||
return null
|
||||
}
|
||||
}
|
||||
|
||||
private fun getSharedPrefs(context: Context): SharedPreferences {
|
||||
return context.getSharedPreferences(PREFS_NAME, Context.MODE_PRIVATE)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue