LLama: refactor loadModel by splitting the system prompt setting into a separate method

This commit is contained in:
Han Yin 2025-04-16 22:29:40 -07:00
parent 9f77155535
commit 65d4a57a8b
6 changed files with 111 additions and 104 deletions

View File

@ -18,47 +18,4 @@ data class ModelInfo(
) {
val formattedSize: String
get() = formatSize(sizeInBytes)
companion object {
/**
* Creates a list of sample models for development and testing.
*/
fun getSampleModels(): List<ModelInfo> {
return listOf(
ModelInfo(
id = "mistral-7b",
name = "Mistral 7B",
path = "/storage/models/mistral-7b-q4_0.gguf",
sizeInBytes = 4_000_000_000,
parameters = "7B",
quantization = "Q4_K_M",
type = "Mistral",
contextLength = 8192,
lastUsed = System.currentTimeMillis() - 86400000 // 1 day ago
),
ModelInfo(
id = "llama2-13b",
name = "Llama 2 13B",
path = "/storage/models/llama2-13b-q5_k_m.gguf",
sizeInBytes = 8_500_000_000,
parameters = "13B",
quantization = "Q5_K_M",
type = "Llama",
contextLength = 4096,
lastUsed = System.currentTimeMillis() - 259200000 // 3 days ago
),
ModelInfo(
id = "phi-2",
name = "Phi-2",
path = "/storage/models/phi-2.gguf",
sizeInBytes = 2_800_000_000,
parameters = "2.7B",
quantization = "Q4_0",
type = "Phi",
contextLength = 2048,
lastUsed = null
)
)
}
}
}

View File

@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService {
/**
* Load a model for benchmark
*/
suspend fun loadModelForBenchmark(): Boolean
suspend fun loadModelForBenchmark(): ModelLoadingMetrics?
/**
* Load a model for conversation
*/
suspend fun loadModelForConversation(systemPrompt: String?): Boolean
suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics?
}
interface BenchmarkService : InferenceService {
@ -80,6 +80,17 @@ interface ConversationService : InferenceService {
fun createTokenMetrics(): TokenMetrics
}
/**
* Metrics for model loading and system prompt processing
*/
data class ModelLoadingMetrics(
val modelLoadingTimeMs: Long,
val systemPromptProcessingTimeMs: Long? = null
) {
val totalTimeMs: Long
get() = modelLoadingTimeMs + (systemPromptProcessingTimeMs ?: 0)
}
/**
* Represents an update during text generation
*/
@ -115,9 +126,7 @@ internal class InferenceServiceImpl @Inject internal constructor(
private val _currentModel = MutableStateFlow<ModelInfo?>(null)
override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
override fun setCurrentModel(model: ModelInfo) {
_currentModel.value = model
}
override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model }
override suspend fun unloadModel() = inferenceEngine.unloadModel()
@ -129,29 +138,45 @@ internal class InferenceServiceImpl @Inject internal constructor(
/* ModelLoadingService implementation */
override suspend fun loadModelForBenchmark(): Boolean {
override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? {
return _currentModel.value?.let { model ->
try {
val modelLoadStartTs = System.currentTimeMillis()
inferenceEngine.loadModel(model.path)
true
val modelLoadEndTs = System.currentTimeMillis()
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
} catch (e: Exception) {
Log.e("InferenceManager", "Error loading model", e)
false
null
}
}
} == true
}
override suspend fun loadModelForConversation(systemPrompt: String?): Boolean {
override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? {
_systemPrompt.value = systemPrompt
return _currentModel.value?.let { model ->
try {
inferenceEngine.loadModel(model.path, systemPrompt)
true
val modelLoadStartTs = System.currentTimeMillis()
inferenceEngine.loadModel(model.path)
val modelLoadEndTs = System.currentTimeMillis()
if (systemPrompt.isNullOrBlank()) {
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
} else {
val prompt: String = systemPrompt
val systemPromptStartTs = System.currentTimeMillis()
inferenceEngine.setSystemPrompt(prompt)
val systemPromptEndTs = System.currentTimeMillis()
ModelLoadingMetrics(
modelLoadingTimeMs = modelLoadEndTs - modelLoadStartTs,
systemPromptProcessingTimeMs = systemPromptEndTs - systemPromptStartTs
)
}
} catch (e: Exception) {
Log.e("InferenceManager", "Error loading model", e)
false
null
}
}
} == true
}

View File

@ -42,9 +42,9 @@ class StubInferenceEngine : InferenceEngine {
}
/**
* Loads a model from the given path with an optional system prompt.
* Loads a model from the given path.
*/
override suspend fun loadModel(pathToModel: String, systemPrompt: String?) {
override suspend fun loadModel(pathToModel: String) {
Log.i(TAG, "loadModel! state: ${_state.value}")
try {
@ -53,16 +53,28 @@ class StubInferenceEngine : InferenceEngine {
// Simulate model loading
delay(STUB_MODEL_LOADING_TIME)
_state.value = State.ModelLoaded
_state.value = State.ModelReady
if (systemPrompt != null) {
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during model loading")
}
}
/**
* Process the plain text system prompt
*/
override suspend fun setSystemPrompt(prompt: String) {
try {
_state.value = State.ProcessingSystemPrompt
// Simulate processing system prompt
delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
}
_state.value = State.AwaitingUserPrompt
_state.value = State.ModelReady
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
throw e
@ -94,10 +106,10 @@ class StubInferenceEngine : InferenceEngine {
delay(STUB_TOKEN_GENERATION_TIME)
}
_state.value = State.AwaitingUserPrompt
_state.value = State.ModelReady
} catch (e: CancellationException) {
// Handle cancellation gracefully
_state.value = State.AwaitingUserPrompt
_state.value = State.ModelReady
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during generation")
@ -144,12 +156,12 @@ class StubInferenceEngine : InferenceEngine {
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
_state.value = State.AwaitingUserPrompt
_state.value = State.ModelReady
return result.toString()
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
_state.value = State.AwaitingUserPrompt
_state.value = State.ModelReady
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during benchmarking")

View File

@ -101,7 +101,7 @@ fun ModelLoadingScreen(
// Check if we're in a loading state
val isLoading = engineState !is State.Uninitialized &&
engineState !is State.LibraryLoaded &&
engineState !is State.AwaitingUserPrompt
engineState !is State.ModelReady
// Mode selection callbacks
val handleBenchmarkSelected = {
@ -431,7 +431,7 @@ fun ModelLoadingScreen(
text = when (engineState) {
is State.LoadingModel -> "Loading model..."
is State.ProcessingSystemPrompt -> "Processing system prompt..."
is State.ModelLoaded -> "Preparing conversation..."
is State.ModelReady -> "Preparing conversation..."
else -> "Processing..."
},
style = MaterialTheme.typography.titleMedium

View File

@ -13,9 +13,14 @@ interface InferenceEngine {
val state: StateFlow<State>
/**
* Load a model from the given path with an optional system prompt.
* Load a model from the given path.
*/
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null)
suspend fun loadModel(pathToModel: String)
/**
* Sends a system prompt to the loaded model
*/
suspend fun setSystemPrompt(systemPrompt: String)
/**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
@ -45,11 +50,9 @@ interface InferenceEngine {
object LibraryLoaded : State()
object LoadingModel : State()
object ModelLoaded : State()
object ModelReady : State()
object ProcessingSystemPrompt : State()
object AwaitingUserPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()

View File

@ -18,10 +18,6 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
@Target(AnnotationTarget.FUNCTION)
@Retention(AnnotationRetention.SOURCE)
annotation class RequiresCleanup(val message: String = "Remember to call this method for proper cleanup!")
/**
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
*
@ -63,6 +59,8 @@ class LLamaAndroid private constructor() : InferenceEngine {
private val _state = MutableStateFlow<State>(State.Uninitialized)
override val state: StateFlow<State> = _state
private var _readyForSystemPrompt = false
/**
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
*/
@ -85,9 +83,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
}
/**
* Load the LLM, then process the plain text system prompt if provided
* Load the LLM
*/
override suspend fun loadModel(pathToModel: String, systemPrompt: String?) =
override suspend fun loadModel(pathToModel: String) =
withContext(llamaDispatcher) {
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
File(pathToModel).let {
@ -96,6 +94,7 @@ class LLamaAndroid private constructor() : InferenceEngine {
}
Log.i(TAG, "Loading model... \n$pathToModel")
_readyForSystemPrompt = false
_state.value = State.LoadingModel
load(pathToModel).let { result ->
if (result != 0) throw IllegalStateException("Failed to Load model: $result")
@ -104,10 +103,21 @@ class LLamaAndroid private constructor() : InferenceEngine {
if (result != 0) throw IllegalStateException("Failed to prepare resources: $result")
}
Log.i(TAG, "Model loaded!")
_state.value = State.ModelLoaded
_readyForSystemPrompt = true
_state.value = State.ModelReady
}
/**
* Process the plain text system prompt
*/
override suspend fun setSystemPrompt(prompt: String) =
withContext(llamaDispatcher) {
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" }
systemPrompt?.let { prompt ->
Log.i(TAG, "Sending system prompt...")
_readyForSystemPrompt = false
_state.value = State.ProcessingSystemPrompt
processSystemPrompt(prompt).let { result ->
if (result != 0) {
@ -117,10 +127,7 @@ class LLamaAndroid private constructor() : InferenceEngine {
}
}
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
} ?: run {
Log.w(TAG, "No system prompt to process.")
}
_state.value = State.AwaitingUserPrompt
_state.value = State.ModelReady
}
/**
@ -131,12 +138,13 @@ class LLamaAndroid private constructor() : InferenceEngine {
predictLength: Int,
): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.AwaitingUserPrompt) {
check(_state.value is State.ModelReady) {
"User prompt discarded due to: ${_state.value}"
}
try {
Log.i(TAG, "Sending user prompt...")
_readyForSystemPrompt = false
_state.value = State.ProcessingUserPrompt
processUserPrompt(message, predictLength).let { result ->
if (result != 0) {
@ -153,10 +161,10 @@ class LLamaAndroid private constructor() : InferenceEngine {
} ?: break
}
Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
_state.value = State.AwaitingUserPrompt
_state.value = State.ModelReady
} catch (e: CancellationException) {
Log.i(TAG, "Generation cancelled by user.")
_state.value = State.AwaitingUserPrompt
_state.value = State.ModelReady
throw e
} catch (e: Exception) {
Log.e(TAG, "Error during generation!", e)
@ -170,13 +178,14 @@ class LLamaAndroid private constructor() : InferenceEngine {
*/
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
withContext(llamaDispatcher) {
check(_state.value is State.AwaitingUserPrompt) {
check(_state.value is State.ModelReady) {
"Benchmark request discarded due to: $state"
}
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
_readyForSystemPrompt = false // Just to be safe
_state.value = State.Benchmarking
benchModel(pp, tg, pl, nr).also {
_state.value = State.AwaitingUserPrompt
_state.value = State.ModelReady
}
}
@ -186,8 +195,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
override suspend fun unloadModel() =
withContext(llamaDispatcher) {
when(_state.value) {
is State.AwaitingUserPrompt, is State.Error -> {
is State.ModelReady, is State.Error -> {
Log.i(TAG, "Unloading model and free resources...")
_readyForSystemPrompt = false
unload()
_state.value = State.LibraryLoaded
Log.i(TAG, "Model unloaded!")
@ -200,8 +210,8 @@ class LLamaAndroid private constructor() : InferenceEngine {
/**
* Cancel all ongoing coroutines and free GGML backends
*/
@RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!")
override fun destroy() {
_readyForSystemPrompt = false
llamaScope.cancel()
when(_state.value) {
is State.Uninitialized -> {}