This commit is contained in:
Naco Siren 2025-12-17 05:51:06 +02:00 committed by GitHub
commit 93dc70c307
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 2588 additions and 1270 deletions

View File

@ -32,7 +32,7 @@
/examples/export-docs/ @ggerganov
/examples/gen-docs/ @ggerganov
/examples/gguf/ @ggerganov
/examples/llama.android/ @ggerganov
/examples/llama.android/ @ggerganov @hanyin-arm @naco-siren
/examples/llama.swiftui/ @ggerganov
/examples/llama.vim @ggerganov
/examples/lookahead/ @ggerganov

View File

@ -190,6 +190,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
- Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma)
- Android: [llama.android](/examples/llama.android)
</details>

View File

@ -1,6 +1,26 @@
# Android
## Build with Android Studio
Import the `examples/llama.android` directory into Android Studio, then perform a Gradle sync and build the project.
![Project imported into Android Studio](./android/imported-into-android-studio.png)
This Android binding supports hardware acceleration up to `SME2` for **Arm** and `AMX` for **x86-64** CPUs on Android and ChromeOS devices.
It automatically detects the host's hardware to load compatible kernels. As a result, it runs seamlessly on both the latest premium devices and older devices that may lack modern CPU features or have limited RAM, without requiring any manual configuration.
A minimal Android app frontend is included to showcase the bindings core functionalities:
1. **Parse GGUF metadata** via `GgufMetadataReader` from either a `ContentResolver` provided `Uri` or a local `File`.
2. **Obtain a `TierDetection` or `InferenceEngine`** instance through the high-level facade APIs.
3. **Send a raw user prompt** for automatic template formatting, prefill, and decoding. Then collect the generated tokens in a Kotlin `Flow`.
For a production-ready experience that leverages advanced features such as system prompts and benchmarks, check out [Arm AI Chat](https://play.google.com/store/apps/details?id=com.arm.aichat) on Google Play.
This project is made possible through a collaborative effort by Arm's **CT-ML**, **CE-ML** and **STE** groups:
| ![Home screen](./android/arm-ai-chat-home-screen.png) | ![System prompt](./android/system-prompt-setup.png) | !["Haiku"](./android/chat-with-system-prompt-haiku.png) |
|:------------------------------------------------------:|:----------------------------------------------------:|:--------------------------------------------------------:|
| Home screen | System prompt | "Haiku" |
## Build on Android using Termux
[Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid.

View File

@ -1,16 +1,18 @@
plugins {
id("com.android.application")
id("org.jetbrains.kotlin.android")
alias(libs.plugins.android.application)
alias(libs.plugins.jetbrains.kotlin.android)
}
android {
namespace = "com.example.llama"
compileSdk = 34
compileSdk = 36
defaultConfig {
applicationId = "com.example.llama"
applicationId = "com.example.llama.aichat"
minSdk = 33
targetSdk = 34
targetSdk = 36
versionCode = 1
versionName = "1.0"
@ -21,8 +23,17 @@ android {
}
buildTypes {
debug {
isMinifyEnabled = true
isShrinkResources = true
proguardFiles(
getDefaultProguardFile("proguard-android.txt"),
"proguard-rules.pro"
)
}
release {
isMinifyEnabled = false
isMinifyEnabled = true
isShrinkResources = true
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
@ -36,30 +47,15 @@ android {
kotlinOptions {
jvmTarget = "1.8"
}
buildFeatures {
compose = true
}
composeOptions {
kotlinCompilerExtensionVersion = "1.5.1"
}
}
dependencies {
implementation(libs.bundles.androidx)
implementation(libs.material)
implementation("androidx.core:core-ktx:1.12.0")
implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.2")
implementation("androidx.activity:activity-compose:1.8.2")
implementation(platform("androidx.compose:compose-bom:2023.08.00"))
implementation("androidx.compose.ui:ui")
implementation("androidx.compose.ui:ui-graphics")
implementation("androidx.compose.ui:ui-tooling-preview")
implementation("androidx.compose.material3:material3")
implementation(project(":llama"))
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
androidTestImplementation(platform("androidx.compose:compose-bom:2023.08.00"))
androidTestImplementation("androidx.compose.ui:ui-test-junit4")
debugImplementation("androidx.compose.ui:ui-tooling")
debugImplementation("androidx.compose.ui:ui-test-manifest")
implementation(project(":lib"))
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)
}

View File

@ -19,3 +19,11 @@
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
-keep class com.arm.aichat.* { *; }
-keep class com.arm.aichat.gguf.* { *; }
-assumenosideeffects class android.util.Log {
public static int v(...);
public static int d(...);
}

View File

@ -1,24 +1,21 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">
<uses-permission android:name="android.permission.INTERNET" />
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
<application
android:allowBackup="true"
android:dataExtractionRules="@xml/data_extraction_rules"
android:extractNativeLibs="true"
android:fullBackupContent="@xml/backup_rules"
android:icon="@mipmap/ic_launcher"
android:icon="@mipmap/ic_launcher_round"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/Theme.LlamaAndroid"
android:theme="@style/Theme.AiChatSample"
>
<activity
android:name=".MainActivity"
android:exported="true"
android:theme="@style/Theme.LlamaAndroid">
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />

View File

@ -1,119 +0,0 @@
package com.example.llama
import android.app.DownloadManager
import android.net.Uri
import android.util.Log
import androidx.compose.material3.Button
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableDoubleStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.core.database.getLongOrNull
import androidx.core.net.toUri
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import java.io.File
data class Downloadable(val name: String, val source: Uri, val destination: File) {
companion object {
@JvmStatic
private val tag: String? = this::class.qualifiedName
sealed interface State
data object Ready: State
data class Downloading(val id: Long): State
data class Downloaded(val downloadable: Downloadable): State
data class Error(val message: String): State
@JvmStatic
@Composable
fun Button(viewModel: MainViewModel, dm: DownloadManager, item: Downloadable) {
var status: State by remember {
mutableStateOf(
if (item.destination.exists()) Downloaded(item)
else Ready
)
}
var progress by remember { mutableDoubleStateOf(0.0) }
val coroutineScope = rememberCoroutineScope()
suspend fun waitForDownload(result: Downloading, item: Downloadable): State {
while (true) {
val cursor = dm.query(DownloadManager.Query().setFilterById(result.id))
if (cursor == null) {
Log.e(tag, "dm.query() returned null")
return Error("dm.query() returned null")
}
if (!cursor.moveToFirst() || cursor.count < 1) {
cursor.close()
Log.i(tag, "cursor.moveToFirst() returned false or cursor.count < 1, download canceled?")
return Ready
}
val pix = cursor.getColumnIndex(DownloadManager.COLUMN_BYTES_DOWNLOADED_SO_FAR)
val tix = cursor.getColumnIndex(DownloadManager.COLUMN_TOTAL_SIZE_BYTES)
val sofar = cursor.getLongOrNull(pix) ?: 0
val total = cursor.getLongOrNull(tix) ?: 1
cursor.close()
if (sofar == total) {
return Downloaded(item)
}
progress = (sofar * 1.0) / total
delay(1000L)
}
}
fun onClick() {
when (val s = status) {
is Downloaded -> {
viewModel.load(item.destination.path)
}
is Downloading -> {
coroutineScope.launch {
status = waitForDownload(s, item)
}
}
else -> {
item.destination.delete()
val request = DownloadManager.Request(item.source).apply {
setTitle("Downloading model")
setDescription("Downloading model: ${item.name}")
setAllowedNetworkTypes(DownloadManager.Request.NETWORK_WIFI)
setDestinationUri(item.destination.toUri())
}
viewModel.log("Saving ${item.name} to ${item.destination.path}")
Log.i(tag, "Saving ${item.name} to ${item.destination.path}")
val id = dm.enqueue(request)
status = Downloading(id)
onClick()
}
}
}
Button(onClick = { onClick() }, enabled = status !is Downloading) {
when (status) {
is Downloading -> Text(text = "Downloading ${(progress * 100).toInt()}%")
is Downloaded -> Text("Load ${item.name}")
is Ready -> Text("Download ${item.name}")
is Error -> Text("Download ${item.name}")
}
}
}
}
}

View File

@ -1,154 +1,257 @@
package com.example.llama
import android.app.ActivityManager
import android.app.DownloadManager
import android.content.ClipData
import android.content.ClipboardManager
import android.net.Uri
import android.os.Bundle
import android.os.StrictMode
import android.os.StrictMode.VmPolicy
import android.text.format.Formatter
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.activity.viewModels
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.material3.Button
import androidx.compose.material3.LocalContentColor
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
import androidx.core.content.getSystemService
import com.example.llama.ui.theme.LlamaAndroidTheme
import android.util.Log
import android.widget.EditText
import android.widget.TextView
import android.widget.Toast
import androidx.activity.enableEdgeToEdge
import androidx.activity.result.contract.ActivityResultContracts
import androidx.appcompat.app.AppCompatActivity
import androidx.lifecycle.lifecycleScope
import androidx.recyclerview.widget.LinearLayoutManager
import androidx.recyclerview.widget.RecyclerView
import com.arm.aichat.AiChat
import com.arm.aichat.InferenceEngine
import com.arm.aichat.gguf.GgufMetadata
import com.arm.aichat.gguf.GgufMetadataReader
import com.google.android.material.floatingactionbutton.FloatingActionButton
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
import java.io.FileOutputStream
import java.io.InputStream
import java.util.UUID
class MainActivity(
activityManager: ActivityManager? = null,
downloadManager: DownloadManager? = null,
clipboardManager: ClipboardManager? = null,
): ComponentActivity() {
private val tag: String? = this::class.simpleName
class MainActivity : AppCompatActivity() {
private val activityManager by lazy { activityManager ?: getSystemService<ActivityManager>()!! }
private val downloadManager by lazy { downloadManager ?: getSystemService<DownloadManager>()!! }
private val clipboardManager by lazy { clipboardManager ?: getSystemService<ClipboardManager>()!! }
// Android views
private lateinit var ggufTv: TextView
private lateinit var messagesRv: RecyclerView
private lateinit var userInputEt: EditText
private lateinit var userActionFab: FloatingActionButton
private val viewModel: MainViewModel by viewModels()
// Arm AI Chat inference engine
private lateinit var engine: InferenceEngine
// Get a MemoryInfo object for the device's current memory status.
private fun availableMemory(): ActivityManager.MemoryInfo {
return ActivityManager.MemoryInfo().also { memoryInfo ->
activityManager.getMemoryInfo(memoryInfo)
}
}
// Conversation states
private var isModelReady = false
private val messages = mutableListOf<Message>()
private val lastAssistantMsg = StringBuilder()
private val messageAdapter = MessageAdapter(messages)
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
enableEdgeToEdge()
setContentView(R.layout.activity_main)
StrictMode.setVmPolicy(
VmPolicy.Builder(StrictMode.getVmPolicy())
.detectLeakedClosableObjects()
.build()
)
// Find views
ggufTv = findViewById(R.id.gguf)
messagesRv = findViewById(R.id.messages)
messagesRv.layoutManager = LinearLayoutManager(this)
messagesRv.adapter = messageAdapter
userInputEt = findViewById(R.id.user_input)
userActionFab = findViewById(R.id.fab)
val free = Formatter.formatFileSize(this, availableMemory().availMem)
val total = Formatter.formatFileSize(this, availableMemory().totalMem)
viewModel.log("Current memory: $free / $total")
viewModel.log("Downloads directory: ${getExternalFilesDir(null)}")
val extFilesDir = getExternalFilesDir(null)
val models = listOf(
Downloadable(
"Phi-2 7B (Q4_0, 1.6 GiB)",
Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true"),
File(extFilesDir, "phi-2-q4_0.gguf"),
),
Downloadable(
"TinyLlama 1.1B (f16, 2.2 GiB)",
Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true"),
File(extFilesDir, "tinyllama-1.1-f16.gguf"),
),
Downloadable(
"Phi 2 DPO (Q3_K_M, 1.48 GiB)",
Uri.parse("https://huggingface.co/TheBloke/phi-2-dpo-GGUF/resolve/main/phi-2-dpo.Q3_K_M.gguf?download=true"),
File(extFilesDir, "phi-2-dpo.Q3_K_M.gguf")
),
)
setContent {
LlamaAndroidTheme {
// A surface container using the 'background' color from the theme
Surface(
modifier = Modifier.fillMaxSize(),
color = MaterialTheme.colorScheme.background
) {
MainCompose(
viewModel,
clipboardManager,
downloadManager,
models,
)
// Arm AI Chat initialization
lifecycleScope.launch(Dispatchers.Default) {
engine = AiChat.getInferenceEngine(applicationContext)
}
// Upon CTA button tapped
userActionFab.setOnClickListener {
if (isModelReady) {
// If model is ready, validate input and send to engine
handleUserInput()
} else {
// Otherwise, prompt user to select a GGUF metadata on the device
getContent.launch(arrayOf("*/*"))
}
}
}
private val getContent = registerForActivityResult(
ActivityResultContracts.OpenDocument()
) { uri ->
Log.i(TAG, "Selected file uri:\n $uri")
uri?.let { handleSelectedModel(it) }
}
/**
* Handles the file Uri from [getContent] result
*/
private fun handleSelectedModel(uri: Uri) {
// Update UI states
userActionFab.isEnabled = false
userInputEt.hint = "Parsing GGUF..."
ggufTv.text = "Parsing metadata from selected file \n$uri"
lifecycleScope.launch(Dispatchers.IO) {
// Parse GGUF metadata
Log.i(TAG, "Parsing GGUF metadata...")
contentResolver.openInputStream(uri)?.use {
GgufMetadataReader.create().readStructuredMetadata(it)
}?.let { metadata ->
// Update UI to show GGUF metadata to user
Log.i(TAG, "GGUF parsed: \n$metadata")
withContext(Dispatchers.Main) {
ggufTv.text = metadata.toString()
}
// Ensure the model file is available
val modelName = metadata.filename() + FILE_EXTENSION_GGUF
contentResolver.openInputStream(uri)?.use { input ->
ensureModelFile(modelName, input)
}?.let { modelFile ->
loadModel(modelName, modelFile)
withContext(Dispatchers.Main) {
isModelReady = true
userInputEt.hint = "Type and send a message!"
userInputEt.isEnabled = true
userActionFab.setImageResource(R.drawable.outline_send_24)
userActionFab.isEnabled = true
}
}
}
}
}
@Composable
fun MainCompose(
viewModel: MainViewModel,
clipboard: ClipboardManager,
dm: DownloadManager,
models: List<Downloadable>
) {
Column {
val scrollState = rememberLazyListState()
Box(modifier = Modifier.weight(1f)) {
LazyColumn(state = scrollState) {
items(viewModel.messages) {
Text(
it,
style = MaterialTheme.typography.bodyLarge.copy(color = LocalContentColor.current),
modifier = Modifier.padding(16.dp)
)
}
}
}
OutlinedTextField(
value = viewModel.message,
onValueChange = { viewModel.updateMessage(it) },
label = { Text("Message") },
)
Row {
Button({ viewModel.send() }) { Text("Send") }
Button({ viewModel.bench(8, 4, 1) }) { Text("Bench") }
Button({ viewModel.clear() }) { Text("Clear") }
Button({
viewModel.messages.joinToString("\n").let {
clipboard.setPrimaryClip(ClipData.newPlainText("", it))
}
}) { Text("Copy") }
/**
* Prepare the model file within app's private storage
*/
private suspend fun ensureModelFile(modelName: String, input: InputStream) =
withContext(Dispatchers.IO) {
File(ensureModelsDirectory(), modelName).also { file ->
// Copy the file into local storage if not yet done
if (!file.exists()) {
Log.i(TAG, "Start copying file to $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Copying file..."
}
Column {
for (model in models) {
Downloadable.Button(viewModel, dm, model)
FileOutputStream(file).use { input.copyTo(it) }
Log.i(TAG, "Finished copying file to $modelName")
} else {
Log.i(TAG, "File already exists $modelName")
}
}
}
/**
* Load the model file from the app private storage
*/
private suspend fun loadModel(modelName: String, modelFile: File) =
withContext(Dispatchers.IO) {
Log.i(TAG, "Loading model $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Loading model..."
}
engine.loadModel(modelFile.path)
}
/**
* Validate and send the user message into [InferenceEngine]
*/
private fun handleUserInput() {
userInputEt.text.toString().also { userSsg ->
if (userSsg.isEmpty()) {
Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
} else {
userInputEt.text = null
userActionFab.isEnabled = false
// Update message states
messages.add(Message(UUID.randomUUID().toString(), userSsg, true))
lastAssistantMsg.clear()
messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
lifecycleScope.launch(Dispatchers.Default) {
engine.sendUserPrompt(userSsg)
.onCompletion {
withContext(Dispatchers.Main) {
userActionFab.isEnabled = true
}
}.collect { token ->
val messageCount = messages.size
check(messageCount > 0 && !messages[messageCount - 1].isUser)
messages.removeAt(messageCount - 1).copy(
content = lastAssistantMsg.append(token).toString()
).let { messages.add(it) }
withContext(Dispatchers.Main) {
messageAdapter.notifyItemChanged(messages.size - 1)
}
}
}
}
}
}
/**
* Run a benchmark with the model file
*/
private suspend fun runBenchmark(modelName: String, modelFile: File) =
withContext(Dispatchers.Default) {
Log.i(TAG, "Starts benchmarking $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Running benchmark..."
}
engine.bench(
pp=BENCH_PROMPT_PROCESSING_TOKENS,
tg=BENCH_TOKEN_GENERATION_TOKENS,
pl=BENCH_SEQUENCE,
nr=BENCH_REPETITION
).let { result ->
messages.add(Message(UUID.randomUUID().toString(), result, false))
withContext(Dispatchers.Main) {
messageAdapter.notifyItemChanged(messages.size - 1)
}
}
}
/**
* Create the `models` directory if not exist.
*/
private fun ensureModelsDirectory() =
File(filesDir, DIRECTORY_MODELS).also {
if (it.exists() && !it.isDirectory) { it.delete() }
if (!it.exists()) { it.mkdir() }
}
companion object {
private val TAG = MainActivity::class.java.simpleName
private const val DIRECTORY_MODELS = "models"
private const val FILE_EXTENSION_GGUF = ".gguf"
private const val BENCH_PROMPT_PROCESSING_TOKENS = 512
private const val BENCH_TOKEN_GENERATION_TOKENS = 128
private const val BENCH_SEQUENCE = 1
private const val BENCH_REPETITION = 3
}
}
fun GgufMetadata.filename() = when {
basic.name != null -> {
basic.name?.let { name ->
basic.sizeLabel?.let { size ->
"$name-$size"
} ?: name
}
}
architecture?.architecture != null -> {
architecture?.architecture?.let { arch ->
basic.uuid?.let { uuid ->
"$arch-$uuid"
} ?: "$arch-${System.currentTimeMillis()}"
}
}
else -> {
"model-${System.currentTimeMillis().toHexString()}"
}
}

View File

@ -1,105 +0,0 @@
package com.example.llama
import android.llama.cpp.LLamaAndroid
import android.util.Log
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.setValue
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch
class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() {
companion object {
@JvmStatic
private val NanosPerSecond = 1_000_000_000.0
}
private val tag: String? = this::class.simpleName
var messages by mutableStateOf(listOf("Initializing..."))
private set
var message by mutableStateOf("")
private set
override fun onCleared() {
super.onCleared()
viewModelScope.launch {
try {
llamaAndroid.unload()
} catch (exc: IllegalStateException) {
messages += exc.message!!
}
}
}
fun send() {
val text = message
message = ""
// Add to messages console.
messages += text
messages += ""
viewModelScope.launch {
llamaAndroid.send(text)
.catch {
Log.e(tag, "send() failed", it)
messages += it.message!!
}
.collect { messages = messages.dropLast(1) + (messages.last() + it) }
}
}
fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) {
viewModelScope.launch {
try {
val start = System.nanoTime()
val warmupResult = llamaAndroid.bench(pp, tg, pl, nr)
val end = System.nanoTime()
messages += warmupResult
val warmup = (end - start).toDouble() / NanosPerSecond
messages += "Warm up time: $warmup seconds, please wait..."
if (warmup > 5.0) {
messages += "Warm up took too long, aborting benchmark"
return@launch
}
messages += llamaAndroid.bench(512, 128, 1, 3)
} catch (exc: IllegalStateException) {
Log.e(tag, "bench() failed", exc)
messages += exc.message!!
}
}
}
fun load(pathToModel: String) {
viewModelScope.launch {
try {
llamaAndroid.load(pathToModel)
messages += "Loaded $pathToModel"
} catch (exc: IllegalStateException) {
Log.e(tag, "load() failed", exc)
messages += exc.message!!
}
}
}
fun updateMessage(newMessage: String) {
message = newMessage
}
fun clear() {
messages = listOf()
}
fun log(message: String) {
messages += message
}
}

View File

@ -0,0 +1,51 @@
package com.example.llama
import android.view.LayoutInflater
import android.view.View
import android.view.ViewGroup
import android.widget.TextView
import androidx.recyclerview.widget.RecyclerView
data class Message(
val id: String,
val content: String,
val isUser: Boolean
)
class MessageAdapter(
private val messages: List<Message>
) : RecyclerView.Adapter<RecyclerView.ViewHolder>() {
companion object {
private const val VIEW_TYPE_USER = 1
private const val VIEW_TYPE_ASSISTANT = 2
}
override fun getItemViewType(position: Int): Int {
return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT
}
override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder {
val layoutInflater = LayoutInflater.from(parent.context)
return if (viewType == VIEW_TYPE_USER) {
val view = layoutInflater.inflate(R.layout.item_message_user, parent, false)
UserMessageViewHolder(view)
} else {
val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false)
AssistantMessageViewHolder(view)
}
}
override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) {
val message = messages[position]
if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) {
val textView = holder.itemView.findViewById<TextView>(R.id.msg_content)
textView.text = message.content
}
}
override fun getItemCount(): Int = messages.size
class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
}

View File

@ -1,11 +0,0 @@
package com.example.llama.ui.theme
import androidx.compose.ui.graphics.Color
val Purple80 = Color(0xFFD0BCFF)
val PurpleGrey80 = Color(0xFFCCC2DC)
val Pink80 = Color(0xFFEFB8C8)
val Purple40 = Color(0xFF6650a4)
val PurpleGrey40 = Color(0xFF625b71)
val Pink40 = Color(0xFF7D5260)

View File

@ -1,70 +0,0 @@
package com.example.llama.ui.theme
import android.app.Activity
import android.os.Build
import androidx.compose.foundation.isSystemInDarkTheme
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.darkColorScheme
import androidx.compose.material3.dynamicDarkColorScheme
import androidx.compose.material3.dynamicLightColorScheme
import androidx.compose.material3.lightColorScheme
import androidx.compose.runtime.Composable
import androidx.compose.runtime.SideEffect
import androidx.compose.ui.graphics.toArgb
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalView
import androidx.core.view.WindowCompat
private val DarkColorScheme = darkColorScheme(
primary = Purple80,
secondary = PurpleGrey80,
tertiary = Pink80
)
private val LightColorScheme = lightColorScheme(
primary = Purple40,
secondary = PurpleGrey40,
tertiary = Pink40
/* Other default colors to override
background = Color(0xFFFFFBFE),
surface = Color(0xFFFFFBFE),
onPrimary = Color.White,
onSecondary = Color.White,
onTertiary = Color.White,
onBackground = Color(0xFF1C1B1F),
onSurface = Color(0xFF1C1B1F),
*/
)
@Composable
fun LlamaAndroidTheme(
darkTheme: Boolean = isSystemInDarkTheme(),
// Dynamic color is available on Android 12+
dynamicColor: Boolean = true,
content: @Composable () -> Unit
) {
val colorScheme = when {
dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {
val context = LocalContext.current
if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)
}
darkTheme -> DarkColorScheme
else -> LightColorScheme
}
val view = LocalView.current
if (!view.isInEditMode) {
SideEffect {
val window = (view.context as Activity).window
window.statusBarColor = colorScheme.primary.toArgb()
WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme
}
}
MaterialTheme(
colorScheme = colorScheme,
typography = Typography,
content = content
)
}

View File

@ -1,34 +0,0 @@
package com.example.llama.ui.theme
import androidx.compose.material3.Typography
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.font.FontFamily
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.sp
// Set of Material typography styles to start with
val Typography = Typography(
bodyLarge = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Normal,
fontSize = 16.sp,
lineHeight = 24.sp,
letterSpacing = 0.5.sp
)
/* Other default text styles to override
titleLarge = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Normal,
fontSize = 22.sp,
lineHeight = 28.sp,
letterSpacing = 0.sp
),
labelSmall = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Medium,
fontSize = 11.sp,
lineHeight = 16.sp,
letterSpacing = 0.5.sp
)
*/
)

View File

@ -0,0 +1,4 @@
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
<solid android:color="#E5E5EA" />
<corners android:radius="16dp" />
</shape>

View File

@ -0,0 +1,4 @@
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
<solid android:color="#4285F4" />
<corners android:radius="16dp" />
</shape>

View File

@ -0,0 +1,10 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal">
<path
android:fillColor="@android:color/white"
android:pathData="M20,6h-8l-2,-2L4,4c-1.1,0 -1.99,0.9 -1.99,2L2,18c0,1.1 0.9,2 2,2h16c1.1,0 2,-0.9 2,-2L22,8c0,-1.1 -0.9,-2 -2,-2zM20,18L4,18L4,8h16v10z"/>
</vector>

View File

@ -0,0 +1,11 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal"
android:autoMirrored="true">
<path
android:fillColor="@android:color/white"
android:pathData="M4.01,6.03l7.51,3.22 -7.52,-1 0.01,-2.22m7.5,8.72L4,17.97v-2.22l7.51,-1M2.01,3L2,10l15,2 -15,2 0.01,7L23,12 2.01,3z"/>
</vector>

View File

@ -0,0 +1,76 @@
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:id="@+id/main"
android:layout_height="match_parent"
android:layout_width="match_parent">
<LinearLayout
android:fitsSystemWindows="true"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
tools:context=".MainActivity">
<FrameLayout
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="1">
<ScrollView
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:fadeScrollbars="false">
<TextView
android:id="@+id/gguf"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_margin="16dp"
android:text="Selected GGUF model's metadata will show here."
style="@style/TextAppearance.MaterialComponents.Body2"
android:maxLines="100" />
</ScrollView>
</FrameLayout>
<androidx.recyclerview.widget.RecyclerView
android:id="@+id/messages"
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="4"
android:padding="16dp"
android:fadeScrollbars="false"
app:reverseLayout="true"
tools:listitem="@layout/item_message_assistant"/>
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:orientation="horizontal">
<EditText
android:id="@+id/user_input"
android:enabled="false"
android:layout_width="0dp"
android:layout_weight="1"
android:layout_height="match_parent"
android:padding="8dp"
style="@style/TextAppearance.MaterialComponents.Body2"
android:hint="Please first pick a GGUF model file to import." />
<com.google.android.material.floatingactionbutton.FloatingActionButton
android:id="@+id/fab"
android:enabled="true"
style="@style/Widget.Material3.FloatingActionButton.Primary"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_margin="8dp"
android:src="@drawable/outline_folder_open_24" />
</LinearLayout>
</LinearLayout>
</androidx.constraintlayout.widget.ConstraintLayout>

View File

@ -0,0 +1,15 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:padding="8dp"
android:gravity="start">
<TextView
android:id="@+id/msg_content"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:background="@drawable/bg_assistant_message"
android:padding="12dp"
android:textColor="@android:color/black" />
</LinearLayout>

View File

@ -0,0 +1,15 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:padding="8dp"
android:gravity="end">
<TextView
android:id="@+id/msg_content"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:background="@drawable/bg_user_message"
android:padding="12dp"
android:textColor="@android:color/white" />
</LinearLayout>

View File

@ -1,3 +1,3 @@
<resources>
<string name="app_name">LlamaAndroid</string>
<string name="app_name">AI Chat basic sample</string>
</resources>

View File

@ -1,5 +1,10 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<style name="Theme.LlamaAndroid" parent="android:Theme.Material.Light.NoActionBar" />
<style name="Base.Theme.AiChatSample" parent="Theme.Material3.DayNight.NoActionBar">
<!-- Customize your light theme here. -->
<!-- <item name="colorPrimary">@color/my_light_primary</item> -->
</style>
<style name="Theme.AiChatSample" parent="Base.Theme.AiChatSample" />
</resources>

View File

@ -1,6 +1,6 @@
// Top-level build file where you can add configuration options common to all sub-projects/modules.
plugins {
id("com.android.application") version "8.2.0" apply false
id("org.jetbrains.kotlin.android") version "1.9.0" apply false
id("com.android.library") version "8.2.0" apply false
alias(libs.plugins.android.application) apply false
alias(libs.plugins.android.library) apply false
alias(libs.plugins.jetbrains.kotlin.android) apply false
}

View File

@ -21,3 +21,4 @@ kotlin.code.style=official
# resources declared in the library itself and none from the library's dependencies,
# thereby reducing the size of the R class for that library
android.nonTransitiveRClass=true
android.native.buildOutput=verbose

View File

@ -0,0 +1,53 @@
[versions]
# Plugins
agp = "8.13.0"
kotlin = "2.2.20"
# AndroidX
activity = "1.11.0"
appcompat = "1.7.1"
core-ktx = "1.17.0"
constraint-layout = "2.2.1"
datastore-preferences = "1.1.7"
# Material
material = "1.13.0"
# Testing
espresso-core = "3.7.0"
androidx-junit = "1.3.0"
junit = "4.13.2"
[plugins]
android-application = { id = "com.android.application", version.ref = "agp" }
android-library = { id = "com.android.library", version.ref = "agp" }
jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
[libraries]
# AndroidX
androidx-activity = { group = "androidx.activity", name = "activity", version.ref = "activity" }
androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" }
androidx-constraintlayout = { group = "androidx.constraintlayout", name = "constraintlayout", version.ref = "constraint-layout" }
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "core-ktx" }
androidx-datastore-preferences = { group = "androidx.datastore", name = "datastore-preferences", version.ref = "datastore-preferences" }
#Material
material = { group = "com.google.android.material", name = "material", version.ref = "material" }
# Testing
androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espresso-core" }
androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-junit" }
junit = { group = "junit", name = "junit", version.ref = "junit" }
[bundles]
androidx = [
"androidx-activity",
"androidx-appcompat",
"androidx-constraintlayout",
"androidx-core-ktx",
"androidx-datastore-preferences",
]

View File

@ -1,6 +1,6 @@
#Thu Dec 21 14:31:09 AEDT 2023
#Tue Apr 01 11:15:06 PDT 2025
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

View File

@ -0,0 +1,78 @@
plugins {
alias(libs.plugins.android.library)
alias(libs.plugins.jetbrains.kotlin.android)
}
android {
namespace = "com.arm.aichat"
compileSdk = 36
ndkVersion = "29.0.13113456"
defaultConfig {
minSdk = 33
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
consumerProguardFiles("consumer-rules.pro")
ndk {
abiFilters += listOf("arm64-v8a", "x86_64")
}
externalNativeBuild {
cmake {
arguments += "-DCMAKE_BUILD_TYPE=Release"
arguments += "-DCMAKE_MESSAGE_LOG_LEVEL=DEBUG"
arguments += "-DCMAKE_VERBOSE_MAKEFILE=ON"
arguments += "-DBUILD_SHARED_LIBS=ON"
arguments += "-DLLAMA_BUILD_COMMON=ON"
arguments += "-DLLAMA_CURL=OFF"
arguments += "-DGGML_NATIVE=OFF"
arguments += "-DGGML_BACKEND_DL=ON"
arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
arguments += "-DGGML_LLAMAFILE=OFF"
}
}
aarMetadata {
minCompileSdk = 35
}
}
externalNativeBuild {
cmake {
path("src/main/cpp/CMakeLists.txt")
version = "3.31.6"
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
kotlin {
jvmToolchain(17)
compileOptions {
targetCompatibility = JavaVersion.VERSION_17
}
}
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
}
}
publishing {
singleVariant("release") {
withJavadocJar()
}
}
}
dependencies {
implementation(libs.androidx.core.ktx)
implementation(libs.androidx.datastore.preferences)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
}

View File

@ -0,0 +1,8 @@
-keep class com.arm.aichat.* { *; }
-keep class com.arm.aichat.gguf.* { *; }
-keepclasseswithmembernames class * {
native <methods>;
}
-keep class kotlin.Metadata { *; }

View File

@ -0,0 +1,56 @@
cmake_minimum_required(VERSION 3.31.6)
project("ai-chat" VERSION 1.0.0 LANGUAGES C CXX)
set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED true)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED true)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE)
# --------------------------------------------------------------------------
# AI Chat library
# --------------------------------------------------------------------------
if(DEFINED ANDROID_ABI)
message(STATUS "Detected Android ABI: ${ANDROID_ABI}")
if(ANDROID_ABI STREQUAL "arm64-v8a")
set(GGML_SYSTEM_ARCH "ARM")
set(GGML_CPU_KLEIDIAI ON)
set(GGML_OPENMP ON)
elseif(ANDROID_ABI STREQUAL "x86_64")
set(GGML_SYSTEM_ARCH "x86")
set(GGML_CPU_KLEIDIAI OFF)
set(GGML_OPENMP OFF)
else()
message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}")
endif()
endif()
set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
add_subdirectory(${LLAMA_SRC} build-llama)
add_library(${CMAKE_PROJECT_NAME} SHARED
ai_chat.cpp)
target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH}
GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}>
GGML_OPENMP=$<BOOL:${GGML_OPENMP}>
)
target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
${LLAMA_SRC}
${LLAMA_SRC}/common
${LLAMA_SRC}/include
${LLAMA_SRC}/ggml/include
${LLAMA_SRC}/ggml/src)
target_link_libraries(${CMAKE_PROJECT_NAME}
llama
common
android
log)

View File

@ -0,0 +1,565 @@
#include <android/log.h>
#include <jni.h>
#include <iomanip>
#include <cmath>
#include <string>
#include <unistd.h>
#include <sampling.h>
#include "logging.h"
#include "chat.h"
#include "common.h"
#include "llama.h"
template<class T>
static std::string join(const std::vector<T> &values, const std::string &delim) {
std::ostringstream str;
for (size_t i = 0; i < values.size(); i++) {
str << values[i];
if (i < values.size() - 1) { str << delim; }
}
return str.str();
}
/**
* LLama resources: context, model, batch and sampler
*/
constexpr int N_THREADS_MIN = 2;
constexpr int N_THREADS_MAX = 4;
constexpr int N_THREADS_HEADROOM = 2;
constexpr int DEFAULT_CONTEXT_SIZE = 8192;
constexpr int OVERFLOW_HEADROOM = 4;
constexpr int BATCH_SIZE = 512;
constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
static llama_model * g_model;
static llama_context * g_context;
static llama_batch g_batch;
static common_chat_templates_ptr g_chat_templates;
static common_sampler * g_sampler;
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) {
// Set llama log handler to Android
llama_log_set(aichat_android_log_callback, nullptr);
// Loading all CPU backend variants
const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0);
LOGi("Loading backends from %s", path_to_backend);
ggml_backend_load_all_from_path(path_to_backend);
env->ReleaseStringUTFChars(nativeLibDir, path_to_backend);
// Initialize backends
llama_backend_init();
LOGi("Backend initiated; Log handler set.");
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
llama_model_params model_params = llama_model_default_params();
const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
auto *model = llama_model_load_from_file(model_path, model_params);
env->ReleaseStringUTFChars(jmodel_path, model_path);
if (!model) {
return 1;
}
g_model = model;
return 0;
}
static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
if (!model) {
LOGe("%s: model cannot be null", __func__);
return nullptr;
}
// Multi-threading setup
const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
(int) sysconf(_SC_NPROCESSORS_ONLN) -
N_THREADS_HEADROOM));
LOGi("%s: Using %d threads", __func__, n_threads);
// Context parameters setup
llama_context_params ctx_params = llama_context_default_params();
const int trained_context_size = llama_model_n_ctx_train(model);
if (n_ctx > trained_context_size) {
LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
__func__, trained_context_size, n_ctx);
}
ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = BATCH_SIZE;
ctx_params.n_ubatch = BATCH_SIZE;
ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads;
auto *context = llama_init_from_model(g_model, ctx_params);
if (context == nullptr) {
LOGe("%s: llama_new_context_with_model() returned null)", __func__);
}
return context;
}
static common_sampler *new_sampler(float temp) {
common_params_sampling sparams;
sparams.temp = temp;
return common_sampler_init(g_model, sparams);
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
auto *context = init_context(g_model);
if (!context) { return 1; }
g_context = context;
g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
g_chat_templates = common_chat_templates_init(g_model, "");
g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
return 0;
}
static std::string get_backend() {
std::vector<std::string> backends;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
auto *reg = ggml_backend_reg_get(i);
std::string name = ggml_backend_reg_name(reg);
if (name != "CPU") {
backends.push_back(ggml_backend_reg_name(reg));
}
}
return backends.empty() ? "CPU" : join(backends, ",");
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
return env->NewStringUTF(llama_print_system_info());
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
jint pl, jint nr) {
auto *context = init_context(g_model, pp);
if (!context) {
const auto *const err_msg = "Fail to init_context! Bench aborted.";
LOGe(err_msg);
return env->NewStringUTF(err_msg);
}
auto pp_avg = 0.0;
auto tg_avg = 0.0;
auto pp_std = 0.0;
auto tg_std = 0.0;
const uint32_t n_ctx = llama_n_ctx(context);
LOGi("n_ctx = %d", n_ctx);
int i, j;
int nri;
for (nri = 0; nri < nr; nri++) {
LOGi("Benchmark prompt processing (pp = %d)", pp);
common_batch_clear(g_batch);
const int n_tokens = pp;
for (i = 0; i < n_tokens; i++) {
common_batch_add(g_batch, 0, i, {0}, false);
}
g_batch.logits[g_batch.n_tokens - 1] = true;
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp_start = ggml_time_us();
if (llama_decode(context, g_batch) != 0) {
LOGe("llama_decode() failed during prompt processing");
}
const auto t_pp_end = ggml_time_us();
// bench text generation
LOGi("Benchmark text generation (tg = %d)", tg);
llama_memory_clear(llama_get_memory(context), false);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {
common_batch_clear(g_batch);
for (j = 0; j < pl; j++) {
common_batch_add(g_batch, 0, i, {j}, true);
}
if (llama_decode(context, g_batch) != 0) {
LOGe("llama_decode() failed during text generation");
}
}
const auto t_tg_end = ggml_time_us();
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
const auto speed_pp = double(pp) / t_pp;
const auto speed_tg = double(pl * tg) / t_tg;
pp_avg += speed_pp;
tg_avg += speed_tg;
pp_std += speed_pp * speed_pp;
tg_std += speed_tg * speed_tg;
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
}
llama_free(context);
pp_avg /= double(nr);
tg_avg /= double(nr);
if (nr > 1) {
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
} else {
pp_std = 0;
tg_std = 0;
}
char model_desc[128];
llama_model_desc(g_model, model_desc, sizeof(model_desc));
const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
const auto backend = get_backend();
std::stringstream result;
result << std::setprecision(3);
result << "| model | size | params | backend | test | t/s |\n";
result << "| --- | --- | --- | --- | --- | --- |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
<< backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
<< backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
return env->NewStringUTF(result.str().c_str());
}
/**
* Completion loop's long-term states:
* - chat management
* - position tracking
*/
constexpr const char *ROLE_SYSTEM = "system";
constexpr const char *ROLE_USER = "user";
constexpr const char *ROLE_ASSISTANT = "assistant";
static std::vector<common_chat_msg> chat_msgs;
static llama_pos system_prompt_position;
static llama_pos current_position;
static void reset_long_term_states(const bool clear_kv_cache = true) {
chat_msgs.clear();
system_prompt_position = 0;
current_position = 0;
if (clear_kv_cache)
llama_memory_clear(llama_get_memory(g_context), false);
}
/**
* TODO-hyin: implement sliding-window version as a better alternative
*
* Context shifting by discarding the older half of the tokens appended after system prompt:
* - take the [system_prompt_position] first tokens from the original prompt
* - take half of the last (system_prompt_position - system_prompt_position) tokens
* - recompute the logits in batches
*/
static void shift_context() {
const int n_discard = (current_position - system_prompt_position) / 2;
LOGi("%s: Discarding %d tokens", __func__, n_discard);
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
current_position -= n_discard;
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
}
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
common_chat_msg new_msg;
new_msg.role = role;
new_msg.content = content;
auto formatted = common_chat_format_single(
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
chat_msgs.push_back(new_msg);
LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
return formatted;
}
/**
* Completion loop's short-term states:
* - stop generation position
* - token chars caching
* - current assistant message being generated
*/
static llama_pos stop_generation_position;
static std::string cached_token_chars;
static std::ostringstream assistant_ss;
static void reset_short_term_states() {
stop_generation_position = 0;
cached_token_chars.clear();
assistant_ss.str("");
}
static int decode_tokens_in_batches(
llama_context *context,
llama_batch &batch,
const llama_tokens &tokens,
const llama_pos start_pos,
const bool compute_last_logit = false) {
// Process tokens in batches using the global batch
LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
common_batch_clear(batch);
LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
// Shift context if current batch cannot fit into the context
if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
shift_context();
}
// Add tokens to the batch with proper positions
for (int j = 0; j < cur_batch_size; j++) {
const llama_token token_id = tokens[i + j];
const llama_pos position = start_pos + i + j;
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
common_batch_add(batch, token_id, position, {0}, want_logit);
}
// Decode this batch
const int decode_result = llama_decode(context, batch);
if (decode_result) {
LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
return 1;
}
}
return 0;
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt(
JNIEnv *env,
jobject /*unused*/,
jstring jsystem_prompt
) {
// Reset long-term & short-term states
reset_long_term_states();
reset_short_term_states();
// Obtain system prompt from JEnv
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
std::string formatted_system_prompt(system_prompt);
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
// Format system prompt if applicable
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
if (has_chat_template) {
formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
}
// Tokenize system prompt
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
has_chat_template, has_chat_template);
for (auto id: system_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// Handle context overflow
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
if ((int) system_tokens.size() > max_batch_size) {
LOGe("%s: System prompt too long for context! %d tokens, max: %d",
__func__, (int) system_tokens.size(), max_batch_size);
return 1;
}
// Decode system tokens in batches
if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
LOGe("%s: llama_decode() failed!", __func__);
return 2;
}
// Update position
system_prompt_position = current_position = (int) system_tokens.size();
return 0;
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt(
JNIEnv *env,
jobject /*unused*/,
jstring juser_prompt,
jint n_predict
) {
// Reset short-term states
reset_short_term_states();
// Obtain and tokenize user prompt
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
std::string formatted_user_prompt(user_prompt);
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
// Format user prompt if applicable
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
if (has_chat_template) {
formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
}
// Decode formatted user prompts
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
for (auto id: user_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
const int user_prompt_size = (int) user_tokens.size();
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
if (user_prompt_size > max_batch_size) {
const int skipped_tokens = user_prompt_size - max_batch_size;
user_tokens.resize(max_batch_size);
LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
}
// Decode user tokens in batches
if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
LOGe("%s: llama_decode() failed!", __func__);
return 2;
}
// Update position
current_position += user_prompt_size;
stop_generation_position = current_position + user_prompt_size + n_predict;
return 0;
}
static bool is_valid_utf8(const char *string) {
if (!string) { return true; }
const auto *bytes = (const unsigned char *) string;
int num;
while (*bytes != 0x00) {
if ((*bytes & 0x80) == 0x00) {
// U+0000 to U+007F
num = 1;
} else if ((*bytes & 0xE0) == 0xC0) {
// U+0080 to U+07FF
num = 2;
} else if ((*bytes & 0xF0) == 0xE0) {
// U+0800 to U+FFFF
num = 3;
} else if ((*bytes & 0xF8) == 0xF0) {
// U+10000 to U+10FFFF
num = 4;
} else {
return false;
}
bytes += 1;
for (int i = 1; i < num; ++i) {
if ((*bytes & 0xC0) != 0x80) {
return false;
}
bytes += 1;
}
}
return true;
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken(
JNIEnv *env,
jobject /*unused*/
) {
// Infinite text generation via context shifting
if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
LOGw("%s: Context full! Shifting...", __func__);
shift_context();
}
// Stop if reaching the marked position
if (current_position >= stop_generation_position) {
LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
return nullptr;
}
// Sample next token
const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
common_sampler_accept(g_sampler, new_token_id, true);
// Populate the batch with new token, then decode
common_batch_clear(g_batch);
common_batch_add(g_batch, new_token_id, current_position, {0}, true);
if (llama_decode(g_context, g_batch) != 0) {
LOGe("%s: llama_decode() failed for generated token", __func__);
return nullptr;
}
// Update position
current_position++;
// Stop if next token is EOG
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
return nullptr;
}
// If not EOG, convert to text
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
cached_token_chars += new_token_chars;
// Create and return a valid UTF-8 Java string
jstring result = nullptr;
if (is_valid_utf8(cached_token_chars.c_str())) {
result = env->NewStringUTF(cached_token_chars.c_str());
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
assistant_ss << cached_token_chars;
cached_token_chars.clear();
} else {
LOGv("id: %d,\tappend to cache", new_token_id);
result = env->NewStringUTF("");
}
return result;
}
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
// Reset long-term & short-term states
reset_long_term_states();
reset_short_term_states();
// Free up resources
common_sampler_free(g_sampler);
g_chat_templates.reset();
llama_batch_free(g_batch);
llama_free(g_context);
llama_model_free(g_model);
}
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *env, jobject /*unused*/) {
llama_backend_free();
}

View File

@ -0,0 +1,61 @@
//
// Created by Han Yin on 10/31/25.
//
#ifndef AICHAT_LOGGING_H
#define AICHAT_LOGGING_H
#endif //AICHAT_LOGGING_H
#pragma once
#include <android/log.h>
#ifndef LOG_TAG
#define LOG_TAG "ai-chat"
#endif
#ifndef LOG_MIN_LEVEL
#if defined(NDEBUG)
#define LOG_MIN_LEVEL ANDROID_LOG_INFO
#else
#define LOG_MIN_LEVEL ANDROID_LOG_VERBOSE
#endif
#endif
static inline int ai_should_log(int prio) {
return __android_log_is_loggable(prio, LOG_TAG, LOG_MIN_LEVEL);
}
#if LOG_MIN_LEVEL <= ANDROID_LOG_VERBOSE
#define LOGv(...) do { if (ai_should_log(ANDROID_LOG_VERBOSE)) __android_log_print(ANDROID_LOG_VERBOSE, LOG_TAG, __VA_ARGS__); } while (0)
#else
#define LOGv(...) ((void)0)
#endif
#if LOG_MIN_LEVEL <= ANDROID_LOG_DEBUG
#define LOGd(...) do { if (ai_should_log(ANDROID_LOG_DEBUG)) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__); } while (0)
#else
#define LOGd(...) ((void)0)
#endif
#define LOGi(...) do { if (ai_should_log(ANDROID_LOG_INFO )) __android_log_print(ANDROID_LOG_INFO , LOG_TAG, __VA_ARGS__); } while (0)
#define LOGw(...) do { if (ai_should_log(ANDROID_LOG_WARN )) __android_log_print(ANDROID_LOG_WARN , LOG_TAG, __VA_ARGS__); } while (0)
#define LOGe(...) do { if (ai_should_log(ANDROID_LOG_ERROR)) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__); } while (0)
static inline int android_log_prio_from_ggml(enum ggml_log_level level) {
switch (level) {
case GGML_LOG_LEVEL_ERROR: return ANDROID_LOG_ERROR;
case GGML_LOG_LEVEL_WARN: return ANDROID_LOG_WARN;
case GGML_LOG_LEVEL_INFO: return ANDROID_LOG_INFO;
case GGML_LOG_LEVEL_DEBUG: return ANDROID_LOG_DEBUG;
default: return ANDROID_LOG_DEFAULT;
}
}
static inline void aichat_android_log_callback(enum ggml_log_level level,
const char* text,
void* /*user*/) {
const int prio = android_log_prio_from_ggml(level);
if (!ai_should_log(prio)) return;
__android_log_write(prio, LOG_TAG, text);
}

View File

@ -0,0 +1,14 @@
package com.arm.aichat
import android.content.Context
import com.arm.aichat.internal.InferenceEngineImpl
/**
* Main entry point for Arm's AI Chat library.
*/
object AiChat {
/**
* Get the inference engine single instance.
*/
fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context)
}

View File

@ -0,0 +1,89 @@
package com.arm.aichat
import com.arm.aichat.InferenceEngine.State
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.StateFlow
/**
* Interface defining the core LLM inference operations.
*/
interface InferenceEngine {
/**
* Current state of the inference engine
*/
val state: StateFlow<State>
/**
* Load a model from the given path.
*
* @throws UnsupportedArchitectureException if model architecture not supported
*/
suspend fun loadModel(pathToModel: String)
/**
* Sends a system prompt to the loaded model
*/
suspend fun setSystemPrompt(systemPrompt: String)
/**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
/**
* Runs a benchmark with the specified parameters.
*/
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
/**
* Unloads the currently loaded model.
*/
suspend fun cleanUp()
/**
* Cleans up resources when the engine is no longer needed.
*/
fun destroy()
/**
* States of the inference engine
*/
sealed class State {
object Uninitialized : State()
object Initializing : State()
object Initialized : State()
object LoadingModel : State()
object UnloadingModel : State()
object ModelReady : State()
object Benchmarking : State()
object ProcessingSystemPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()
data class Error(val exception: Exception) : State()
}
companion object {
const val DEFAULT_PREDICT_LENGTH = 1024
}
}
val State.isUninterruptible
get() = this is State.Initializing ||
this is State.LoadingModel ||
this is State.UnloadingModel ||
this is State.Benchmarking ||
this is State.ProcessingSystemPrompt ||
this is State.ProcessingUserPrompt
val State.isModelLoaded: Boolean
get() = this is State.ModelReady ||
this is State.Benchmarking ||
this is State.ProcessingSystemPrompt ||
this is State.ProcessingUserPrompt ||
this is State.Generating
class UnsupportedArchitectureException : Exception()

View File

@ -0,0 +1,61 @@
package com.arm.aichat.gguf
import kotlin.collections.get
/**
* Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`).
* The `label` matches what llamacli prints.
*/
enum class FileType(val code: Int, val label: String) {
ALL_F32(0, "all F32"),
MOSTLY_F16(1, "F16"),
MOSTLY_Q4_0(2, "Q4_0"),
MOSTLY_Q4_1(3, "Q4_1"),
// 4 removed
MOSTLY_Q8_0(7, "Q8_0"),
MOSTLY_Q5_0(8, "Q5_0"),
MOSTLY_Q5_1(9, "Q5_1"),
/* Kquants ------------------------------------------------------------ */
MOSTLY_Q2_K (10, "Q2_K - Medium"),
MOSTLY_Q3_K_S (11, "Q3_K - Small"),
MOSTLY_Q3_K_M (12, "Q3_K - Medium"),
MOSTLY_Q3_K_L (13, "Q3_K - Large"),
MOSTLY_Q4_K_S (14, "Q4_K - Small"),
MOSTLY_Q4_K_M (15, "Q4_K - Medium"),
MOSTLY_Q5_K_S (16, "Q5_K - Small"),
MOSTLY_Q5_K_M (17, "Q5_K - Medium"),
MOSTLY_Q6_K (18, "Q6_K"),
/* IQ quants ----------------------------------------------------------- */
MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"),
MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"),
MOSTLY_Q2_K_S (21, "Q2_K - Small"),
MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"),
MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"),
MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"),
MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"),
MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"),
MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"),
MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"),
MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"),
MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"),
MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"),
/* BF16 & Ternary ------------------------------------------------------ */
MOSTLY_BF16 (32, "BF16"),
MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"),
MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"),
/* Special flag -------------------------------------------------------- */
GUESSED(1024, "(guessed)"),
UNKNOWN(-1, "unknown");
companion object {
private val map = entries.associateBy(FileType::code)
fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN
}
}

View File

@ -0,0 +1,132 @@
package com.arm.aichat.gguf
import java.io.IOException
/**
* Structured metadata of GGUF
*/
data class GgufMetadata(
// Basic file info
val version: GgufVersion,
val tensorCount: Long,
val kvCount: Long,
// General info
val basic: BasicInfo,
val author: AuthorInfo? = null,
val additional: AdditionalInfo? = null,
val architecture: ArchitectureInfo? = null,
val baseModels: List<BaseModelInfo>? = null,
val tokenizer: TokenizerInfo? = null,
// Derivative info
val dimensions: DimensionsInfo? = null,
val attention: AttentionInfo? = null,
val rope: RopeInfo? = null,
val experts: ExpertsInfo? = null
) {
enum class GgufVersion(val code: Int, val label: String) {
/** First public draft; littleendian only, no alignment key. */
LEGACY_V1(1, "Legacy v1"),
/** Added splitfile support and some extra metadata keys. */
EXTENDED_V2(2, "Extended v2"),
/** Current spec: endianaware, mandatory alignment, fully validated. */
VALIDATED_V3(3, "Validated v3");
companion object {
fun fromCode(code: Int): GgufVersion =
entries.firstOrNull { it.code == code }
?: throw IOException("Unknown GGUF version code $code")
}
override fun toString(): String = "$label (code=$code)"
}
data class BasicInfo(
val uuid: String? = null,
val name: String? = null,
val nameLabel: String? = null,
val sizeLabel: String? = null, // Size label like "7B"
)
data class AuthorInfo(
val organization: String? = null,
val author: String? = null,
val doi: String? = null,
val url: String? = null,
val repoUrl: String? = null,
val license: String? = null,
val licenseLink: String? = null,
)
data class AdditionalInfo(
val type: String? = null,
val description: String? = null,
val tags: List<String>? = null,
val languages: List<String>? = null,
)
data class ArchitectureInfo(
val architecture: String? = null,
val fileType: Int? = null,
val vocabSize: Int? = null,
val finetune: String? = null,
val quantizationVersion: Int? = null,
)
data class BaseModelInfo(
val name: String? = null,
val author: String? = null,
val version: String? = null,
val organization: String? = null,
val url: String? = null,
val doi: String? = null,
val uuid: String? = null,
val repoUrl: String? = null,
)
data class TokenizerInfo(
val model: String? = null,
val bosTokenId: Int? = null,
val eosTokenId: Int? = null,
val unknownTokenId: Int? = null,
val paddingTokenId: Int? = null,
val addBosToken: Boolean? = null,
val addEosToken: Boolean? = null,
val chatTemplate: String? = null,
)
data class DimensionsInfo(
val contextLength: Int? = null,
val embeddingSize: Int? = null,
val blockCount: Int? = null,
val feedForwardSize: Int? = null,
)
data class AttentionInfo(
val headCount: Int? = null,
val headCountKv: Int? = null,
val keyLength: Int? = null,
val valueLength: Int? = null,
val layerNormEpsilon: Float? = null,
val layerNormRmsEpsilon: Float? = null,
)
data class RopeInfo(
val frequencyBase: Float? = null,
val dimensionCount: Int? = null,
val scalingType: String? = null,
val scalingFactor: Float? = null,
val attnFactor: Float? = null,
val originalContextLength: Int? = null,
val finetuned: Boolean? = null,
)
data class ExpertsInfo(
val count: Int? = null,
val usedCount: Int? = null,
)
}

View File

@ -0,0 +1,77 @@
package com.arm.aichat.gguf
import android.content.Context
import android.net.Uri
import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl
import java.io.File
import java.io.IOException
import java.io.InputStream
/**
* Interface for reading GGUF metadata from model files.
* Use `GgufMetadataReader.create()` to get an instance.
*/
interface GgufMetadataReader {
/**
* Reads the magic number from the specified file path.
*
* @param file Java File to the GGUF file with absolute path
* @return true if file is valid GGUF, otherwise false
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun ensureSourceFileFormat(file: File): Boolean
/**
* Reads the magic number from the specified file path.
*
* @param context Context for obtaining [android.content.ContentProvider]
* @param uri Uri to the GGUF file provided by [android.content.ContentProvider]
* @return true if file is valid GGUF, otherwise false
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean
/**
* Reads and parses GGUF metadata from the specified file path.
*
* @param input the [InputStream] obtained from a readable file or content
* @return Structured metadata extracted from the file
* @throws IOException if file is damaged or cannot be read
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
companion object {
private val DEFAULT_SKIP_KEYS = setOf(
"tokenizer.chat_template",
"tokenizer.ggml.scores",
"tokenizer.ggml.tokens",
"tokenizer.ggml.token_type"
)
/**
* Creates a default GgufMetadataReader instance
*/
fun create(): GgufMetadataReader = GgufMetadataReaderImpl(
skipKeys = DEFAULT_SKIP_KEYS,
arraySummariseThreshold = 1_000
)
/**
* Creates a GgufMetadataReader with custom configuration
*
* @param skipKeys Keys whose value should be skipped entirely (not kept in the result map)
* @param arraySummariseThreshold If 0, arrays longer get summarised, not materialised;
* If -1, never summarise.
*/
fun create(
skipKeys: Set<String> = DEFAULT_SKIP_KEYS,
arraySummariseThreshold: Int = 1_000
): GgufMetadataReader = GgufMetadataReaderImpl(
skipKeys = skipKeys,
arraySummariseThreshold = arraySummariseThreshold
)
}
}
class InvalidFileFormatException : IOException()

View File

@ -0,0 +1,309 @@
package com.arm.aichat.internal
import android.content.Context
import android.util.Log
import com.arm.aichat.InferenceEngine
import com.arm.aichat.UnsupportedArchitectureException
import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance
import dalvik.annotation.optimization.FastNative
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
import java.io.IOException
/**
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
*
* This class implements a singleton pattern for managing the lifecycle of a single LLM instance.
* All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety
* with the underlying C++ native code.
*
* The typical usage flow is:
* 1. Get instance via [getInstance]
* 2. Load a model with [loadModel]
* 3. Send prompts with [sendUserPrompt]
* 4. Generate responses as token streams
* 5. Perform [cleanUp] when done with a model
* 6. Properly [destroy] when completely done
*
* State transitions are managed automatically and validated at each operation.
*
* @see ai_chat.cpp for the native implementation details
*/
internal class InferenceEngineImpl private constructor(
private val nativeLibDir: String
) : InferenceEngine {
companion object {
private val TAG = InferenceEngineImpl::class.java.simpleName
@Volatile
private var instance: InferenceEngine? = null
/**
* Create or obtain [InferenceEngineImpl]'s single instance.
*
* @param Context for obtaining native library directory
* @throws IllegalArgumentException if native library path is invalid
* @throws UnsatisfiedLinkError if library failed to load
*/
internal fun getInstance(context: Context) =
instance ?: synchronized(this) {
val nativeLibDir = context.applicationInfo.nativeLibraryDir
require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" }
try {
Log.i(TAG, "Instantiating InferenceEngineImpl,,,")
InferenceEngineImpl(nativeLibDir).also { instance = it }
} catch (e: UnsatisfiedLinkError) {
Log.e(TAG, "Failed to load native library from $nativeLibDir", e)
throw e
}
}
}
/**
* JNI methods
* @see ai_chat.cpp
*/
@FastNative
private external fun init(nativeLibDir: String)
@FastNative
private external fun load(modelPath: String): Int
@FastNative
private external fun prepare(): Int
@FastNative
private external fun systemInfo(): String
@FastNative
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
@FastNative
private external fun processSystemPrompt(systemPrompt: String): Int
@FastNative
private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
@FastNative
private external fun generateNextToken(): String?
@FastNative
private external fun unload()
@FastNative
private external fun shutdown()
private val _state =
MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized)
override val state: StateFlow<InferenceEngine.State> = _state
private var _readyForSystemPrompt = false
/**
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
*/
@OptIn(ExperimentalCoroutinesApi::class)
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
init {
llamaScope.launch {
try {
check(_state.value is InferenceEngine.State.Uninitialized) {
"Cannot load native library in ${_state.value.javaClass.simpleName}!"
}
_state.value = InferenceEngine.State.Initializing
Log.i(TAG, "Loading native library...")
System.loadLibrary("ai-chat")
init(nativeLibDir)
_state.value = InferenceEngine.State.Initialized
Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}")
} catch (e: Exception) {
Log.e(TAG, "Failed to load native library", e)
throw e
}
}
}
/**
* Load the LLM
*/
override suspend fun loadModel(pathToModel: String) =
withContext(llamaDispatcher) {
check(_state.value is InferenceEngine.State.Initialized) {
"Cannot load model in ${_state.value.javaClass.simpleName}!"
}
try {
Log.i(TAG, "Checking access to model file... \n$pathToModel")
File(pathToModel).let {
require(it.exists()) { "File not found" }
require(it.isFile) { "Not a valid file" }
require(it.canRead()) { "Cannot read file" }
}
Log.i(TAG, "Loading model... \n$pathToModel")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.LoadingModel
load(pathToModel).let {
// TODO-han.yin: find a better way to pass other error codes
if (it != 0) throw UnsupportedArchitectureException()
}
prepare().let {
if (it != 0) throw IOException("Failed to prepare resources")
}
Log.i(TAG, "Model loaded!")
_readyForSystemPrompt = true
_state.value = InferenceEngine.State.ModelReady
} catch (e: Exception) {
Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e)
_state.value = InferenceEngine.State.Error(e)
throw e
}
}
/**
* Process the plain text system prompt
*
* TODO-han.yin: return error code if system prompt not correct processed?
*/
override suspend fun setSystemPrompt(prompt: String) =
withContext(llamaDispatcher) {
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
check(_state.value is InferenceEngine.State.ModelReady) {
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
}
Log.i(TAG, "Sending system prompt...")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.ProcessingSystemPrompt
processSystemPrompt(prompt).let { result ->
if (result != 0) {
RuntimeException("Failed to process system prompt: $result").also {
_state.value = InferenceEngine.State.Error(it)
throw it
}
}
}
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
_state.value = InferenceEngine.State.ModelReady
}
/**
* Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
*/
override fun sendUserPrompt(
message: String,
predictLength: Int,
): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is InferenceEngine.State.ModelReady) {
"User prompt discarded due to: ${_state.value.javaClass.simpleName}"
}
try {
Log.i(TAG, "Sending user prompt...")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.ProcessingUserPrompt
processUserPrompt(message, predictLength).let { result ->
if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result")
return@flow
}
}
Log.i(TAG, "User prompt processed. Generating assistant prompt...")
_state.value = InferenceEngine.State.Generating
while (true) {
generateNextToken()?.let { utf8token ->
if (utf8token.isNotEmpty()) emit(utf8token)
} ?: break
}
Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
_state.value = InferenceEngine.State.ModelReady
} catch (e: CancellationException) {
Log.i(TAG, "Generation cancelled by user.")
_state.value = InferenceEngine.State.ModelReady
throw e
} catch (e: Exception) {
Log.e(TAG, "Error during generation!", e)
_state.value = InferenceEngine.State.Error(e)
throw e
}
}.flowOn(llamaDispatcher)
/**
* Benchmark the model
*/
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
withContext(llamaDispatcher) {
check(_state.value is InferenceEngine.State.ModelReady) {
"Benchmark request discarded due to: $state"
}
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
_readyForSystemPrompt = false // Just to be safe
_state.value = InferenceEngine.State.Benchmarking
benchModel(pp, tg, pl, nr).also {
_state.value = InferenceEngine.State.ModelReady
}
}
/**
* Unloads the model and frees resources, or reset error states
*/
override suspend fun cleanUp() =
withContext(llamaDispatcher) {
when (val state = _state.value) {
is InferenceEngine.State.ModelReady -> {
Log.i(TAG, "Unloading model and free resources...")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.UnloadingModel
unload()
_state.value = InferenceEngine.State.Initialized
Log.i(TAG, "Model unloaded!")
Unit
}
is InferenceEngine.State.Error -> {
Log.i(TAG, "Resetting error states...")
_state.value = InferenceEngine.State.Initialized
Log.i(TAG, "States reset!")
Unit
}
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
}
}
/**
* Cancel all ongoing coroutines and free GGML backends
*/
override fun destroy() {
_readyForSystemPrompt = false
llamaScope.cancel()
when(_state.value) {
is InferenceEngine.State.Uninitialized -> {}
is InferenceEngine.State.Initialized -> shutdown()
else -> { unload(); shutdown() }
}
}
}

View File

@ -0,0 +1,590 @@
package com.arm.aichat.internal.gguf
import android.content.Context
import android.net.Uri
import com.arm.aichat.gguf.GgufMetadata
import com.arm.aichat.gguf.GgufMetadataReader
import com.arm.aichat.gguf.InvalidFileFormatException
import java.io.File
import java.io.IOException
import java.io.InputStream
/**
* Utility class to read GGUF model files and extract metadata key-value pairs.
* This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data.
*/
internal class GgufMetadataReaderImpl(
private val skipKeys: Set<String>,
private val arraySummariseThreshold: Int,
) : GgufMetadataReader {
companion object {
private const val ARCH_LLAMA = "llama"
}
/** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */
enum class MetadataType(val code: Int) {
UINT8(0), INT8(1), UINT16(2), INT16(3),
UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
companion object {
private val codeMap = entries.associateBy(MetadataType::code)
fun fromCode(code: Int): MetadataType = codeMap[code]
?: throw IOException("Unknown metadata value type code: $code")
}
}
/** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */
sealed class MetadataValue {
data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int
data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int
data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian)
data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian)
data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian)
data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian)
data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float
data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true)
data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed)
data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue()
data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian)
data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian)
data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double
}
/* Convert MetadataValue to plain Kotlin primitives for allMetadata map */
private fun MetadataValue.toPrimitive(): Any = when (this) {
is MetadataValue.UInt8 -> value
is MetadataValue.Int8 -> value
is MetadataValue.UInt16 -> value
is MetadataValue.Int16 -> value
is MetadataValue.UInt32 -> value
is MetadataValue.Int32 -> value
is MetadataValue.Float32 -> value
is MetadataValue.Bool -> value
is MetadataValue.StringVal -> value
is MetadataValue.UInt64 -> value
is MetadataValue.Int64 -> value
is MetadataValue.Float64 -> value
is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
}
/**
* Reads the magic number from the specified file path.
*
* @param context Context for obtaining ContentResolver
* @param uri Uri to the GGUF file provided by ContentProvider
* @return true if file is valid GGUF, otherwise false
*/
override suspend fun ensureSourceFileFormat(file: File): Boolean =
file.inputStream().buffered().use { ensureMagic(it) }
/**
* Reads the magic number from the specified file path.
*
* @param context Context for obtaining ContentResolver
* @param uri Uri to the GGUF file provided by ContentProvider
* @return true if file is valid GGUF, otherwise false
*/
override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean =
context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
/** Reads the 4byte magic; throws if magic ≠ "GGUF". */
private fun ensureMagic(input: InputStream): Boolean =
ByteArray(4).let {
if (input.read(it) != 4) throw IOException("Not a valid file!")
it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
}
/**
* Highlevel entry point: parses a `.gguf` file on disk and returns the fully
* populated [GgufMetadata] tree.
*
* Steps performed internally:
* 1. Reads and validates the 8byte header (`"GGUF"` magic + version).
* 2. Streams through the keyvalue section, skipping large blobs if the key
* appears in [skipKeys] or if an array exceeds [arraySummariseThreshold].
* 3. Converts the resulting raw map into stronglytyped substructures
* (basic info, tokenizer, rope, etc.).
*
* The method is STREAMINGONLY: tensors are never mapped or loaded into
* memory, so even multiGB model files can be processed in < 50 ms.
*
* @param path Absolute or relative filesystem path to a `.gguf` file.
* @return A [GgufMetadata] instance containing all recognised metadata plus
* an `allMetadata` map with any keys that were not given a dedicated
* field.
* @throws IOException if the file is not GGUF, the version is unsupported,
* or the metadata block is truncated / corrupt.
*/
override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
// ── 1. header ──────────────────────────────────────────────────────────
// throws on mismatch
val version = ensureMagicAndVersion(input)
val tensorCount = readLittleLong(input)
val kvCount = readLittleLong(input)
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
// ── 3. build structured object ────────────────────────────────────────
return buildStructured(meta, version, tensorCount, kvCount)
}
/** Reads the 4byte magic + 4byte version; throws if magic ≠ "GGUF". */
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
if (!ensureMagic(input)) throw InvalidFileFormatException()
return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
}
/**
* Read an unsigned 32bit littleendian integer.
*
* @throws IOException if fewer than four bytes are available.
*/
private fun readLEUInt32(input: InputStream): Int {
val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
return (b3 and 0xFF shl 24) or
(b2 and 0xFF shl 16) or
(b1 and 0xFF shl 8) or
(b0 and 0xFF)
}
/**
* Lowlevel helper that reads the entire key-value section from the current
* stream position.
*
* @param input Open stream positioned JUST AFTER the header.
* @param kvCnt Number of keyvalue pairs (taken from the header).
* @return Mutable map with one [MetadataValue] for every key that is NOT skipped.
*
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking
* [skipValue] or [parseValue] accordingly.
*/
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> =
mutableMapOf<String, MetadataValue>().apply {
repeat(kvCnt.toInt()) {
val key = readString(input)
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
if (key in skipKeys) {
skipValue(input, valueT)
} else {
this[key] = parseValue(input, valueT)
}
}
}
/**
* Converts a flat [Map]<[String], [MetadataValue]> into the stronglytyped
* [GgufMetadata] tree used by the rest of the app.
*
* Only the keys listed in the spec are copied into dedicated data classes;
* everything else is preserved in `GgufMetadata.allMetadata`.
*
* @param m Raw key/value map.
* @param version GGUF fileformat version (enum).
* @param tensorCnt Number of tensors (from the header).
* @param kvCnt Total metadata pair count (from the header).
*/
private fun buildStructured(
m: Map<String, MetadataValue>,
version: GgufMetadata.GgufVersion,
tensorCnt: Long,
kvCnt: Long
): GgufMetadata {
// ---------- helpers ----------
fun String.str() = (m[this] as? MetadataValue.StringVal)?.value
fun String.bool() = (m[this] as? MetadataValue.Bool)?.value
fun String.i32() = (m[this] as? MetadataValue.Int32)?.value
fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt()
fun String.f32() = (m[this] as? MetadataValue.Float32)?.value
fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat()
fun String.strList(): List<String>? =
(m[this] as? MetadataValue.ArrayVal)
?.elements
?.mapNotNull { (it as? MetadataValue.StringVal)?.value }
val arch = "general.architecture".str() ?: ARCH_LLAMA
// -------------- populate sections ----------------
val basic = GgufMetadata.BasicInfo(
uuid = "general.uuid".str(),
name = "general.basename".str(),
nameLabel = "general.name".str(),
sizeLabel = "general.size_label".str()
)
val author = GgufMetadata.AuthorInfo(
organization = "general.organization".str(),
author = "general.author".str(),
doi = "general.doi".str(),
url = "general.url".str(),
repoUrl = "general.repo_url".str(),
license = "general.license".str(),
licenseLink = "general.license.link".str()
).takeUnless {
organization == null && author == null && doi == null &&
url == null && repoUrl == null && license == null && licenseLink == null
}
val additional = GgufMetadata.AdditionalInfo(
type = "general.type".str(),
description = "general.description".str(),
tags = "general.tags".strList(),
languages = "general.languages".strList()
).takeUnless {
type == null && description == null && tags == null && languages == null
}
val architectureInfo = GgufMetadata.ArchitectureInfo(
architecture = arch,
fileType = "general.file_type".u32(),
vocabSize = "$arch.vocab_size".u32(),
finetune = "general.finetune".str(),
quantizationVersion = "general.quantization_version".u32()
).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null }
val baseModels = buildList {
val n = "general.base_model.count".u32() ?: 0
for (i in 0 until n) {
fun k(s: String) = "general.base_model.$i.$s"
add(
GgufMetadata.BaseModelInfo(
name = k("name").str(),
author = k("author").str(),
version = k("version").str(),
organization = k("organization").str(),
url = k("url").str(),
doi = k("doi").str(),
uuid = k("uuid").str(),
repoUrl = k("repo_url").str(),
)
)
}
}.takeIf { it.isNotEmpty() }
val tokenizer = GgufMetadata.TokenizerInfo(
model = "tokenizer.ggml.model".str(),
bosTokenId = "tokenizer.ggml.bos_token_id".u32(),
eosTokenId = "tokenizer.ggml.eos_token_id".u32(),
unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(),
paddingTokenId = "tokenizer.ggml.padding_token_id".u32(),
addBosToken = "tokenizer.ggml.add_bos_token".bool(),
addEosToken = "tokenizer.ggml.add_eos_token".bool(),
chatTemplate = "tokenizer.chat_template".str()
).takeUnless { model == null && bosTokenId == null && eosTokenId == null &&
unknownTokenId == null && paddingTokenId == null &&
addBosToken == null && addEosToken == null && chatTemplate == null
}
val dimensions = GgufMetadata.DimensionsInfo(
contextLength = "$arch.context_length".u32(),
embeddingSize = "$arch.embedding_length".u32(),
blockCount = "$arch.block_count".u32(),
feedForwardSize = "$arch.feed_forward_length".u32()
).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null }
val attention = GgufMetadata.AttentionInfo(
headCount = "$arch.attention.head_count".u32(),
headCountKv = "$arch.attention.head_count_kv".u32(),
keyLength = "$arch.attention.key_length".u32(),
valueLength = "$arch.attention.value_length".u32(),
layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(),
layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(),
).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null &&
layerNormEpsilon == null && layerNormRmsEpsilon == null
}
val rope = GgufMetadata.RopeInfo(
frequencyBase = "$arch.rope.freq_base".f32(),
dimensionCount = "$arch.rope.dimension_count".u32(),
scalingType = "$arch.rope.scaling.type".str(),
scalingFactor = "$arch.rope.scaling.factor".f32(),
attnFactor = "$arch.rope.scaling.attn_factor".f32(),
originalContextLength = "$arch.rope.scaling.original_context_length".u32(),
finetuned = "$arch.rope.scaling.finetuned".bool()
).takeUnless { frequencyBase == null && dimensionCount == null &&
scalingType == null && scalingFactor == null && attnFactor == null &&
originalContextLength == null && finetuned == null
}
val experts = GgufMetadata.ExpertsInfo(
count = "$arch.expert_count".u32(),
usedCount = "$arch.expert_used_count".u32()
).takeUnless { count == null && usedCount == null }
return GgufMetadata(
version = version,
tensorCount = tensorCnt,
kvCount = kvCnt,
basic = basic,
author = author,
additional = additional,
architecture = architectureInfo,
baseModels = baseModels,
tokenizer = tokenizer,
dimensions = dimensions,
attention = attention,
rope = rope,
experts = experts
)
}
/**
* Recursively parses a metadata value of the given type from the input stream.
* @param input The input stream positioned at the start of the value.
* @param type The metadata value type to parse.
*/
private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) {
MetadataType.UINT8 -> {
// 1-byte unsigned integer
val byteVal = input.read()
if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.")
MetadataValue.UInt8(byteVal.toUByte())
}
MetadataType.INT8 -> {
// 1-byte signed integer
val byteVal = input.read()
if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.")
MetadataValue.Int8(byteVal.toByte())
}
MetadataType.UINT16 -> {
// 2-byte unsigned integer (little-endian)
val bytes = ByteArray(2)
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.")
// Combine two bytes (little-endian) into an unsigned 16-bit value
val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
MetadataValue.UInt16(u16.toUShort())
}
MetadataType.INT16 -> {
// 2-byte signed integer (little-endian)
val bytes = ByteArray(2)
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.")
// Combine to 16-bit and interpret as signed
val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
MetadataValue.Int16(i16.toShort())
}
MetadataType.UINT32 -> {
// 4-byte unsigned integer (little-endian)
val bytes = ByteArray(4)
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.")
// Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt
val u32 = (bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
MetadataValue.UInt32(u32.toUInt())
}
MetadataType.INT32 -> {
// 4-byte signed integer (little-endian)
val bytes = ByteArray(4)
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.")
// Combine four bytes into a 32-bit signed int
val i32 = (bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF)
MetadataValue.Int32(i32)
}
MetadataType.FLOAT32 -> {
// 4-byte IEEE 754 float (little-endian)
val bytes = ByteArray(4)
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.")
// Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float
val bits = (bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF)
val floatVal = Float.fromBits(bits)
MetadataValue.Float32(floatVal)
}
MetadataType.BOOL -> {
// 1-byte boolean (0 = false, 1 = true)
val byteVal = input.read()
if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.")
if (byteVal != 0 && byteVal != 1) {
throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).")
}
MetadataValue.Bool(byteVal != 0)
}
MetadataType.STRING -> {
// UTF-8 string (length-prefixed with 8-byte length)
val str = readString(input)
MetadataValue.StringVal(str)
}
MetadataType.ARRAY -> {
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
val len = readLittleLong(input)
val count = len.toInt()
if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) {
// fastforward without allocation
repeat(count) { skipValue(input, elemType) }
MetadataValue.StringVal("Array($elemType, $count items) /* summarised */")
} else {
val list = ArrayList<MetadataValue>(count)
repeat(count) { list += parseValue(input, elemType) }
MetadataValue.ArrayVal(elemType, list)
}
}
MetadataType.UINT64 -> {
// 8-byte unsigned integer (little-endian)
val bytes = ByteArray(8)
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.")
// Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range.
val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or
(bytes[6].toULong() and 0xFFuL shl 48) or
(bytes[5].toULong() and 0xFFuL shl 40) or
(bytes[4].toULong() and 0xFFuL shl 32) or
(bytes[3].toULong() and 0xFFuL shl 24) or
(bytes[2].toULong() and 0xFFuL shl 16) or
(bytes[1].toULong() and 0xFFuL shl 8) or
(bytes[0].toULong() and 0xFFuL)
MetadataValue.UInt64(u64)
}
MetadataType.INT64 -> {
// 8-byte signed integer (little-endian)
val bytes = ByteArray(8)
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.")
// Combine 8 bytes into a signed 64-bit value (Long)
val i64 = (bytes[7].toLong() and 0xFFL shl 56) or
(bytes[6].toLong() and 0xFFL shl 48) or
(bytes[5].toLong() and 0xFFL shl 40) or
(bytes[4].toLong() and 0xFFL shl 32) or
(bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
MetadataValue.Int64(i64)
}
MetadataType.FLOAT64 -> {
// 8-byte IEEE 754 double (little-endian)
val bytes = ByteArray(8)
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.")
// Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double
val bits = (bytes[7].toLong() and 0xFFL shl 56) or
(bytes[6].toLong() and 0xFFL shl 48) or
(bytes[5].toLong() and 0xFFL shl 40) or
(bytes[4].toLong() and 0xFFL shl 32) or
(bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
val doubleVal = Double.fromBits(bits)
MetadataValue.Float64(doubleVal)
}
}
private fun <T> T?.takeUnless(check: T.() -> Boolean): T? =
this?.takeIf { !it.check() }
/** Helper: Skip a value in the stream without storing it (still maintains pointer). */
private fun skipValue(input: InputStream, type: MetadataType) {
when (type) {
MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1)
MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2)
MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4)
MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8)
MetadataType.STRING -> {
val len = readLittleLong(input); input.skipFully(len)
}
MetadataType.ARRAY -> {
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
val len = readLittleLong(input)
repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip
}
}
}
/** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */
private fun readLittleLong(input: InputStream): Long {
val bytes = ByteArray(8)
input.readFully(bytes)
// Combine 8 bytes into a 64-bit value (Little Endian).
// Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement).
// In our context (lengths/counts), such extremely large values are not expected.
return (bytes[7].toLong() and 0xFFL shl 56) or
(bytes[6].toLong() and 0xFFL shl 48) or
(bytes[5].toLong() and 0xFFL shl 40) or
(bytes[4].toLong() and 0xFFL shl 32) or
(bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
}
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
private fun readString(input: InputStream): String =
// Read 8-byte little-endian length (number of bytes in the string).
readLittleLong(input).let { len ->
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
// Read the UTF-8 bytes of the given length.
ByteArray(len.toInt()).let {
if (it.isNotEmpty()) input.readFully(it)
String(it, Charsets.UTF_8)
}
}
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
private fun littleEndianBytesToInt(bytes: ByteArray): Int =
// Note: assumes bytes length is 4.
(bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF)
/**
* Robust skip that works the same on JDK 11 and Androids desugared runtime.
*
* @param n Number of bytes to advance in the stream.
* @throws IOException on premature EOF.
*/
private fun InputStream.skipFully(n: Long) {
var remaining = n
val scratch = ByteArray(8192) // readandtoss buffer
while (remaining > 0) {
val skipped = skip(remaining)
when {
skipped > 0 -> remaining -= skipped // normal fast path
skipped == 0L -> {
// fallback: read and discard
val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt())
if (read == -1) throw IOException("EOF while skipping $n bytes")
remaining -= read
}
else -> throw IOException("Skip returned negative value")
}
}
}
/**
* Extension that keeps reading until the requested number of bytes are filled.
* Falls back to `read()` when `skip()` returns 0, which happens on some Android
* streams.
*
* @param buf Destination buffer.
* @param len Number of bytes to fill (defaults to `buf.size`).
* @throws IOException on premature EOF.
*/
private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) {
var off = 0
while (off < len) {
val n = read(buf, off, len - off)
if (n == -1) throw IOException("EOF after $off of $len bytes")
off += n
}
}
/**
* Read EXACTLY `n` bytes or throw never returns a partiallyfilled array.
* This is used for small fixedlength reads (e.g. 4byte type codes).
*
* @throws IOException on premature EOF.
*/
private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
if (read(it) != n) throw IOException("Unexpected EOF")
}
}

View File

@ -1,71 +0,0 @@
plugins {
id("com.android.library")
id("org.jetbrains.kotlin.android")
}
android {
namespace = "android.llama.cpp"
compileSdk = 34
defaultConfig {
minSdk = 33
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
consumerProguardFiles("consumer-rules.pro")
ndk {
// Add NDK properties if wanted, e.g.
// abiFilters += listOf("arm64-v8a")
}
externalNativeBuild {
cmake {
arguments += "-DLLAMA_CURL=OFF"
arguments += "-DLLAMA_BUILD_COMMON=ON"
arguments += "-DGGML_LLAMAFILE=OFF"
arguments += "-DCMAKE_BUILD_TYPE=Release"
cppFlags += listOf()
arguments += listOf()
cppFlags("")
}
}
}
buildTypes {
release {
isMinifyEnabled = false
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
}
}
externalNativeBuild {
cmake {
path("src/main/cpp/CMakeLists.txt")
version = "3.22.1"
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
}
kotlinOptions {
jvmTarget = "1.8"
}
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
}
}
}
dependencies {
implementation("androidx.core:core-ktx:1.12.0")
implementation("androidx.appcompat:appcompat:1.6.1")
implementation("com.google.android.material:material:1.11.0")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
}

View File

@ -1,53 +0,0 @@
# For more information about using CMake with Android Studio, read the
# documentation: https://d.android.com/studio/projects/add-native-code.html.
# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
# Sets the minimum CMake version required for this project.
cmake_minimum_required(VERSION 3.22.1)
# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
# Since this is the top level CMakeLists.txt, the project name is also accessible
# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
# build script scope).
project("llama-android")
#include(FetchContent)
#FetchContent_Declare(
# llama
# GIT_REPOSITORY https://github.com/ggml-org/llama.cpp
# GIT_TAG master
#)
# Also provides "common"
#FetchContent_MakeAvailable(llama)
# Creates and names a library, sets it as either STATIC
# or SHARED, and provides the relative paths to its source code.
# You can define multiple libraries, and CMake builds them for you.
# Gradle automatically packages shared libraries with your APK.
#
# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
# is preferred for the same purpose.
#
#load local llama.cpp
add_subdirectory(../../../../../../ build-llama)
# In order to load a library into your app from Java/Kotlin, you must call
# System.loadLibrary() and pass the name of the library defined here;
# for GameActivity/NativeActivity derived applications, the same library name must be
# used in the AndroidManifest.xml file.
add_library(${CMAKE_PROJECT_NAME} SHARED
# List C/C++ source files with relative paths to this CMakeLists.txt.
llama-android.cpp)
# Specifies libraries CMake should link to your target library. You
# can link libraries from various origins, such as libraries defined in this
# build script, prebuilt third-party libraries, or Android system libraries.
target_link_libraries(${CMAKE_PROJECT_NAME}
# List libraries link to the target library
llama
common
android
log)

View File

@ -1,452 +0,0 @@
#include <android/log.h>
#include <jni.h>
#include <iomanip>
#include <math.h>
#include <string>
#include <unistd.h>
#include "llama.h"
#include "common.h"
// Write C++ code here.
//
// Do not forget to dynamically load the C++ library into your application.
//
// For instance,
//
// In MainActivity.java:
// static {
// System.loadLibrary("llama-android");
// }
//
// Or, in MainActivity.kt:
// companion object {
// init {
// System.loadLibrary("llama-android")
// }
// }
#define TAG "llama-android.cpp"
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
jclass la_int_var;
jmethodID la_int_var_value;
jmethodID la_int_var_inc;
std::string cached_token_chars;
bool is_valid_utf8(const char * string) {
if (!string) {
return true;
}
const unsigned char * bytes = (const unsigned char *)string;
int num;
while (*bytes != 0x00) {
if ((*bytes & 0x80) == 0x00) {
// U+0000 to U+007F
num = 1;
} else if ((*bytes & 0xE0) == 0xC0) {
// U+0080 to U+07FF
num = 2;
} else if ((*bytes & 0xF0) == 0xE0) {
// U+0800 to U+FFFF
num = 3;
} else if ((*bytes & 0xF8) == 0xF0) {
// U+10000 to U+10FFFF
num = 4;
} else {
return false;
}
bytes += 1;
for (int i = 1; i < num; ++i) {
if ((*bytes & 0xC0) != 0x80) {
return false;
}
bytes += 1;
}
}
return true;
}
static void log_callback(ggml_log_level level, const char * fmt, void * data) {
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
llama_model_params model_params = llama_model_default_params();
auto path_to_model = env->GetStringUTFChars(filename, 0);
LOGi("Loading model from %s", path_to_model);
auto model = llama_model_load_from_file(path_to_model, model_params);
env->ReleaseStringUTFChars(filename, path_to_model);
if (!model) {
LOGe("load_model() failed");
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
return 0;
}
return reinterpret_cast<jlong>(model);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
llama_model_free(reinterpret_cast<llama_model *>(model));
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) {
auto model = reinterpret_cast<llama_model *>(jmodel);
if (!model) {
LOGe("new_context(): model cannot be null");
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
return 0;
}
int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
LOGi("Using %d threads", n_threads);
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 2048;
ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads;
llama_context * context = llama_new_context_with_model(model, ctx_params);
if (!context) {
LOGe("llama_new_context_with_model() returned null)");
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
"llama_new_context_with_model() returned null)");
return 0;
}
return reinterpret_cast<jlong>(context);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) {
llama_free(reinterpret_cast<llama_context *>(context));
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) {
llama_backend_free();
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
llama_log_set(log_callback, NULL);
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_bench_1model(
JNIEnv *env,
jobject,
jlong context_pointer,
jlong model_pointer,
jlong batch_pointer,
jint pp,
jint tg,
jint pl,
jint nr
) {
auto pp_avg = 0.0;
auto tg_avg = 0.0;
auto pp_std = 0.0;
auto tg_std = 0.0;
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto model = reinterpret_cast<llama_model *>(model_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const int n_ctx = llama_n_ctx(context);
LOGi("n_ctx = %d", n_ctx);
int i, j;
int nri;
for (nri = 0; nri < nr; nri++) {
LOGi("Benchmark prompt processing (pp)");
common_batch_clear(*batch);
const int n_tokens = pp;
for (i = 0; i < n_tokens; i++) {
common_batch_add(*batch, 0, i, { 0 }, false);
}
batch->logits[batch->n_tokens - 1] = true;
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) {
LOGi("llama_decode() failed during prompt processing");
}
const auto t_pp_end = ggml_time_us();
// bench text generation
LOGi("Benchmark text generation (tg)");
llama_memory_clear(llama_get_memory(context), false);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {
common_batch_clear(*batch);
for (j = 0; j < pl; j++) {
common_batch_add(*batch, 0, i, { j }, true);
}
LOGi("llama_decode() text generation: %d", i);
if (llama_decode(context, *batch) != 0) {
LOGi("llama_decode() failed during text generation");
}
}
const auto t_tg_end = ggml_time_us();
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
const auto speed_pp = double(pp) / t_pp;
const auto speed_tg = double(pl * tg) / t_tg;
pp_avg += speed_pp;
tg_avg += speed_tg;
pp_std += speed_pp * speed_pp;
tg_std += speed_tg * speed_tg;
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
}
pp_avg /= double(nr);
tg_avg /= double(nr);
if (nr > 1) {
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
} else {
pp_std = 0;
tg_std = 0;
}
char model_desc[128];
llama_model_desc(model, model_desc, sizeof(model_desc));
const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
const auto backend = "(Android)"; // TODO: What should this be?
std::stringstream result;
result << std::setprecision(2);
result << "| model | size | params | backend | test | t/s |\n";
result << "| --- | --- | --- | --- | --- | --- |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
return env->NewStringUTF(result.str().c_str());
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
llama_batch *batch = new llama_batch {
0,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
};
if (embd) {
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
} else {
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
}
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
for (int i = 0; i < n_tokens; ++i) {
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return reinterpret_cast<jlong>(batch);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
//llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
delete batch;
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = true;
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
return reinterpret_cast<jlong>(smpl);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer));
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
llama_backend_init();
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) {
return env->NewStringUTF(llama_print_system_info());
}
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_completion_1init(
JNIEnv *env,
jobject,
jlong context_pointer,
jlong batch_pointer,
jstring jtext,
jboolean format_chat,
jint n_len
) {
cached_token_chars.clear();
const auto text = env->GetStringUTFChars(jtext, 0);
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
bool parse_special = (format_chat == JNI_TRUE);
const auto tokens_list = common_tokenize(context, text, true, parse_special);
auto n_ctx = llama_n_ctx(context);
auto n_kv_req = tokens_list.size() + n_len;
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
if (n_kv_req > n_ctx) {
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
}
for (auto id : tokens_list) {
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
}
common_batch_clear(*batch);
// evaluate the initial prompt
for (auto i = 0; i < tokens_list.size(); i++) {
common_batch_add(*batch, tokens_list[i], i, { 0 }, false);
}
// llama_decode will output logits only for the last token of the prompt
batch->logits[batch->n_tokens - 1] = true;
if (llama_decode(context, *batch) != 0) {
LOGe("llama_decode() failed");
}
env->ReleaseStringUTFChars(jtext, text);
return batch->n_tokens;
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_completion_1loop(
JNIEnv * env,
jobject,
jlong context_pointer,
jlong batch_pointer,
jlong sampler_pointer,
jint n_len,
jobject intvar_ncur
) {
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
const auto model = llama_get_model(context);
const auto vocab = llama_model_get_vocab(model);
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
// sample the most likely token
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
return nullptr;
}
auto new_token_chars = common_token_to_piece(context, new_token_id);
cached_token_chars += new_token_chars;
jstring new_token = nullptr;
if (is_valid_utf8(cached_token_chars.c_str())) {
new_token = env->NewStringUTF(cached_token_chars.c_str());
LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
cached_token_chars.clear();
} else {
new_token = env->NewStringUTF("");
}
common_batch_clear(*batch);
common_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
env->CallVoidMethod(intvar_ncur, la_int_var_inc);
if (llama_decode(context, *batch) != 0) {
LOGe("llama_decode() returned null");
}
return new_token;
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_memory_clear(llama_get_memory(reinterpret_cast<llama_context *>(context)), true);
}

View File

@ -1,180 +0,0 @@
package android.llama.cpp
import android.util.Log
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.withContext
import java.util.concurrent.Executors
import kotlin.concurrent.thread
class LLamaAndroid {
private val tag: String? = this::class.simpleName
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
thread(start = false, name = "Llm-RunLoop") {
Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
// No-op if called more than once.
System.loadLibrary("llama-android")
// Set llama log handler to Android
log_to_android()
backend_init(false)
Log.d(tag, system_info())
it.run()
}.apply {
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
Log.e(tag, "Unhandled exception", exception)
}
}
}.asCoroutineDispatcher()
private val nlen: Int = 64
private external fun log_to_android()
private external fun load_model(filename: String): Long
private external fun free_model(model: Long)
private external fun new_context(model: Long): Long
private external fun free_context(context: Long)
private external fun backend_init(numa: Boolean)
private external fun backend_free()
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
private external fun free_batch(batch: Long)
private external fun new_sampler(): Long
private external fun free_sampler(sampler: Long)
private external fun bench_model(
context: Long,
model: Long,
batch: Long,
pp: Int,
tg: Int,
pl: Int,
nr: Int
): String
private external fun system_info(): String
private external fun completion_init(
context: Long,
batch: Long,
text: String,
formatChat: Boolean,
nLen: Int
): Int
private external fun completion_loop(
context: Long,
batch: Long,
sampler: Long,
nLen: Int,
ncur: IntVar
): String?
private external fun kv_cache_clear(context: Long)
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
return withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
Log.d(tag, "bench(): $state")
bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
}
else -> throw IllegalStateException("No model loaded")
}
}
}
suspend fun load(pathToModel: String) {
withContext(runLoop) {
when (threadLocalState.get()) {
is State.Idle -> {
val model = load_model(pathToModel)
if (model == 0L) throw IllegalStateException("load_model() failed")
val context = new_context(model)
if (context == 0L) throw IllegalStateException("new_context() failed")
val batch = new_batch(512, 0, 1)
if (batch == 0L) throw IllegalStateException("new_batch() failed")
val sampler = new_sampler()
if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
Log.i(tag, "Loaded model $pathToModel")
threadLocalState.set(State.Loaded(model, context, batch, sampler))
}
else -> throw IllegalStateException("Model already loaded")
}
}
}
fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
while (ncur.value <= nlen) {
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
if (str == null) {
break
}
emit(str)
}
kv_cache_clear(state.context)
}
else -> {}
}
}.flowOn(runLoop)
/**
* Unloads the model and frees resources.
*
* This is a no-op if there's no model loaded.
*/
suspend fun unload() {
withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
free_context(state.context)
free_model(state.model)
free_batch(state.batch)
free_sampler(state.sampler);
threadLocalState.set(State.Idle)
}
else -> {}
}
}
}
companion object {
private class IntVar(value: Int) {
@Volatile
var value: Int = value
private set
fun inc() {
synchronized(this) {
value += 1
}
}
}
private sealed interface State {
data object Idle: State
data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State
}
// Enforce only one instance of Llm.
private val _instance: LLamaAndroid = LLamaAndroid()
fun instance(): LLamaAndroid = _instance
}
}

View File

@ -8,11 +8,11 @@ pluginManagement {
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
mavenCentral()
google()
}
}
rootProject.name = "LlamaAndroid"
rootProject.name = "AiChat"
include(":app")
include(":llama")
include(":lib")

View File

@ -386,6 +386,9 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD)
ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
ggml_add_cpu_backend_variant(android_armv9.0_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2)
ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SME)
ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
elseif (APPLE)
ggml_add_cpu_backend_variant(apple_m1 DOTPROD)
ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8)