LLama: refactor loadModel by splitting the system prompt setting into a separate method
This commit is contained in:
parent
9f77155535
commit
65d4a57a8b
|
|
@ -18,47 +18,4 @@ data class ModelInfo(
|
|||
) {
|
||||
val formattedSize: String
|
||||
get() = formatSize(sizeInBytes)
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Creates a list of sample models for development and testing.
|
||||
*/
|
||||
fun getSampleModels(): List<ModelInfo> {
|
||||
return listOf(
|
||||
ModelInfo(
|
||||
id = "mistral-7b",
|
||||
name = "Mistral 7B",
|
||||
path = "/storage/models/mistral-7b-q4_0.gguf",
|
||||
sizeInBytes = 4_000_000_000,
|
||||
parameters = "7B",
|
||||
quantization = "Q4_K_M",
|
||||
type = "Mistral",
|
||||
contextLength = 8192,
|
||||
lastUsed = System.currentTimeMillis() - 86400000 // 1 day ago
|
||||
),
|
||||
ModelInfo(
|
||||
id = "llama2-13b",
|
||||
name = "Llama 2 13B",
|
||||
path = "/storage/models/llama2-13b-q5_k_m.gguf",
|
||||
sizeInBytes = 8_500_000_000,
|
||||
parameters = "13B",
|
||||
quantization = "Q5_K_M",
|
||||
type = "Llama",
|
||||
contextLength = 4096,
|
||||
lastUsed = System.currentTimeMillis() - 259200000 // 3 days ago
|
||||
),
|
||||
ModelInfo(
|
||||
id = "phi-2",
|
||||
name = "Phi-2",
|
||||
path = "/storage/models/phi-2.gguf",
|
||||
sizeInBytes = 2_800_000_000,
|
||||
parameters = "2.7B",
|
||||
quantization = "Q4_0",
|
||||
type = "Phi",
|
||||
contextLength = 2048,
|
||||
lastUsed = null
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService {
|
|||
/**
|
||||
* Load a model for benchmark
|
||||
*/
|
||||
suspend fun loadModelForBenchmark(): Boolean
|
||||
suspend fun loadModelForBenchmark(): ModelLoadingMetrics?
|
||||
|
||||
/**
|
||||
* Load a model for conversation
|
||||
*/
|
||||
suspend fun loadModelForConversation(systemPrompt: String?): Boolean
|
||||
suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics?
|
||||
}
|
||||
|
||||
interface BenchmarkService : InferenceService {
|
||||
|
|
@ -80,6 +80,17 @@ interface ConversationService : InferenceService {
|
|||
fun createTokenMetrics(): TokenMetrics
|
||||
}
|
||||
|
||||
/**
|
||||
* Metrics for model loading and system prompt processing
|
||||
*/
|
||||
data class ModelLoadingMetrics(
|
||||
val modelLoadingTimeMs: Long,
|
||||
val systemPromptProcessingTimeMs: Long? = null
|
||||
) {
|
||||
val totalTimeMs: Long
|
||||
get() = modelLoadingTimeMs + (systemPromptProcessingTimeMs ?: 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents an update during text generation
|
||||
*/
|
||||
|
|
@ -115,9 +126,7 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
|||
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
|
||||
override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
|
||||
|
||||
override fun setCurrentModel(model: ModelInfo) {
|
||||
_currentModel.value = model
|
||||
}
|
||||
override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model }
|
||||
|
||||
override suspend fun unloadModel() = inferenceEngine.unloadModel()
|
||||
|
||||
|
|
@ -129,29 +138,45 @@ internal class InferenceServiceImpl @Inject internal constructor(
|
|||
|
||||
/* ModelLoadingService implementation */
|
||||
|
||||
override suspend fun loadModelForBenchmark(): Boolean {
|
||||
override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? {
|
||||
return _currentModel.value?.let { model ->
|
||||
try {
|
||||
val modelLoadStartTs = System.currentTimeMillis()
|
||||
inferenceEngine.loadModel(model.path)
|
||||
true
|
||||
val modelLoadEndTs = System.currentTimeMillis()
|
||||
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
|
||||
} catch (e: Exception) {
|
||||
Log.e("InferenceManager", "Error loading model", e)
|
||||
false
|
||||
null
|
||||
}
|
||||
}
|
||||
} == true
|
||||
}
|
||||
|
||||
override suspend fun loadModelForConversation(systemPrompt: String?): Boolean {
|
||||
override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? {
|
||||
_systemPrompt.value = systemPrompt
|
||||
return _currentModel.value?.let { model ->
|
||||
try {
|
||||
inferenceEngine.loadModel(model.path, systemPrompt)
|
||||
true
|
||||
val modelLoadStartTs = System.currentTimeMillis()
|
||||
inferenceEngine.loadModel(model.path)
|
||||
val modelLoadEndTs = System.currentTimeMillis()
|
||||
|
||||
if (systemPrompt.isNullOrBlank()) {
|
||||
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
|
||||
} else {
|
||||
val prompt: String = systemPrompt
|
||||
val systemPromptStartTs = System.currentTimeMillis()
|
||||
inferenceEngine.setSystemPrompt(prompt)
|
||||
val systemPromptEndTs = System.currentTimeMillis()
|
||||
ModelLoadingMetrics(
|
||||
modelLoadingTimeMs = modelLoadEndTs - modelLoadStartTs,
|
||||
systemPromptProcessingTimeMs = systemPromptEndTs - systemPromptStartTs
|
||||
)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.e("InferenceManager", "Error loading model", e)
|
||||
false
|
||||
null
|
||||
}
|
||||
}
|
||||
} == true
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -42,9 +42,9 @@ class StubInferenceEngine : InferenceEngine {
|
|||
}
|
||||
|
||||
/**
|
||||
* Loads a model from the given path with an optional system prompt.
|
||||
* Loads a model from the given path.
|
||||
*/
|
||||
override suspend fun loadModel(pathToModel: String, systemPrompt: String?) {
|
||||
override suspend fun loadModel(pathToModel: String) {
|
||||
Log.i(TAG, "loadModel! state: ${_state.value}")
|
||||
|
||||
try {
|
||||
|
|
@ -53,16 +53,28 @@ class StubInferenceEngine : InferenceEngine {
|
|||
// Simulate model loading
|
||||
delay(STUB_MODEL_LOADING_TIME)
|
||||
|
||||
_state.value = State.ModelLoaded
|
||||
_state.value = State.ModelReady
|
||||
|
||||
if (systemPrompt != null) {
|
||||
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during model loading")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process the plain text system prompt
|
||||
*/
|
||||
override suspend fun setSystemPrompt(prompt: String) {
|
||||
try {
|
||||
_state.value = State.ProcessingSystemPrompt
|
||||
|
||||
// Simulate processing system prompt
|
||||
delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
|
||||
}
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
_state.value = State.ModelReady
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
throw e
|
||||
|
|
@ -94,10 +106,10 @@ class StubInferenceEngine : InferenceEngine {
|
|||
delay(STUB_TOKEN_GENERATION_TIME)
|
||||
}
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
_state.value = State.ModelReady
|
||||
} catch (e: CancellationException) {
|
||||
// Handle cancellation gracefully
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
_state.value = State.ModelReady
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during generation")
|
||||
|
|
@ -144,12 +156,12 @@ class StubInferenceEngine : InferenceEngine {
|
|||
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
|
||||
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
|
||||
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
_state.value = State.ModelReady
|
||||
|
||||
return result.toString()
|
||||
} catch (e: CancellationException) {
|
||||
// If coroutine is cancelled, propagate cancellation
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
_state.value = State.ModelReady
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ fun ModelLoadingScreen(
|
|||
// Check if we're in a loading state
|
||||
val isLoading = engineState !is State.Uninitialized &&
|
||||
engineState !is State.LibraryLoaded &&
|
||||
engineState !is State.AwaitingUserPrompt
|
||||
engineState !is State.ModelReady
|
||||
|
||||
// Mode selection callbacks
|
||||
val handleBenchmarkSelected = {
|
||||
|
|
@ -431,7 +431,7 @@ fun ModelLoadingScreen(
|
|||
text = when (engineState) {
|
||||
is State.LoadingModel -> "Loading model..."
|
||||
is State.ProcessingSystemPrompt -> "Processing system prompt..."
|
||||
is State.ModelLoaded -> "Preparing conversation..."
|
||||
is State.ModelReady -> "Preparing conversation..."
|
||||
else -> "Processing..."
|
||||
},
|
||||
style = MaterialTheme.typography.titleMedium
|
||||
|
|
|
|||
|
|
@ -13,9 +13,14 @@ interface InferenceEngine {
|
|||
val state: StateFlow<State>
|
||||
|
||||
/**
|
||||
* Load a model from the given path with an optional system prompt.
|
||||
* Load a model from the given path.
|
||||
*/
|
||||
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null)
|
||||
suspend fun loadModel(pathToModel: String)
|
||||
|
||||
/**
|
||||
* Sends a system prompt to the loaded model
|
||||
*/
|
||||
suspend fun setSystemPrompt(systemPrompt: String)
|
||||
|
||||
/**
|
||||
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
||||
|
|
@ -45,11 +50,9 @@ interface InferenceEngine {
|
|||
object LibraryLoaded : State()
|
||||
|
||||
object LoadingModel : State()
|
||||
object ModelLoaded : State()
|
||||
object ModelReady : State()
|
||||
|
||||
object ProcessingSystemPrompt : State()
|
||||
object AwaitingUserPrompt : State()
|
||||
|
||||
object ProcessingUserPrompt : State()
|
||||
object Generating : State()
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,6 @@ import kotlinx.coroutines.launch
|
|||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
|
||||
@Target(AnnotationTarget.FUNCTION)
|
||||
@Retention(AnnotationRetention.SOURCE)
|
||||
annotation class RequiresCleanup(val message: String = "Remember to call this method for proper cleanup!")
|
||||
|
||||
/**
|
||||
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
|
||||
*
|
||||
|
|
@ -63,6 +59,8 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
private val _state = MutableStateFlow<State>(State.Uninitialized)
|
||||
override val state: StateFlow<State> = _state
|
||||
|
||||
private var _readyForSystemPrompt = false
|
||||
|
||||
/**
|
||||
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
|
||||
*/
|
||||
|
|
@ -85,9 +83,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
}
|
||||
|
||||
/**
|
||||
* Load the LLM, then process the plain text system prompt if provided
|
||||
* Load the LLM
|
||||
*/
|
||||
override suspend fun loadModel(pathToModel: String, systemPrompt: String?) =
|
||||
override suspend fun loadModel(pathToModel: String) =
|
||||
withContext(llamaDispatcher) {
|
||||
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
|
||||
File(pathToModel).let {
|
||||
|
|
@ -96,6 +94,7 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
}
|
||||
|
||||
Log.i(TAG, "Loading model... \n$pathToModel")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = State.LoadingModel
|
||||
load(pathToModel).let { result ->
|
||||
if (result != 0) throw IllegalStateException("Failed to Load model: $result")
|
||||
|
|
@ -104,10 +103,21 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
if (result != 0) throw IllegalStateException("Failed to prepare resources: $result")
|
||||
}
|
||||
Log.i(TAG, "Model loaded!")
|
||||
_state.value = State.ModelLoaded
|
||||
_readyForSystemPrompt = true
|
||||
_state.value = State.ModelReady
|
||||
}
|
||||
|
||||
/**
|
||||
* Process the plain text system prompt
|
||||
*/
|
||||
override suspend fun setSystemPrompt(prompt: String) =
|
||||
withContext(llamaDispatcher) {
|
||||
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
|
||||
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
|
||||
check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" }
|
||||
|
||||
systemPrompt?.let { prompt ->
|
||||
Log.i(TAG, "Sending system prompt...")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = State.ProcessingSystemPrompt
|
||||
processSystemPrompt(prompt).let { result ->
|
||||
if (result != 0) {
|
||||
|
|
@ -117,10 +127,7 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
}
|
||||
}
|
||||
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
|
||||
} ?: run {
|
||||
Log.w(TAG, "No system prompt to process.")
|
||||
}
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
_state.value = State.ModelReady
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -131,12 +138,13 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
predictLength: Int,
|
||||
): Flow<String> = flow {
|
||||
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
||||
check(_state.value is State.AwaitingUserPrompt) {
|
||||
check(_state.value is State.ModelReady) {
|
||||
"User prompt discarded due to: ${_state.value}"
|
||||
}
|
||||
|
||||
try {
|
||||
Log.i(TAG, "Sending user prompt...")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = State.ProcessingUserPrompt
|
||||
processUserPrompt(message, predictLength).let { result ->
|
||||
if (result != 0) {
|
||||
|
|
@ -153,10 +161,10 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
} ?: break
|
||||
}
|
||||
Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
_state.value = State.ModelReady
|
||||
} catch (e: CancellationException) {
|
||||
Log.i(TAG, "Generation cancelled by user.")
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
_state.value = State.ModelReady
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error during generation!", e)
|
||||
|
|
@ -170,13 +178,14 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
*/
|
||||
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
||||
withContext(llamaDispatcher) {
|
||||
check(_state.value is State.AwaitingUserPrompt) {
|
||||
check(_state.value is State.ModelReady) {
|
||||
"Benchmark request discarded due to: $state"
|
||||
}
|
||||
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
|
||||
_readyForSystemPrompt = false // Just to be safe
|
||||
_state.value = State.Benchmarking
|
||||
benchModel(pp, tg, pl, nr).also {
|
||||
_state.value = State.AwaitingUserPrompt
|
||||
_state.value = State.ModelReady
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -186,8 +195,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
override suspend fun unloadModel() =
|
||||
withContext(llamaDispatcher) {
|
||||
when(_state.value) {
|
||||
is State.AwaitingUserPrompt, is State.Error -> {
|
||||
is State.ModelReady, is State.Error -> {
|
||||
Log.i(TAG, "Unloading model and free resources...")
|
||||
_readyForSystemPrompt = false
|
||||
unload()
|
||||
_state.value = State.LibraryLoaded
|
||||
Log.i(TAG, "Model unloaded!")
|
||||
|
|
@ -200,8 +210,8 @@ class LLamaAndroid private constructor() : InferenceEngine {
|
|||
/**
|
||||
* Cancel all ongoing coroutines and free GGML backends
|
||||
*/
|
||||
@RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!")
|
||||
override fun destroy() {
|
||||
_readyForSystemPrompt = false
|
||||
llamaScope.cancel()
|
||||
when(_state.value) {
|
||||
is State.Uninitialized -> {}
|
||||
|
|
|
|||
Loading…
Reference in New Issue