LLama: refactor loadModel by splitting the system prompt setting into a separate method
This commit is contained in:
parent
9f77155535
commit
65d4a57a8b
|
|
@ -18,47 +18,4 @@ data class ModelInfo(
|
||||||
) {
|
) {
|
||||||
val formattedSize: String
|
val formattedSize: String
|
||||||
get() = formatSize(sizeInBytes)
|
get() = formatSize(sizeInBytes)
|
||||||
|
|
||||||
companion object {
|
|
||||||
/**
|
|
||||||
* Creates a list of sample models for development and testing.
|
|
||||||
*/
|
|
||||||
fun getSampleModels(): List<ModelInfo> {
|
|
||||||
return listOf(
|
|
||||||
ModelInfo(
|
|
||||||
id = "mistral-7b",
|
|
||||||
name = "Mistral 7B",
|
|
||||||
path = "/storage/models/mistral-7b-q4_0.gguf",
|
|
||||||
sizeInBytes = 4_000_000_000,
|
|
||||||
parameters = "7B",
|
|
||||||
quantization = "Q4_K_M",
|
|
||||||
type = "Mistral",
|
|
||||||
contextLength = 8192,
|
|
||||||
lastUsed = System.currentTimeMillis() - 86400000 // 1 day ago
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
id = "llama2-13b",
|
|
||||||
name = "Llama 2 13B",
|
|
||||||
path = "/storage/models/llama2-13b-q5_k_m.gguf",
|
|
||||||
sizeInBytes = 8_500_000_000,
|
|
||||||
parameters = "13B",
|
|
||||||
quantization = "Q5_K_M",
|
|
||||||
type = "Llama",
|
|
||||||
contextLength = 4096,
|
|
||||||
lastUsed = System.currentTimeMillis() - 259200000 // 3 days ago
|
|
||||||
),
|
|
||||||
ModelInfo(
|
|
||||||
id = "phi-2",
|
|
||||||
name = "Phi-2",
|
|
||||||
path = "/storage/models/phi-2.gguf",
|
|
||||||
sizeInBytes = 2_800_000_000,
|
|
||||||
parameters = "2.7B",
|
|
||||||
quantization = "Q4_0",
|
|
||||||
type = "Phi",
|
|
||||||
contextLength = 2048,
|
|
||||||
lastUsed = null
|
|
||||||
)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService {
|
||||||
/**
|
/**
|
||||||
* Load a model for benchmark
|
* Load a model for benchmark
|
||||||
*/
|
*/
|
||||||
suspend fun loadModelForBenchmark(): Boolean
|
suspend fun loadModelForBenchmark(): ModelLoadingMetrics?
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load a model for conversation
|
* Load a model for conversation
|
||||||
*/
|
*/
|
||||||
suspend fun loadModelForConversation(systemPrompt: String?): Boolean
|
suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics?
|
||||||
}
|
}
|
||||||
|
|
||||||
interface BenchmarkService : InferenceService {
|
interface BenchmarkService : InferenceService {
|
||||||
|
|
@ -80,6 +80,17 @@ interface ConversationService : InferenceService {
|
||||||
fun createTokenMetrics(): TokenMetrics
|
fun createTokenMetrics(): TokenMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Metrics for model loading and system prompt processing
|
||||||
|
*/
|
||||||
|
data class ModelLoadingMetrics(
|
||||||
|
val modelLoadingTimeMs: Long,
|
||||||
|
val systemPromptProcessingTimeMs: Long? = null
|
||||||
|
) {
|
||||||
|
val totalTimeMs: Long
|
||||||
|
get() = modelLoadingTimeMs + (systemPromptProcessingTimeMs ?: 0)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Represents an update during text generation
|
* Represents an update during text generation
|
||||||
*/
|
*/
|
||||||
|
|
@ -115,9 +126,7 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
||||||
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
|
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
|
||||||
override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
|
override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
|
||||||
|
|
||||||
override fun setCurrentModel(model: ModelInfo) {
|
override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model }
|
||||||
_currentModel.value = model
|
|
||||||
}
|
|
||||||
|
|
||||||
override suspend fun unloadModel() = inferenceEngine.unloadModel()
|
override suspend fun unloadModel() = inferenceEngine.unloadModel()
|
||||||
|
|
||||||
|
|
@ -129,29 +138,45 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
||||||
|
|
||||||
/* ModelLoadingService implementation */
|
/* ModelLoadingService implementation */
|
||||||
|
|
||||||
override suspend fun loadModelForBenchmark(): Boolean {
|
override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? {
|
||||||
return _currentModel.value?.let { model ->
|
return _currentModel.value?.let { model ->
|
||||||
try {
|
try {
|
||||||
|
val modelLoadStartTs = System.currentTimeMillis()
|
||||||
inferenceEngine.loadModel(model.path)
|
inferenceEngine.loadModel(model.path)
|
||||||
true
|
val modelLoadEndTs = System.currentTimeMillis()
|
||||||
|
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.e("InferenceManager", "Error loading model", e)
|
Log.e("InferenceManager", "Error loading model", e)
|
||||||
false
|
null
|
||||||
}
|
}
|
||||||
} == true
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override suspend fun loadModelForConversation(systemPrompt: String?): Boolean {
|
override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? {
|
||||||
_systemPrompt.value = systemPrompt
|
_systemPrompt.value = systemPrompt
|
||||||
return _currentModel.value?.let { model ->
|
return _currentModel.value?.let { model ->
|
||||||
try {
|
try {
|
||||||
inferenceEngine.loadModel(model.path, systemPrompt)
|
val modelLoadStartTs = System.currentTimeMillis()
|
||||||
true
|
inferenceEngine.loadModel(model.path)
|
||||||
|
val modelLoadEndTs = System.currentTimeMillis()
|
||||||
|
|
||||||
|
if (systemPrompt.isNullOrBlank()) {
|
||||||
|
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
|
||||||
|
} else {
|
||||||
|
val prompt: String = systemPrompt
|
||||||
|
val systemPromptStartTs = System.currentTimeMillis()
|
||||||
|
inferenceEngine.setSystemPrompt(prompt)
|
||||||
|
val systemPromptEndTs = System.currentTimeMillis()
|
||||||
|
ModelLoadingMetrics(
|
||||||
|
modelLoadingTimeMs = modelLoadEndTs - modelLoadStartTs,
|
||||||
|
systemPromptProcessingTimeMs = systemPromptEndTs - systemPromptStartTs
|
||||||
|
)
|
||||||
|
}
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.e("InferenceManager", "Error loading model", e)
|
Log.e("InferenceManager", "Error loading model", e)
|
||||||
false
|
null
|
||||||
}
|
}
|
||||||
} == true
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,9 +42,9 @@ class StubInferenceEngine : InferenceEngine {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads a model from the given path with an optional system prompt.
|
* Loads a model from the given path.
|
||||||
*/
|
*/
|
||||||
override suspend fun loadModel(pathToModel: String, systemPrompt: String?) {
|
override suspend fun loadModel(pathToModel: String) {
|
||||||
Log.i(TAG, "loadModel! state: ${_state.value}")
|
Log.i(TAG, "loadModel! state: ${_state.value}")
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
|
@ -53,16 +53,28 @@ class StubInferenceEngine : InferenceEngine {
|
||||||
// Simulate model loading
|
// Simulate model loading
|
||||||
delay(STUB_MODEL_LOADING_TIME)
|
delay(STUB_MODEL_LOADING_TIME)
|
||||||
|
|
||||||
_state.value = State.ModelLoaded
|
_state.value = State.ModelReady
|
||||||
|
|
||||||
if (systemPrompt != null) {
|
|
||||||
_state.value = State.ProcessingSystemPrompt
|
|
||||||
|
|
||||||
// Simulate processing system prompt
|
} catch (e: CancellationException) {
|
||||||
delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
|
// If coroutine is cancelled, propagate cancellation
|
||||||
}
|
throw e
|
||||||
|
} catch (e: Exception) {
|
||||||
|
_state.value = State.Error(e.message ?: "Unknown error during model loading")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
_state.value = State.AwaitingUserPrompt
|
/**
|
||||||
|
* Process the plain text system prompt
|
||||||
|
*/
|
||||||
|
override suspend fun setSystemPrompt(prompt: String) {
|
||||||
|
try {
|
||||||
|
_state.value = State.ProcessingSystemPrompt
|
||||||
|
|
||||||
|
// Simulate processing system prompt
|
||||||
|
delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
|
||||||
|
|
||||||
|
_state.value = State.ModelReady
|
||||||
} catch (e: CancellationException) {
|
} catch (e: CancellationException) {
|
||||||
// If coroutine is cancelled, propagate cancellation
|
// If coroutine is cancelled, propagate cancellation
|
||||||
throw e
|
throw e
|
||||||
|
|
@ -94,10 +106,10 @@ class StubInferenceEngine : InferenceEngine {
|
||||||
delay(STUB_TOKEN_GENERATION_TIME)
|
delay(STUB_TOKEN_GENERATION_TIME)
|
||||||
}
|
}
|
||||||
|
|
||||||
_state.value = State.AwaitingUserPrompt
|
_state.value = State.ModelReady
|
||||||
} catch (e: CancellationException) {
|
} catch (e: CancellationException) {
|
||||||
// Handle cancellation gracefully
|
// Handle cancellation gracefully
|
||||||
_state.value = State.AwaitingUserPrompt
|
_state.value = State.ModelReady
|
||||||
throw e
|
throw e
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
_state.value = State.Error(e.message ?: "Unknown error during generation")
|
_state.value = State.Error(e.message ?: "Unknown error during generation")
|
||||||
|
|
@ -144,12 +156,12 @@ class StubInferenceEngine : InferenceEngine {
|
||||||
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
|
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
|
||||||
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
|
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
|
||||||
|
|
||||||
_state.value = State.AwaitingUserPrompt
|
_state.value = State.ModelReady
|
||||||
|
|
||||||
return result.toString()
|
return result.toString()
|
||||||
} catch (e: CancellationException) {
|
} catch (e: CancellationException) {
|
||||||
// If coroutine is cancelled, propagate cancellation
|
// If coroutine is cancelled, propagate cancellation
|
||||||
_state.value = State.AwaitingUserPrompt
|
_state.value = State.ModelReady
|
||||||
throw e
|
throw e
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
|
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ fun ModelLoadingScreen(
|
||||||
// Check if we're in a loading state
|
// Check if we're in a loading state
|
||||||
val isLoading = engineState !is State.Uninitialized &&
|
val isLoading = engineState !is State.Uninitialized &&
|
||||||
engineState !is State.LibraryLoaded &&
|
engineState !is State.LibraryLoaded &&
|
||||||
engineState !is State.AwaitingUserPrompt
|
engineState !is State.ModelReady
|
||||||
|
|
||||||
// Mode selection callbacks
|
// Mode selection callbacks
|
||||||
val handleBenchmarkSelected = {
|
val handleBenchmarkSelected = {
|
||||||
|
|
@ -431,7 +431,7 @@ fun ModelLoadingScreen(
|
||||||
text = when (engineState) {
|
text = when (engineState) {
|
||||||
is State.LoadingModel -> "Loading model..."
|
is State.LoadingModel -> "Loading model..."
|
||||||
is State.ProcessingSystemPrompt -> "Processing system prompt..."
|
is State.ProcessingSystemPrompt -> "Processing system prompt..."
|
||||||
is State.ModelLoaded -> "Preparing conversation..."
|
is State.ModelReady -> "Preparing conversation..."
|
||||||
else -> "Processing..."
|
else -> "Processing..."
|
||||||
},
|
},
|
||||||
style = MaterialTheme.typography.titleMedium
|
style = MaterialTheme.typography.titleMedium
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,14 @@ interface InferenceEngine {
|
||||||
val state: StateFlow<State>
|
val state: StateFlow<State>
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load a model from the given path with an optional system prompt.
|
* Load a model from the given path.
|
||||||
*/
|
*/
|
||||||
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null)
|
suspend fun loadModel(pathToModel: String)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends a system prompt to the loaded model
|
||||||
|
*/
|
||||||
|
suspend fun setSystemPrompt(systemPrompt: String)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
||||||
|
|
@ -45,11 +50,9 @@ interface InferenceEngine {
|
||||||
object LibraryLoaded : State()
|
object LibraryLoaded : State()
|
||||||
|
|
||||||
object LoadingModel : State()
|
object LoadingModel : State()
|
||||||
object ModelLoaded : State()
|
object ModelReady : State()
|
||||||
|
|
||||||
object ProcessingSystemPrompt : State()
|
object ProcessingSystemPrompt : State()
|
||||||
object AwaitingUserPrompt : State()
|
|
||||||
|
|
||||||
object ProcessingUserPrompt : State()
|
object ProcessingUserPrompt : State()
|
||||||
object Generating : State()
|
object Generating : State()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,6 @@ import kotlinx.coroutines.launch
|
||||||
import kotlinx.coroutines.withContext
|
import kotlinx.coroutines.withContext
|
||||||
import java.io.File
|
import java.io.File
|
||||||
|
|
||||||
@Target(AnnotationTarget.FUNCTION)
|
|
||||||
@Retention(AnnotationRetention.SOURCE)
|
|
||||||
annotation class RequiresCleanup(val message: String = "Remember to call this method for proper cleanup!")
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
|
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
|
||||||
*
|
*
|
||||||
|
|
@ -63,6 +59,8 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
private val _state = MutableStateFlow<State>(State.Uninitialized)
|
private val _state = MutableStateFlow<State>(State.Uninitialized)
|
||||||
override val state: StateFlow<State> = _state
|
override val state: StateFlow<State> = _state
|
||||||
|
|
||||||
|
private var _readyForSystemPrompt = false
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
|
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
|
||||||
*/
|
*/
|
||||||
|
|
@ -85,9 +83,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Load the LLM, then process the plain text system prompt if provided
|
* Load the LLM
|
||||||
*/
|
*/
|
||||||
override suspend fun loadModel(pathToModel: String, systemPrompt: String?) =
|
override suspend fun loadModel(pathToModel: String) =
|
||||||
withContext(llamaDispatcher) {
|
withContext(llamaDispatcher) {
|
||||||
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
|
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
|
||||||
File(pathToModel).let {
|
File(pathToModel).let {
|
||||||
|
|
@ -96,6 +94,7 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
}
|
}
|
||||||
|
|
||||||
Log.i(TAG, "Loading model... \n$pathToModel")
|
Log.i(TAG, "Loading model... \n$pathToModel")
|
||||||
|
_readyForSystemPrompt = false
|
||||||
_state.value = State.LoadingModel
|
_state.value = State.LoadingModel
|
||||||
load(pathToModel).let { result ->
|
load(pathToModel).let { result ->
|
||||||
if (result != 0) throw IllegalStateException("Failed to Load model: $result")
|
if (result != 0) throw IllegalStateException("Failed to Load model: $result")
|
||||||
|
|
@ -104,23 +103,31 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
if (result != 0) throw IllegalStateException("Failed to prepare resources: $result")
|
if (result != 0) throw IllegalStateException("Failed to prepare resources: $result")
|
||||||
}
|
}
|
||||||
Log.i(TAG, "Model loaded!")
|
Log.i(TAG, "Model loaded!")
|
||||||
_state.value = State.ModelLoaded
|
_readyForSystemPrompt = true
|
||||||
|
_state.value = State.ModelReady
|
||||||
|
}
|
||||||
|
|
||||||
systemPrompt?.let { prompt ->
|
/**
|
||||||
Log.i(TAG, "Sending system prompt...")
|
* Process the plain text system prompt
|
||||||
_state.value = State.ProcessingSystemPrompt
|
*/
|
||||||
processSystemPrompt(prompt).let { result ->
|
override suspend fun setSystemPrompt(prompt: String) =
|
||||||
if (result != 0) {
|
withContext(llamaDispatcher) {
|
||||||
val errorMessage = "Failed to process system prompt: $result"
|
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
|
||||||
_state.value = State.Error(errorMessage)
|
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
|
||||||
throw IllegalStateException(errorMessage)
|
check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" }
|
||||||
}
|
|
||||||
|
Log.i(TAG, "Sending system prompt...")
|
||||||
|
_readyForSystemPrompt = false
|
||||||
|
_state.value = State.ProcessingSystemPrompt
|
||||||
|
processSystemPrompt(prompt).let { result ->
|
||||||
|
if (result != 0) {
|
||||||
|
val errorMessage = "Failed to process system prompt: $result"
|
||||||
|
_state.value = State.Error(errorMessage)
|
||||||
|
throw IllegalStateException(errorMessage)
|
||||||
}
|
}
|
||||||
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
|
|
||||||
} ?: run {
|
|
||||||
Log.w(TAG, "No system prompt to process.")
|
|
||||||
}
|
}
|
||||||
_state.value = State.AwaitingUserPrompt
|
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
|
||||||
|
_state.value = State.ModelReady
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -131,12 +138,13 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
predictLength: Int,
|
predictLength: Int,
|
||||||
): Flow<String> = flow {
|
): Flow<String> = flow {
|
||||||
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
||||||
check(_state.value is State.AwaitingUserPrompt) {
|
check(_state.value is State.ModelReady) {
|
||||||
"User prompt discarded due to: ${_state.value}"
|
"User prompt discarded due to: ${_state.value}"
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Log.i(TAG, "Sending user prompt...")
|
Log.i(TAG, "Sending user prompt...")
|
||||||
|
_readyForSystemPrompt = false
|
||||||
_state.value = State.ProcessingUserPrompt
|
_state.value = State.ProcessingUserPrompt
|
||||||
processUserPrompt(message, predictLength).let { result ->
|
processUserPrompt(message, predictLength).let { result ->
|
||||||
if (result != 0) {
|
if (result != 0) {
|
||||||
|
|
@ -153,10 +161,10 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
} ?: break
|
} ?: break
|
||||||
}
|
}
|
||||||
Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
|
Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
|
||||||
_state.value = State.AwaitingUserPrompt
|
_state.value = State.ModelReady
|
||||||
} catch (e: CancellationException) {
|
} catch (e: CancellationException) {
|
||||||
Log.i(TAG, "Generation cancelled by user.")
|
Log.i(TAG, "Generation cancelled by user.")
|
||||||
_state.value = State.AwaitingUserPrompt
|
_state.value = State.ModelReady
|
||||||
throw e
|
throw e
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
Log.e(TAG, "Error during generation!", e)
|
Log.e(TAG, "Error during generation!", e)
|
||||||
|
|
@ -170,13 +178,14 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
*/
|
*/
|
||||||
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
||||||
withContext(llamaDispatcher) {
|
withContext(llamaDispatcher) {
|
||||||
check(_state.value is State.AwaitingUserPrompt) {
|
check(_state.value is State.ModelReady) {
|
||||||
"Benchmark request discarded due to: $state"
|
"Benchmark request discarded due to: $state"
|
||||||
}
|
}
|
||||||
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
|
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
|
||||||
|
_readyForSystemPrompt = false // Just to be safe
|
||||||
_state.value = State.Benchmarking
|
_state.value = State.Benchmarking
|
||||||
benchModel(pp, tg, pl, nr).also {
|
benchModel(pp, tg, pl, nr).also {
|
||||||
_state.value = State.AwaitingUserPrompt
|
_state.value = State.ModelReady
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -186,8 +195,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
override suspend fun unloadModel() =
|
override suspend fun unloadModel() =
|
||||||
withContext(llamaDispatcher) {
|
withContext(llamaDispatcher) {
|
||||||
when(_state.value) {
|
when(_state.value) {
|
||||||
is State.AwaitingUserPrompt, is State.Error -> {
|
is State.ModelReady, is State.Error -> {
|
||||||
Log.i(TAG, "Unloading model and free resources...")
|
Log.i(TAG, "Unloading model and free resources...")
|
||||||
|
_readyForSystemPrompt = false
|
||||||
unload()
|
unload()
|
||||||
_state.value = State.LibraryLoaded
|
_state.value = State.LibraryLoaded
|
||||||
Log.i(TAG, "Model unloaded!")
|
Log.i(TAG, "Model unloaded!")
|
||||||
|
|
@ -200,8 +210,8 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
||||||
/**
|
/**
|
||||||
* Cancel all ongoing coroutines and free GGML backends
|
* Cancel all ongoing coroutines and free GGML backends
|
||||||
*/
|
*/
|
||||||
@RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!")
|
|
||||||
override fun destroy() {
|
override fun destroy() {
|
||||||
|
_readyForSystemPrompt = false
|
||||||
llamaScope.cancel()
|
llamaScope.cancel()
|
||||||
when(_state.value) {
|
when(_state.value) {
|
||||||
is State.Uninitialized -> {}
|
is State.Uninitialized -> {}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue