lib: replace the factory pattern for deprecated tiered lib loading with single instance pattern

This commit is contained in:
Han Yin 2025-09-18 16:55:01 -07:00
parent 63e5bd0771
commit 6dfdc2c105
5 changed files with 35 additions and 69 deletions

View File

@ -76,8 +76,7 @@ internal abstract class AppModule {
return if (USE_STUB_ENGINE) { return if (USE_STUB_ENGINE) {
StubInferenceEngine() StubInferenceEngine()
} else { } else {
KleidiLlama.createInferenceEngine(context) KleidiLlama.getInferenceEngine(context)
?: throw InstantiationException("Cannot instantiate InferenceEngine!")
} }
} }

View File

@ -1,7 +1,7 @@
package android.llama.cpp package android.llama.cpp
import android.content.Context import android.content.Context
import android.llama.cpp.internal.InferenceEngineFactory import android.llama.cpp.internal.InferenceEngineImpl
import android.llama.cpp.internal.TierDetectionImpl import android.llama.cpp.internal.TierDetectionImpl
/** /**
@ -10,12 +10,12 @@ import android.llama.cpp.internal.TierDetectionImpl
*/ */
object KleidiLlama { object KleidiLlama {
/** /**
* Create an inference engine instance with automatic tier detection. * Get the inference engine single instance.
*/ */
fun createInferenceEngine(context: Context) = InferenceEngineFactory.getInstance(context) fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context)
/** /**
* Get tier detection information for debugging/settings. * Get tier detection single instance.
*/ */
fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context) fun getTierDetection(context: Context): TierDetection = TierDetectionImpl.getInstance(context)
} }

View File

@ -1,45 +0,0 @@
package android.llama.cpp.internal
import android.content.Context
import android.llama.cpp.InferenceEngine
import android.llama.cpp.TierDetection
import android.util.Log
import kotlinx.coroutines.runBlocking
/**
* Internal factory to create [InferenceEngine] and [TierDetection]
*/
internal object InferenceEngineFactory {
private val TAG = InferenceEngineFactory::class.simpleName
private var _cachedInstance: InferenceEngineImpl? = null
/**
* Factory method to get a configured [InferenceEngineImpl] instance.
* Handles tier detection, caching, and library loading automatically.
*/
@Synchronized
fun getInstance(context: Context): InferenceEngine? {
// Return cached instance if available
_cachedInstance?.let { return it }
return runBlocking {
try {
// Create and cache the inference engine instance
InferenceEngineImpl.create(context).also {
_cachedInstance = it
Log.i(TAG, "Successfully instantiated Inference Engine")
}
} catch (e: Exception) {
Log.e(TAG, "Error instantiating Inference Engine", e)
null
}
}
}
fun clearCache() {
_cachedInstance = null
Log.i(TAG, "Cleared cached instance of InferenceEngine")
}
}

View File

@ -29,7 +29,7 @@ import java.io.IOException
* with the underlying C++ native code. * with the underlying C++ native code.
* *
* The typical usage flow is: * The typical usage flow is:
* 1. Get instance via [instance] * 1. Get instance via [getInstance]
* 2. Load a model with [loadModel] * 2. Load a model with [loadModel]
* 3. Send prompts with [sendUserPrompt] * 3. Send prompts with [sendUserPrompt]
* 4. Generate responses as token streams * 4. Generate responses as token streams
@ -47,30 +47,29 @@ internal class InferenceEngineImpl private constructor(
companion object { companion object {
private val TAG = InferenceEngineImpl::class.java.simpleName private val TAG = InferenceEngineImpl::class.java.simpleName
private var initialized = false @Volatile
private var instance: InferenceEngine? = null
/** /**
* Create [InferenceEngineImpl] instance at runtime * Create or obtain [InferenceEngineImpl]'s single instance.
* *
* @param Context for obtaining native library directory * @param Context for obtaining native library directory
* @throws IllegalArgumentException if native library path is invalid * @throws IllegalArgumentException if native library path is invalid
* @throws UnsatisfiedLinkError if library failed to load * @throws UnsatisfiedLinkError if library failed to load
*/ */
internal fun create(context: Context): InferenceEngineImpl { internal fun getInstance(context: Context) =
assert(!initialized) { "Inference Engine has already been initialized!" } instance ?: synchronized(this) {
val nativeLibDir = context.applicationInfo.nativeLibraryDir
require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" }
val nativeLibDir = context.applicationInfo.nativeLibraryDir try {
require(nativeLibDir.isNotBlank()) { "Expected native library" } Log.i(TAG, "Instantiating InferenceEngineImpl,,,")
InferenceEngineImpl(nativeLibDir).also { instance = it }
return try { } catch (e: UnsatisfiedLinkError) {
Log.i(TAG, "Instantiating InferenceEngineImpl,,,") Log.e(TAG, "Failed to load native library from $nativeLibDir", e)
InferenceEngineImpl(nativeLibDir).also { initialized = true } throw e
}
} catch (e: UnsatisfiedLinkError) {
Log.e(TAG, "Failed to load native library from $nativeLibDir", e)
throw e
} }
}
} }
/** /**

View File

@ -15,7 +15,7 @@ import kotlinx.coroutines.runBlocking
/** /**
* Internal [LLamaTier] detection implementation * Internal [LLamaTier] detection implementation
*/ */
internal class TierDetectionImpl( internal class TierDetectionImpl private constructor(
private val context: Context private val context: Context
): TierDetection { ): TierDetection {
@ -30,6 +30,19 @@ internal class TierDetectionImpl(
private val DETECTION_VERSION = intPreferencesKey("detection_version") private val DETECTION_VERSION = intPreferencesKey("detection_version")
private val DETECTED_TIER = intPreferencesKey("detected_tier") private val DETECTED_TIER = intPreferencesKey("detected_tier")
@Volatile
private var instance: TierDetection? = null
/**
* Create or obtain [TierDetectionImpl]'s single instance.
*
* @param Context for obtaining the data store
*/
internal fun getInstance(context: Context) =
instance ?: synchronized(this) {
instance ?: TierDetectionImpl(context).also { instance = it }
}
} }
private external fun getOptimalTier(): Int private external fun getOptimalTier(): Int