diff --git a/examples/llama.android/app/src/main/java/com/example/llama/legacy/LegacyViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/legacy/LegacyViewModel.kt index 82562f29ca..cf17a71aec 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/legacy/LegacyViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/legacy/LegacyViewModel.kt @@ -1,6 +1,8 @@ package com.example.llama.legacy +import android.content.Context import android.llama.cpp.LLamaAndroid +import android.llama.cpp.LLamaLibraryLoader import android.util.Log import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf @@ -10,15 +12,16 @@ import androidx.lifecycle.viewModelScope import kotlinx.coroutines.flow.catch import kotlinx.coroutines.launch -class LegacyViewModel( - private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance() -): ViewModel() { +class LegacyViewModel(context: Context): ViewModel() { companion object { + private val tag = LegacyViewModel::class.java.simpleName + @JvmStatic private val NanosPerSecond = 1_000_000_000.0 } - private val tag: String? = this::class.simpleName + val llamaAndroid: LLamaAndroid = LLamaLibraryLoader.createInstance(context) + ?: throw InstantiationException("Cannot instantiate LlamaAndroid!") var messages by mutableStateOf(listOf("Initializing...")) private set diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt index 4bada497e9..c0f22fff30 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/di/AppModule.kt @@ -2,7 +2,7 @@ package com.example.llama.revamp.di import android.content.Context import android.llama.cpp.InferenceEngine -import android.llama.cpp.LLamaAndroid +import android.llama.cpp.LLamaLibraryLoader import com.example.llama.revamp.data.local.AppDatabase import com.example.llama.revamp.data.remote.HuggingFaceApiService import com.example.llama.revamp.data.remote.HuggingFaceRemoteDataSource @@ -58,10 +58,15 @@ internal abstract class AppModule { ): HuggingFaceRemoteDataSource companion object { + private const val USE_REAL_ENGINE = true + @Provides - fun provideInferenceEngine(): InferenceEngine { - val useRealEngine = true - return if (useRealEngine) LLamaAndroid.instance() else StubInferenceEngine() + fun provideInferenceEngine(@ApplicationContext context: Context): InferenceEngine { + return if (USE_REAL_ENGINE) { + LLamaLibraryLoader.createInstance(context) ?: throw InstantiationException("Cannot instantiate LlamaAndroid!") + } else { + StubInferenceEngine() + } } @Provides diff --git a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt index df34127212..e280ae2f81 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/revamp/engine/StubInferenceEngine.kt @@ -16,15 +16,11 @@ import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.flow import kotlinx.coroutines.launch import kotlinx.coroutines.withContext -import org.jetbrains.annotations.TestOnly -import org.jetbrains.annotations.VisibleForTesting import javax.inject.Singleton /** * A stub [InferenceEngine] for agile development & testing */ -@VisibleForTesting -@TestOnly @Singleton class StubInferenceEngine : InferenceEngine { companion object { diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index a3e26644d7..06c30c987e 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -1,7 +1,6 @@ package android.llama.cpp import android.llama.cpp.InferenceEngine.State -import android.llama.cpp.LLamaAndroid.Companion.instance import android.util.Log import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope @@ -37,7 +36,35 @@ import java.io.File * * @see llama-android.cpp for the native implementation details */ -class LLamaAndroid private constructor() : InferenceEngine { +class LLamaAndroid private constructor(private val tier: LLamaTier) : InferenceEngine { + + companion object { + private val TAG = LLamaAndroid::class.java.simpleName + + private var initialized = false + + /** + * Create LLamaAndroid instance with specific tier + */ + internal fun createWithTier(tier: LLamaTier): LLamaAndroid? { + if (initialized) { + Log.w(TAG, "LLamaAndroid already initialized") + return null + } + + try { + Log.i(TAG, "Instantiating LLamaAndroid w/ ${tier.libraryName}") + val instance = LLamaAndroid(tier) + initialized = true + return instance + + } catch (e: UnsatisfiedLinkError) { + Log.e(TAG, "Failed to load ${tier.libraryName}", e) + return null + } + } + } + /** * JNI methods * @see llama-android.cpp @@ -74,13 +101,14 @@ class LLamaAndroid private constructor() : InferenceEngine { check(_state.value is State.Uninitialized) { "Cannot load native library in ${_state.value.javaClass.simpleName}!" } - _state.value = State.Initializing - Log.i(TAG, "Loading native library $LIB_LLAMA_ANDROID") - System.loadLibrary(LIB_LLAMA_ANDROID) + Log.i(TAG, "Loading native library for $tier") + + System.loadLibrary(tier.libraryName) init() _state.value = State.Initialized Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}") + } catch (e: Exception) { Log.e(TAG, "Failed to load native library", e) throw e @@ -235,15 +263,4 @@ class LLamaAndroid private constructor() : InferenceEngine { else -> { unload(); shutdown() } } } - - companion object { - private val TAG = LLamaAndroid::class.simpleName - - // TODO-han.yin: replace with dynamic loader - private const val LIB_LLAMA_ANDROID = "llama_android_t3" - - // Enforce only one instance of Llm. - private val _instance: LLamaAndroid = LLamaAndroid() - fun instance(): LLamaAndroid = _instance - } }