LLama: refactor loadModel by splitting the system prompt setting into a separate method

This commit is contained in:
Han Yin 2025-04-16 22:29:40 -07:00
parent 9f77155535
commit 65d4a57a8b
6 changed files with 111 additions and 104 deletions

View File

@ -18,47 +18,4 @@ data class ModelInfo(
) { ) {
val formattedSize: String val formattedSize: String
get() = formatSize(sizeInBytes) get() = formatSize(sizeInBytes)
companion object {
/**
* Creates a list of sample models for development and testing.
*/
fun getSampleModels(): List<ModelInfo> {
return listOf(
ModelInfo(
id = "mistral-7b",
name = "Mistral 7B",
path = "/storage/models/mistral-7b-q4_0.gguf",
sizeInBytes = 4_000_000_000,
parameters = "7B",
quantization = "Q4_K_M",
type = "Mistral",
contextLength = 8192,
lastUsed = System.currentTimeMillis() - 86400000 // 1 day ago
),
ModelInfo(
id = "llama2-13b",
name = "Llama 2 13B",
path = "/storage/models/llama2-13b-q5_k_m.gguf",
sizeInBytes = 8_500_000_000,
parameters = "13B",
quantization = "Q5_K_M",
type = "Llama",
contextLength = 4096,
lastUsed = System.currentTimeMillis() - 259200000 // 3 days ago
),
ModelInfo(
id = "phi-2",
name = "Phi-2",
path = "/storage/models/phi-2.gguf",
sizeInBytes = 2_800_000_000,
parameters = "2.7B",
quantization = "Q4_0",
type = "Phi",
contextLength = 2048,
lastUsed = null
)
)
}
}
} }

View File

@ -38,12 +38,12 @@ interface ModelLoadingService : InferenceService {
/** /**
* Load a model for benchmark * Load a model for benchmark
*/ */
suspend fun loadModelForBenchmark(): Boolean suspend fun loadModelForBenchmark(): ModelLoadingMetrics?
/** /**
* Load a model for conversation * Load a model for conversation
*/ */
suspend fun loadModelForConversation(systemPrompt: String?): Boolean suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics?
} }
interface BenchmarkService : InferenceService { interface BenchmarkService : InferenceService {
@ -80,6 +80,17 @@ interface ConversationService : InferenceService {
fun createTokenMetrics(): TokenMetrics fun createTokenMetrics(): TokenMetrics
} }
/**
* Metrics for model loading and system prompt processing
*/
data class ModelLoadingMetrics(
val modelLoadingTimeMs: Long,
val systemPromptProcessingTimeMs: Long? = null
) {
val totalTimeMs: Long
get() = modelLoadingTimeMs + (systemPromptProcessingTimeMs ?: 0)
}
/** /**
* Represents an update during text generation * Represents an update during text generation
*/ */
@ -115,9 +126,7 @@ internal class InferenceServiceImpl @Inject internal constructor(
private val _currentModel = MutableStateFlow<ModelInfo?>(null) private val _currentModel = MutableStateFlow<ModelInfo?>(null)
override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow() override val currentSelectedModel: StateFlow<ModelInfo?> = _currentModel.asStateFlow()
override fun setCurrentModel(model: ModelInfo) { override fun setCurrentModel(model: ModelInfo) { _currentModel.value = model }
_currentModel.value = model
}
override suspend fun unloadModel() = inferenceEngine.unloadModel() override suspend fun unloadModel() = inferenceEngine.unloadModel()
@ -129,29 +138,45 @@ internal class InferenceServiceImpl @Inject internal constructor(
/* ModelLoadingService implementation */ /* ModelLoadingService implementation */
override suspend fun loadModelForBenchmark(): Boolean { override suspend fun loadModelForBenchmark(): ModelLoadingMetrics? {
return _currentModel.value?.let { model -> return _currentModel.value?.let { model ->
try { try {
val modelLoadStartTs = System.currentTimeMillis()
inferenceEngine.loadModel(model.path) inferenceEngine.loadModel(model.path)
true val modelLoadEndTs = System.currentTimeMillis()
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
} catch (e: Exception) { } catch (e: Exception) {
Log.e("InferenceManager", "Error loading model", e) Log.e("InferenceManager", "Error loading model", e)
false null
}
} }
} == true
} }
override suspend fun loadModelForConversation(systemPrompt: String?): Boolean { override suspend fun loadModelForConversation(systemPrompt: String?): ModelLoadingMetrics? {
_systemPrompt.value = systemPrompt _systemPrompt.value = systemPrompt
return _currentModel.value?.let { model -> return _currentModel.value?.let { model ->
try { try {
inferenceEngine.loadModel(model.path, systemPrompt) val modelLoadStartTs = System.currentTimeMillis()
true inferenceEngine.loadModel(model.path)
val modelLoadEndTs = System.currentTimeMillis()
if (systemPrompt.isNullOrBlank()) {
ModelLoadingMetrics(modelLoadEndTs - modelLoadStartTs)
} else {
val prompt: String = systemPrompt
val systemPromptStartTs = System.currentTimeMillis()
inferenceEngine.setSystemPrompt(prompt)
val systemPromptEndTs = System.currentTimeMillis()
ModelLoadingMetrics(
modelLoadingTimeMs = modelLoadEndTs - modelLoadStartTs,
systemPromptProcessingTimeMs = systemPromptEndTs - systemPromptStartTs
)
}
} catch (e: Exception) { } catch (e: Exception) {
Log.e("InferenceManager", "Error loading model", e) Log.e("InferenceManager", "Error loading model", e)
false null
}
} }
} == true
} }

View File

@ -42,9 +42,9 @@ class StubInferenceEngine : InferenceEngine {
} }
/** /**
* Loads a model from the given path with an optional system prompt. * Loads a model from the given path.
*/ */
override suspend fun loadModel(pathToModel: String, systemPrompt: String?) { override suspend fun loadModel(pathToModel: String) {
Log.i(TAG, "loadModel! state: ${_state.value}") Log.i(TAG, "loadModel! state: ${_state.value}")
try { try {
@ -53,16 +53,28 @@ class StubInferenceEngine : InferenceEngine {
// Simulate model loading // Simulate model loading
delay(STUB_MODEL_LOADING_TIME) delay(STUB_MODEL_LOADING_TIME)
_state.value = State.ModelLoaded _state.value = State.ModelReady
if (systemPrompt != null) {
} catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation
throw e
} catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during model loading")
}
}
/**
* Process the plain text system prompt
*/
override suspend fun setSystemPrompt(prompt: String) {
try {
_state.value = State.ProcessingSystemPrompt _state.value = State.ProcessingSystemPrompt
// Simulate processing system prompt // Simulate processing system prompt
delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME) delay(STUB_SYSTEM_PROMPT_PROCESSING_TIME)
}
_state.value = State.AwaitingUserPrompt _state.value = State.ModelReady
} catch (e: CancellationException) { } catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation // If coroutine is cancelled, propagate cancellation
throw e throw e
@ -94,10 +106,10 @@ class StubInferenceEngine : InferenceEngine {
delay(STUB_TOKEN_GENERATION_TIME) delay(STUB_TOKEN_GENERATION_TIME)
} }
_state.value = State.AwaitingUserPrompt _state.value = State.ModelReady
} catch (e: CancellationException) { } catch (e: CancellationException) {
// Handle cancellation gracefully // Handle cancellation gracefully
_state.value = State.AwaitingUserPrompt _state.value = State.ModelReady
throw e throw e
} catch (e: Exception) { } catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during generation") _state.value = State.Error(e.message ?: "Unknown error during generation")
@ -144,12 +156,12 @@ class StubInferenceEngine : InferenceEngine {
result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ") result.append("| $modelDesc | ${model_size}GiB | ${model_n_params}B | ")
result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n") result.append("$backend | tg $tg | $tg_avg ± $tg_std |\n")
_state.value = State.AwaitingUserPrompt _state.value = State.ModelReady
return result.toString() return result.toString()
} catch (e: CancellationException) { } catch (e: CancellationException) {
// If coroutine is cancelled, propagate cancellation // If coroutine is cancelled, propagate cancellation
_state.value = State.AwaitingUserPrompt _state.value = State.ModelReady
throw e throw e
} catch (e: Exception) { } catch (e: Exception) {
_state.value = State.Error(e.message ?: "Unknown error during benchmarking") _state.value = State.Error(e.message ?: "Unknown error during benchmarking")

View File

@ -101,7 +101,7 @@ fun ModelLoadingScreen(
// Check if we're in a loading state // Check if we're in a loading state
val isLoading = engineState !is State.Uninitialized && val isLoading = engineState !is State.Uninitialized &&
engineState !is State.LibraryLoaded && engineState !is State.LibraryLoaded &&
engineState !is State.AwaitingUserPrompt engineState !is State.ModelReady
// Mode selection callbacks // Mode selection callbacks
val handleBenchmarkSelected = { val handleBenchmarkSelected = {
@ -431,7 +431,7 @@ fun ModelLoadingScreen(
text = when (engineState) { text = when (engineState) {
is State.LoadingModel -> "Loading model..." is State.LoadingModel -> "Loading model..."
is State.ProcessingSystemPrompt -> "Processing system prompt..." is State.ProcessingSystemPrompt -> "Processing system prompt..."
is State.ModelLoaded -> "Preparing conversation..." is State.ModelReady -> "Preparing conversation..."
else -> "Processing..." else -> "Processing..."
}, },
style = MaterialTheme.typography.titleMedium style = MaterialTheme.typography.titleMedium

View File

@ -13,9 +13,14 @@ interface InferenceEngine {
val state: StateFlow<State> val state: StateFlow<State>
/** /**
* Load a model from the given path with an optional system prompt. * Load a model from the given path.
*/ */
suspend fun loadModel(pathToModel: String, systemPrompt: String? = null) suspend fun loadModel(pathToModel: String)
/**
* Sends a system prompt to the loaded model
*/
suspend fun setSystemPrompt(systemPrompt: String)
/** /**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens. * Sends a user prompt to the loaded model and returns a Flow of generated tokens.
@ -45,11 +50,9 @@ interface InferenceEngine {
object LibraryLoaded : State() object LibraryLoaded : State()
object LoadingModel : State() object LoadingModel : State()
object ModelLoaded : State() object ModelReady : State()
object ProcessingSystemPrompt : State() object ProcessingSystemPrompt : State()
object AwaitingUserPrompt : State()
object ProcessingUserPrompt : State() object ProcessingUserPrompt : State()
object Generating : State() object Generating : State()

View File

@ -18,10 +18,6 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import java.io.File import java.io.File
@Target(AnnotationTarget.FUNCTION)
@Retention(AnnotationRetention.SOURCE)
annotation class RequiresCleanup(val message: String = "Remember to call this method for proper cleanup!")
/** /**
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models. * JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
* *
@ -63,6 +59,8 @@ class LLamaAndroid private constructor() : InferenceEngine {
private val _state = MutableStateFlow<State>(State.Uninitialized) private val _state = MutableStateFlow<State>(State.Uninitialized)
override val state: StateFlow<State> = _state override val state: StateFlow<State> = _state
private var _readyForSystemPrompt = false
/** /**
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
*/ */
@ -85,9 +83,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
} }
/** /**
* Load the LLM, then process the plain text system prompt if provided * Load the LLM
*/ */
override suspend fun loadModel(pathToModel: String, systemPrompt: String?) = override suspend fun loadModel(pathToModel: String) =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" } check(_state.value is State.LibraryLoaded) { "Cannot load model in ${_state.value}!" }
File(pathToModel).let { File(pathToModel).let {
@ -96,6 +94,7 @@ class LLamaAndroid private constructor() : InferenceEngine {
} }
Log.i(TAG, "Loading model... \n$pathToModel") Log.i(TAG, "Loading model... \n$pathToModel")
_readyForSystemPrompt = false
_state.value = State.LoadingModel _state.value = State.LoadingModel
load(pathToModel).let { result -> load(pathToModel).let { result ->
if (result != 0) throw IllegalStateException("Failed to Load model: $result") if (result != 0) throw IllegalStateException("Failed to Load model: $result")
@ -104,10 +103,21 @@ class LLamaAndroid private constructor() : InferenceEngine {
if (result != 0) throw IllegalStateException("Failed to prepare resources: $result") if (result != 0) throw IllegalStateException("Failed to prepare resources: $result")
} }
Log.i(TAG, "Model loaded!") Log.i(TAG, "Model loaded!")
_state.value = State.ModelLoaded _readyForSystemPrompt = true
_state.value = State.ModelReady
}
/**
* Process the plain text system prompt
*/
override suspend fun setSystemPrompt(prompt: String) =
withContext(llamaDispatcher) {
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
check(_state.value is State.ModelReady) { "Cannot process system prompt in ${_state.value}!" }
systemPrompt?.let { prompt ->
Log.i(TAG, "Sending system prompt...") Log.i(TAG, "Sending system prompt...")
_readyForSystemPrompt = false
_state.value = State.ProcessingSystemPrompt _state.value = State.ProcessingSystemPrompt
processSystemPrompt(prompt).let { result -> processSystemPrompt(prompt).let { result ->
if (result != 0) { if (result != 0) {
@ -117,10 +127,7 @@ class LLamaAndroid private constructor() : InferenceEngine {
} }
} }
Log.i(TAG, "System prompt processed! Awaiting user prompt...") Log.i(TAG, "System prompt processed! Awaiting user prompt...")
} ?: run { _state.value = State.ModelReady
Log.w(TAG, "No system prompt to process.")
}
_state.value = State.AwaitingUserPrompt
} }
/** /**
@ -131,12 +138,13 @@ class LLamaAndroid private constructor() : InferenceEngine {
predictLength: Int, predictLength: Int,
): Flow<String> = flow { ): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is State.AwaitingUserPrompt) { check(_state.value is State.ModelReady) {
"User prompt discarded due to: ${_state.value}" "User prompt discarded due to: ${_state.value}"
} }
try { try {
Log.i(TAG, "Sending user prompt...") Log.i(TAG, "Sending user prompt...")
_readyForSystemPrompt = false
_state.value = State.ProcessingUserPrompt _state.value = State.ProcessingUserPrompt
processUserPrompt(message, predictLength).let { result -> processUserPrompt(message, predictLength).let { result ->
if (result != 0) { if (result != 0) {
@ -153,10 +161,10 @@ class LLamaAndroid private constructor() : InferenceEngine {
} ?: break } ?: break
} }
Log.i(TAG, "Assistant generation complete. Awaiting user prompt...") Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
_state.value = State.AwaitingUserPrompt _state.value = State.ModelReady
} catch (e: CancellationException) { } catch (e: CancellationException) {
Log.i(TAG, "Generation cancelled by user.") Log.i(TAG, "Generation cancelled by user.")
_state.value = State.AwaitingUserPrompt _state.value = State.ModelReady
throw e throw e
} catch (e: Exception) { } catch (e: Exception) {
Log.e(TAG, "Error during generation!", e) Log.e(TAG, "Error during generation!", e)
@ -170,13 +178,14 @@ class LLamaAndroid private constructor() : InferenceEngine {
*/ */
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String = override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
check(_state.value is State.AwaitingUserPrompt) { check(_state.value is State.ModelReady) {
"Benchmark request discarded due to: $state" "Benchmark request discarded due to: $state"
} }
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)") Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
_readyForSystemPrompt = false // Just to be safe
_state.value = State.Benchmarking _state.value = State.Benchmarking
benchModel(pp, tg, pl, nr).also { benchModel(pp, tg, pl, nr).also {
_state.value = State.AwaitingUserPrompt _state.value = State.ModelReady
} }
} }
@ -186,8 +195,9 @@ class LLamaAndroid private constructor() : InferenceEngine {
override suspend fun unloadModel() = override suspend fun unloadModel() =
withContext(llamaDispatcher) { withContext(llamaDispatcher) {
when(_state.value) { when(_state.value) {
is State.AwaitingUserPrompt, is State.Error -> { is State.ModelReady, is State.Error -> {
Log.i(TAG, "Unloading model and free resources...") Log.i(TAG, "Unloading model and free resources...")
_readyForSystemPrompt = false
unload() unload()
_state.value = State.LibraryLoaded _state.value = State.LibraryLoaded
Log.i(TAG, "Model unloaded!") Log.i(TAG, "Model unloaded!")
@ -200,8 +210,8 @@ class LLamaAndroid private constructor() : InferenceEngine {
/** /**
* Cancel all ongoing coroutines and free GGML backends * Cancel all ongoing coroutines and free GGML backends
*/ */
@RequiresCleanup("Call from `ViewModel.onCleared()` to prevent resource leaks!")
override fun destroy() { override fun destroy() {
_readyForSystemPrompt = false
llamaScope.cancel() llamaScope.cancel()
when(_state.value) { when(_state.value) {
is State.Uninitialized -> {} is State.Uninitialized -> {}