From 6dfdc2c105245dcce8b28a699a22a5af4a95099c Mon Sep 17 00:00:00 2001 From: Han Yin Date: Thu, 18 Sep 2025 16:55:01 -0700 Subject: [PATCH] lib: replace the factory pattern for deprecated tiered lib loading with single instance pattern --- .../java/com/example/llama/di/AppModule.kt | 3 +- .../java/android/llama/cpp/KleidiLlama.kt | 10 ++--- .../cpp/internal/InferenceEngineFactory.kt | 45 ------------------- .../llama/cpp/internal/InferenceEngineImpl.kt | 31 +++++++------ .../llama/cpp/internal/TierDetectionImpl.kt | 15 ++++++- 5 files changed, 35 insertions(+), 69 deletions(-) delete mode 100644 examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt diff --git a/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt b/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt index 272874ea3d..a840973857 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/di/AppModule.kt @@ -76,8 +76,7 @@ internal abstract class AppModule { return if (USE_STUB_ENGINE) { StubInferenceEngine() } else { - KleidiLlama.createInferenceEngine(context) - ?: throw InstantiationException("Cannot instantiate InferenceEngine!") + KleidiLlama.getInferenceEngine(context) } } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt index 4643753b35..ee212b86d4 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/KleidiLlama.kt @@ -1,7 +1,7 @@ package android.llama.cpp import android.content.Context -import android.llama.cpp.internal.InferenceEngineFactory +import android.llama.cpp.internal.InferenceEngineImpl import android.llama.cpp.internal.TierDetectionImpl /** @@ -10,12 +10,12 @@ import android.llama.cpp.internal.TierDetectionImpl */ object KleidiLlama { /** - * Create an inference engine instance with automatic tier detection. + * Get the inference engine single instance. */ - fun createInferenceEngine(context: Context) = InferenceEngineFactory.getInstance(context) + fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context) /** - * Get tier detection information for debugging/settings. + * Get tier detection single instance. */ - fun getTierDetection(context: Context): TierDetection = TierDetectionImpl(context) + fun getTierDetection(context: Context): TierDetection = TierDetectionImpl.getInstance(context) } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt deleted file mode 100644 index 811f13b587..0000000000 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineFactory.kt +++ /dev/null @@ -1,45 +0,0 @@ -package android.llama.cpp.internal - -import android.content.Context -import android.llama.cpp.InferenceEngine -import android.llama.cpp.TierDetection -import android.util.Log -import kotlinx.coroutines.runBlocking - -/** - * Internal factory to create [InferenceEngine] and [TierDetection] - */ -internal object InferenceEngineFactory { - private val TAG = InferenceEngineFactory::class.simpleName - - private var _cachedInstance: InferenceEngineImpl? = null - - /** - * Factory method to get a configured [InferenceEngineImpl] instance. - * Handles tier detection, caching, and library loading automatically. - */ - @Synchronized - fun getInstance(context: Context): InferenceEngine? { - // Return cached instance if available - _cachedInstance?.let { return it } - - return runBlocking { - try { - // Create and cache the inference engine instance - InferenceEngineImpl.create(context).also { - _cachedInstance = it - Log.i(TAG, "Successfully instantiated Inference Engine") - } - - } catch (e: Exception) { - Log.e(TAG, "Error instantiating Inference Engine", e) - null - } - } - } - - fun clearCache() { - _cachedInstance = null - Log.i(TAG, "Cleared cached instance of InferenceEngine") - } -} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt index 1c42406c4a..7ffd8e9503 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/InferenceEngineImpl.kt @@ -29,7 +29,7 @@ import java.io.IOException * with the underlying C++ native code. * * The typical usage flow is: - * 1. Get instance via [instance] + * 1. Get instance via [getInstance] * 2. Load a model with [loadModel] * 3. Send prompts with [sendUserPrompt] * 4. Generate responses as token streams @@ -47,30 +47,29 @@ internal class InferenceEngineImpl private constructor( companion object { private val TAG = InferenceEngineImpl::class.java.simpleName - private var initialized = false + @Volatile + private var instance: InferenceEngine? = null /** - * Create [InferenceEngineImpl] instance at runtime + * Create or obtain [InferenceEngineImpl]'s single instance. * * @param Context for obtaining native library directory * @throws IllegalArgumentException if native library path is invalid * @throws UnsatisfiedLinkError if library failed to load */ - internal fun create(context: Context): InferenceEngineImpl { - assert(!initialized) { "Inference Engine has already been initialized!" } + internal fun getInstance(context: Context) = + instance ?: synchronized(this) { + val nativeLibDir = context.applicationInfo.nativeLibraryDir + require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" } - val nativeLibDir = context.applicationInfo.nativeLibraryDir - require(nativeLibDir.isNotBlank()) { "Expected native library" } - - return try { - Log.i(TAG, "Instantiating InferenceEngineImpl,,,") - InferenceEngineImpl(nativeLibDir).also { initialized = true } - - } catch (e: UnsatisfiedLinkError) { - Log.e(TAG, "Failed to load native library from $nativeLibDir", e) - throw e + try { + Log.i(TAG, "Instantiating InferenceEngineImpl,,,") + InferenceEngineImpl(nativeLibDir).also { instance = it } + } catch (e: UnsatisfiedLinkError) { + Log.e(TAG, "Failed to load native library from $nativeLibDir", e) + throw e + } } - } } /** diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt index 2683395a62..bb4a873b99 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/internal/TierDetectionImpl.kt @@ -15,7 +15,7 @@ import kotlinx.coroutines.runBlocking /** * Internal [LLamaTier] detection implementation */ -internal class TierDetectionImpl( +internal class TierDetectionImpl private constructor( private val context: Context ): TierDetection { @@ -30,6 +30,19 @@ internal class TierDetectionImpl( private val DETECTION_VERSION = intPreferencesKey("detection_version") private val DETECTED_TIER = intPreferencesKey("detected_tier") + + @Volatile + private var instance: TierDetection? = null + + /** + * Create or obtain [TierDetectionImpl]'s single instance. + * + * @param Context for obtaining the data store + */ + internal fun getInstance(context: Context) = + instance ?: synchronized(this) { + instance ?: TierDetectionImpl(context).also { instance = it } + } } private external fun getOptimalTier(): Int