Compare commits

..

18 Commits

Author SHA1 Message Date
Georgi Gerganov 4301e27319
common : restore grammar-based rejection sampling (#18137)
* common : restart grammar-based rejection sampling

* sampling : allow null samplers
2025-12-17 19:46:00 +02:00
Johannes Gäßler a2c199e479
common: clarify instructions for bug reports (#18134) 2025-12-17 18:44:13 +01:00
HonestQiao 15dd67d869
model: fix GLM-ASR-Nano-2512 load error (#18130) (#18142) 2025-12-17 16:34:35 +01:00
Xuan-Son Nguyen bde461de8c
server: (router) allow child process to report status via stdout (#18110)
* server: (router) allow child process to report status via stdout

* apply suggestions
2025-12-17 14:54:11 +01:00
Piotr Wilkin (ilintar) 8faa87db02
Extend run-org-model.py, add (a) batching (b) loading prompt from file (c) multimodal capacity (#18034) 2025-12-17 14:21:51 +01:00
Johannes Gäßler 6f1f6a961a
Github: ask for -v logs for params_fit [no ci] (#18128) 2025-12-17 13:46:48 +01:00
Alberto Cabrera Pérez 669696e00d
ggml-cpu: ARM64: repack version of q8_0 (dotprod and i8mm) (#18096)
* wip: skeleton for q8_0 repack

* q8_0 repack GEMV implementations

* GEMM implementations

* Formatting

* Fixed format consistency of repack gemm and gemv declarations

* gemv and gemm generic location consistent with declarations

* Removed non-correct unused variables statements

* Cleanup, consistent style

* Missing generic fallbacks for x86 and powerpc
2025-12-17 13:39:13 +02:00
Tarek Dakhran 982060fadc
model: fix LFM2_MOE missing tensors (#18132) 2025-12-17 12:17:11 +01:00
Sigbjørn Skjæret 6853bee680
ci : clean up webui jobs (#18116)
* clean up webui jobs

* refined step control

* forgot dependencies

* apparently always() is needed
2025-12-17 10:45:40 +01:00
Pascal 487674fbb3
common: fix --override-kv to support comma-separated values (#18056)
* common: fix --override-kv to support comma-separated values

* Update common/arg.cpp

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

* common: deprecate repeated arguments, suggest comma-separated values

* common: add comma escape support for --override-kv

* common: optimize duplicate detection with insert().second

Co-authored-by: personalmountains <46615898+personalmountains@users.noreply.github.com>

* common: migrate all repeated args to comma-separated syntax

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
Co-authored-by: personalmountains <46615898+personalmountains@users.noreply.github.com>
2025-12-17 11:36:23 +02:00
yulo acec774ef6
HIP: Refactor mma for RDNA and CDNA (#17990)
* mma.cuh for rdna4

* mma for rdna3

* mmq for rdna4

* mmq for rdna3

* align i-major and j-major

* cdna

* fix cuda error

* add missing tile of mfma

* fix j-major wrong ne on CDNA

* fix gramma and empty spaces

---------

Co-authored-by: zhang hui <you@example.com>
2025-12-17 09:34:54 +01:00
Naco Siren 5c0d18881e
llama.android : Rewrite Android binding (w/o cpu_features dep) (#17413)
* UI: implement basic UI components

* util: implement performance monitor; wrap it with a viewmodel

* util: implement user preferences utility

* UI: implement core flow's screens

* UI: add a new MainActivity; update manifest

* [WIP] DI: implement simple local vm factory provider

* UI: disable triggering drawer via gesture; enable alert dialog on back navigation inside conversation and benchmark

* UI: allow drawer's gesture control only on Home and Settings screens; enable alert dialog on back navigation inside conversation and benchmark

* UI: split a nested parent settings screen into separate child settings screens

* UI: polish system prompt setup UI

* Deps: bump Kotlin plugin; introduce KSP; apply in :app subproject

* DB: setup Room database

* data: introduce repo for System Prompt; flow data from Room to VM

* bugfix: properly handle user's quitting conversation screen while tokens in generation

* UI: rename `ModeSelection` to `ModelLoading` for better clarity

* UI: update app name to be more Arm

* UI: polish conversation screen

* data: code polish

* UI: code polish

* bugfix: handle user quitting on model loading

* UI: locks user in alert dialog when model is unloading

* vm: replace token metrics stubs with actual implementation

* UI: refactor top app bars

* nit: combine temperatureMetrics and useFahrenheit

* DI: introduce Hilt plugin + processor + lib dependencies

* DI: make app Hilt injectable

* DI: make viewmodels Hilt injectable

* DI: replace manual DI with Hilt DI

* UI: optimize AppContent's composing

* bugfix: wait for model to load before navigating to benchmark screen; use NavigationActions instead of raw navController

* UI: navigation with more natural animated transitions

* DI: Optimize AppModule

* Feature: Introduce ModelRepository and ModelsManagementViewModel; update AppModule

* UI: polish UI for ModelsManagementScreen; inject ModelsManagementVieModel

* DI: abstract the protocol of SystemPromptRepository; update AppModule

* data: [WIP] prepare for ModelRepository refactor & impl

* data: introduce Model entity and DAO; update DI module

* UI: replace Models Management screen's stubbing with instrumentation

* UI: polish sort order menu

* data: import local model with file picker

* bugfix: use List instead of Collection for ModelDao's deletion

* data: add a util file for extracting file name & size and model metadata

* UI: enrich ModelManagementState; extract filename to show correct importing UI

* UI: implement multiple models deletion; update Models Management screen

* UI: handle back navigation when user is in multi-selection mode

* util: extract file size formatting into ModelUtils

* UI: add a confirmation step when user picks a file; refactor model import overlay into AlertDialog

* UI: extract a shared ModelCard component

* UI: replace model selection screen's data stubbing; add empty view

* nit: tidy SystemPromptViewModel

* Util: split FileUtils from ModelUtils; extract copy methods into FileUtils

* data: pass through getModelById from ModelDao into ModelRepository

* core: extract conversation and benchmark logics into InferenceManager; add logs and missing state updates in stub InferenceEngine

* vm: split mono MainViewModel into separate individual ViewModels

* vm: merge SystemPromptViewModel into ModelLoadingViewModel

* core: break down InferenceManager due to Interface Segregation Principle

* UI: show model card in Model Loading screen

* UI: show model card in Conversation screen

* UI: unify Model Card components

* core: swap in LLamaAndroid and mark stub engine for testing only

* data: allow canceling the ongoing model import

* UI: update UI ongoing model import's cancellation

* LLama: update engine state after handling the cancellation of sendUserPrompt

* VM: handle the cancellation of ongoing token generation

* LLama: refactor loadModel by splitting the system prompt setting into a separate method

* feature: check for available space before copying local model

* UI: centralize the AppScaffold and modularize its configs

* UI: refactor BottomBarConfig.ModelsManagement APIs

* UI: combine TopBarConfig and BottomBarConfig into each route's ScaffoldConfig

* UI: replace ugly optional as casts in AppScaffold with extension functions

* UI: fix the typo `totalGb` in `StorageMetrics`

* UI: remove code duplication in sort menu

* LLama: add ModelUnloadingState to engine State; add missing state checks in stub engine; fix instrumentation engine's error messages

* UI: refactor back handling by removing centralized BackHandlerSetup and UnloadModelConfirmationDialog from AppContent

* UI: implement BenchmarkScreen's individual back handling

* LLama: add a new Initializing state; ; add two extension properties; rename LibraryLoaded state to Initialized

* UI: Introduce an abstract ViewModel to handle additional model unloading logics

* UI: expose a single facade ModelUnloadDialogHandler; move UnloadModelState into ModelUnloadingViewModel.kt

* UI: migrate ModelLoadingScreen onto ModelLoadingViewModel; update & refine ModelLoadingScreen

* UI: migrate ConversationViewModel onto ModelLoadingViewModel; update & refine ConversationScreen

* nit: extract app name into a constant value; remove unused onBackPressed callbacks

* UI: update AppContent to pass in correct navigation callbacks

* nit: polish ModelLoadingScreen UI

* core: throw Exception instead of returning null if model fails to load

* navigation: sink model loading state management from AppContent down into ModelLoadingScreen; pass ModelLoadingMetrics to Benchmark and Conversation screens

* gguf: add GGUF metadata data holder and its corresponding extractor implementation

* DB: introduce Kotlin serialization extension's library and plugin; add Room runtime library

* GGUF: make GgufMetadata serializable in order to be compatible with Room

* nit: refactor data.local package structure

* nit: rename lastUsed field to dateLastUsed; add dateAdded field

* UI: refactor ModelCard UI to show GGUF metadata

* UI: update ModelSelectionScreen with a preselect mechanism

* UI: polish model card

* nit: allow deselect model on Model Selection screen

* nit: revert accidental committing of debug code

* UI: polish ModelLoading screen

* util: extract formatting helper functions from FileUtils into a new FormatUtils

* UI: polish model cards on Benchmark and Conversation screens to show model loading metrics

* UI: show a Snack bar to warn user that system prompt is not always supported

* UI: handle back press on Model Selection screen

* UI: finally support theme modes; remove hardcoded color schemes, default to dynamic color scheme implementation

* feature: support searching on Model Selection screen

* nit: move scaffold related UI components into a separate package

* UI: extract InfoView out into a separate file for reusability

* data: move Model related actions (query, filter, sort) into ModelInfo file

* UI: animate FAB on model preselection states

* feature: support filtering in Model Management screen

* ui: show empty models info in Model Management screen

* ui: add filter off icon to "Clear filters" menu item

* [WIP] ui: polish Benchmark screen; implement its bottom app bar

* ui: polish Benchmark screen; implement its bottom app bar's rerun and share

* nit: disable mode selection's radio buttons when loading model

* feature: implement Conversation screen's bottom app bar

* pkg: restructure BottomAppBars into separate files in a child package

* pkg: restructure TopBarApps into separate files in a child package

* pkg: restructure system metrics into a separate file

* UI: polish Conversation screen

* data: update system prompt presets

* UI: allow hide or show model card on Conversation & Benchmark screens; fix message arrangement

* data: update & enhance system prompt presets

* deps: introduce Retrofit2

* data: implement HuggingFace data model, data source with Retrofit API

* data: update Model data repository to support fetching HuggingFace models

* [WIP] UI: replace the HuggingFace stub in Model Management screen with actual API call

* UI: map language codes into country Emojis

* ui: add "clear results" action to Benchmark screen

* nit: print current pp & tg in llama-bench

* UI: disable landscape mode; prevent duplicated benchmark running

* llama: migrate C/CXX flags into CMakeList

* [WIP] llama: ABI split builds five .so artifacts.

However, all .so are performing on SVE level

* [WIP] llama: ABI split where five tiers are built sequentially.

* [WIP] llama: disable OpenMP in ABI split since most SoCs are big.LITTLE

* [WIP] llama: enable KleidiAI and disable tier 4 due to `+sve+sve2` bug caused by `ggml_add_cpu_backend_variant_impl` as explained below

```CMake
if (NOT SME_ENABLED MATCHES -1)
...
    set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
...
```

* core: add Google's cpu_features as a submodule

* core: implement cpu_detector native lib

* core: swap out hardcoded LlamaAndroid library loading

* core: add back OpenMP due to huge perf loss on TG128

* misc: reorg the pkg structure

* misc: rename LlamaAndroid related class to InferenceEngine prefixes

* [WIP] lib: move GgufMetadata into the lib submodule

* lib: expose GgufMetadataReader as interface only

* lib: replace the naive & plain SharedPreferences with DataStore implementation

* lib: hide the internal implementations, only expose a facade and interfaces

* lib: expose Arm features

* di: add a stub TierDetection; provide both actual impl and stub in AppModule

* UI: add visualizer UI for Arm features

* misc: UI polish

* lib: refactored InferenceEngineLoader; added a `NONE` Llama Tier

* UI: support `NONE` Llama Tier in general settings

* lib: optimize engine loader; always perform a fresh detection when cache is null

* remote: add HuggingFaceModelDetails data class

* remote: refine HuggingFaceModel data class

* nit: remove `trendingScore` field from HuggingFace model entities, weird...

* remote: refactor HuggingFaceApiService; implement download feature in HuggingFaceRemoteDataSource

* remote: fix the incorrect parse of HuggingFace's inconsistent & weird JSON response

* UI: scaffold Models Management screen and view model

* UI: implement a dialog UI to show fetched HuggingFace models.

* UI: use a broadcast receiver to listen for download complete events and show local import dialog.

* data: handle network exceptions elegantly

* pkg: restructure `data`'s packages

* data: extract local file info, copy and cleanup logics into LocalFileDataSource

* nit: minor UI patch; add missing comments

* bugfix: tapping "Home" in navigation drawer should simply close it without any navigation action.

* UI: improve autoscroll during token generation

* lib: tested on JFrog Artifactory for Maven publishing

* UI: show RAM warning if model too large

* UI: polish model management screen's error dialog

* util: add more items into the mapping table of ISO 639-1 language code to ISO 3166-1 country code

* llm: properly propagate error to UI upon failing to load selected model

* UI: avoid duplicated calculation of token metrics

* lib: read & validate the magic number from the picked source file before executing the import

* UI: add "Learn More" hyperlinks to Error dialog upon model import failures

* lib: refactor the GgufMetadataReader to take  InputStream instead of absolute path as argument

* lib: fix the `SIMD` typo in Tier description

* core: verify model file path is readable

* lib: add UnsupportedArchitectureException for triaged error message

* util: split FormatUtils into multiple utils for better readability

* UI: change benchmark screen from raw markdown to table view

* bugfix: reset preselection upon running the preselected model

* misc: linter issue

* bugfix: fix the malfunctioning monitoring switch

* UI: update Arm features indicator; fix the broken hyperlinks

* UI: add quick action buttons to benchmark screen's result card

* UI: hide share fab after clearing all benchmark results

* UI: fix the model unload dialog message; elevate the model card and hide it by default on Conversation screen;

* UI: hide the stubbing actions in Conversation screen

* UI: add show/hide stats control to conversation screen's assistant message bubble; fix placeholder

* UI: add a info button to explain token metrics

* misc: remove the redundant `Companion` added due to refactoring

* UI: show corresponding system metrics detailed info upon tapping RAM / storage / temperature indicator

* UI: add info button to System Prompt switch; expand the model card by default

* UI: disable tag & language chips; add section headers to explain what they are

* misc: replace top bar indicator's spacer with padding

* UI: merge the Model Selection and Model Management into a unified Models screen

* UI: split the ModelsManagementViewModel from a unified ModelsViewModel due to huge complexity

* UI: add model loading in progress view; polish the empty model info view

* UI: polish the bottom bars and info view when no models found; show loading in progress while fetching models

* build: [BREAKING] bump the versions of libraries and plugins

* UI: fix the breaking build

* UI: add Tooltip on Import FAB for user onboarding

* UI: adds AppPreferences to track user onboarding status

* UI: tracks user's first success on importing a model

* data: add hand crafted rules to filter the models fetched from HuggingFace API

* UI: update app name & about; polish top bars' indicators & buttons

* UI: polish Hugging Face download dialog UI

* UX: implement onboarding tooltips for model import and onboarding

* misc: use sentence case for CTA button labels

* [WIP] UI: add Arm color palette from Philip.Watson3

* UI: address Rojin's UX feedbacks

* UI: address Rojin's UX feedbacks - part 2

* UI: update Arm color palette from Philip.Watson3

* data: make sure fetch preselected models in the same order of their IDs

* UI: fix UI issues in the generic settings screen and navigation drawer

* nit: address Rojin's feedbacks on model import message again

* nit: append `®` to all `Arm` labels

* UI: extract a reusable InfoAlertDialog

* core: support GGML_CPU_ALL_VARIANTS on Android!

* core: restructure Kleidi-Llama library

* core: organizing cmake arguments

* data: sort preselected models according to device's available RAM

* app: update adaptive + themed + legacy icons and app name

* UI: fix the font size auto scaling for ArmFeaturesVisualizer

* core: further improve the performance on native methods

* UI: minor color palette changes; emphasize the bottom bar FABs; fix Settings Screen menu item label

* UI: make more room for assistant message bubble's width

* UI: better usage of tertiary colors to highlight model cards but not for warnings

* UI: fix the layout issue on large font sizes

* lib: support x86-64 by dynamically set Arm related definitions

* lib: replace the factory pattern for  deprecated tiered lib loading with single instance pattern

* llama: update the library name in JNI and CMake project

* llama: update the library's package name and namespace

* llama: update the app's package name and namespace

* app: bump ksp version

* app: remove deprecated SystemUIController from accompanist by migrating to EdgeToEdge

* app: extract AppContent from MainActivity to a separate file in ui package

* lib: add File version for GGUF Magic number verification

* lib: perform engine state check inclusively instead of exclusively

* lib: change `LlamaTier` to `ArmCpuTier`

* lib: remove kleidi-llama related namings

* cleanup: remove Arm AI Chat/Playground app source code; replace with the basic sample app from https://github.com/hanyin-arm/Arm-AI-Chat-Sample

Note: the full Google Play version of AI Chat app will be open will be open sourced in another repo soon, therefore didn't go through the trouble of pruning the history using `git filter-repo` here.

* [WIP] doc: update main and Android README docs; add self to code owners

* lib: revert System.load back to System.loadLibrary

* jni: introduce a logging util to filter different logging levels on different build types

* lib: enable app optimization

* doc: replace stub Google Play app URL with the actual link add screenshots; add my GitHub ID to maintainer list

* Remove cpu_features

* Fix linters issues in editorconfig-checker job

https://github.com/ggml-org/llama.cpp/actions/runs/19548770247/job/55974800633?pr=17413

* Remove unnecessary Android CMake flag

* purge include/cpu_features directory

---------

Co-authored-by: Han Yin <han.yin@arm.com>
2025-12-17 10:14:47 +02:00
TrevorS 4b2a4778f8
arg: allow -kvu flag for llama-perplexity (#18117)
The -kvu (--kv-unified) flag is required for hellaswag and winogrande
benchmarks which use coupled sequences. Without unified KV cache,
these benchmarks fail with:

  split_equal: sequential split is not supported when there are
  coupled sequences in the input batch (you may need to use the -kvu flag)

This change adds LLAMA_EXAMPLE_PERPLEXITY to the allowed examples for
the -kvu argument, enabling its use with llama-perplexity.
2025-12-17 08:33:02 +02:00
Aadeshveer Singh 58062860af
ggml : use WARP_SIZE/2 for argmax reduction offset (#18092) 2025-12-17 11:47:01 +08:00
Yuri Khrustalev 2973a65ecb
gguf-py : allow converting multi-tensor models from read-only locations (#18100) 2025-12-17 02:27:03 +01:00
Johannes Gäßler d0794e89d9
llama-fit-params: force disable mlock (#18103) 2025-12-17 00:50:12 +01:00
Johannes Gäßler 9dcac6cf9f
llama-fit-params: lower ctx size for multi GPU (#18101) 2025-12-17 00:49:34 +01:00
Johannes Gäßler 0e49a7b8b4
llama-fit-params: fix underflow for dense models (#18095) 2025-12-17 00:47:37 +01:00
75 changed files with 3661 additions and 1661 deletions

View File

@ -86,6 +86,7 @@ body:
description: >
If applicable, please copy and paste any relevant log output, including any generated text.
This will be automatically formatted into code, so no need for backticks.
If you are encountering problems specifically with the `llama_params_fit` module, always upload `--verbose` logs as well.
render: shell
validations:
required: false

View File

@ -31,9 +31,10 @@ concurrency:
cancel-in-progress: true
jobs:
webui-setup:
name: WebUI Setup
webui-check:
name: WebUI Checks
runs-on: ubuntu-latest
continue-on-error: true
steps:
- name: Checkout code
uses: actions/checkout@v4
@ -42,137 +43,66 @@ jobs:
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Setup Node.js
id: node
uses: actions/setup-node@v4
with:
node-version: "22"
cache: "npm"
cache-dependency-path: "tools/server/webui/package-lock.json"
- name: Cache node_modules
uses: actions/cache@v4
id: cache-node-modules
with:
path: tools/server/webui/node_modules
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-modules-
- name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
id: setup
if: ${{ steps.node.conclusion == 'success' }}
run: npm ci
working-directory: tools/server/webui
webui-check:
needs: webui-setup
name: WebUI Check
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Restore node_modules cache
uses: actions/cache@v4
with:
path: tools/server/webui/node_modules
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-modules-
- name: Run type checking
if: ${{ always() && steps.setup.conclusion == 'success' }}
run: npm run check
working-directory: tools/server/webui
- name: Run linting
if: ${{ always() && steps.setup.conclusion == 'success' }}
run: npm run lint
working-directory: tools/server/webui
webui-build:
needs: webui-check
name: WebUI Build
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Restore node_modules cache
uses: actions/cache@v4
with:
path: tools/server/webui/node_modules
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-modules-
- name: Build application
if: ${{ always() && steps.setup.conclusion == 'success' }}
run: npm run build
working-directory: tools/server/webui
webui-tests:
needs: webui-build
name: Run WebUI tests
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Restore node_modules cache
uses: actions/cache@v4
with:
path: tools/server/webui/node_modules
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-modules-
- name: Install Playwright browsers
id: playwright
if: ${{ always() && steps.setup.conclusion == 'success' }}
run: npx playwright install --with-deps
working-directory: tools/server/webui
- name: Build Storybook
if: ${{ always() && steps.playwright.conclusion == 'success' }}
run: npm run build-storybook
working-directory: tools/server/webui
- name: Run Client tests
if: ${{ always() && steps.playwright.conclusion == 'success' }}
run: npm run test:client
working-directory: tools/server/webui
- name: Run Server tests
run: npm run test:server
- name: Run Unit tests
if: ${{ always() && steps.playwright.conclusion == 'success' }}
run: npm run test:unit
working-directory: tools/server/webui
- name: Run UI tests
if: ${{ always() && steps.playwright.conclusion == 'success' }}
run: npm run test:ui -- --testTimeout=60000
working-directory: tools/server/webui
- name: Run E2E tests
if: ${{ always() && steps.playwright.conclusion == 'success' }}
run: npm run test:e2e
working-directory: tools/server/webui
server-build:
needs: [webui-tests]
runs-on: ubuntu-latest
strategy:

View File

@ -32,7 +32,7 @@
/examples/export-docs/ @ggerganov
/examples/gen-docs/ @ggerganov
/examples/gguf/ @ggerganov
/examples/llama.android/ @ggerganov
/examples/llama.android/ @ggerganov @hanyin-arm @naco-siren
/examples/llama.swiftui/ @ggerganov
/examples/llama.vim @ggerganov
/examples/lookahead/ @ggerganov

View File

@ -190,6 +190,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
- Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma)
- Android: [llama.android](/examples/llama.android)
</details>

View File

@ -420,6 +420,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}
};
std::set<std::string> seen_args;
for (int i = 1; i < argc; i++) {
const std::string arg_prefix = "--";
@ -430,6 +432,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
if (arg_to_options.find(arg) == arg_to_options.end()) {
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
auto & tmp = arg_to_options[arg];
auto opt = *tmp.first;
bool is_positive = tmp.second;
@ -750,6 +755,8 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
}
};
std::set<std::string> seen_args;
for (int i = 1; i < argc; i++) {
const std::string arg_prefix = "--";
@ -760,6 +767,9 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
if (arg_to_options.find(arg) == arg_to_options.end()) {
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
auto opt = *arg_to_options[arg];
std::string val;
if (opt.value_hint != nullptr) {
@ -1140,7 +1150,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.kv_unified = true;
}
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER}));
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY}));
add_opt(common_arg(
{"--context-shift"},
{"--no-context-shift"},
@ -1226,13 +1236,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION}));
add_opt(common_arg(
{"--in-file"}, "FNAME",
"an input file (repeat to specify multiple files)",
"an input file (use comma-separated values to specify multiple files)",
[](common_params & params, const std::string & value) {
std::ifstream file(value);
for (const auto & item : string_split<std::string>(value, ',')) {
std::ifstream file(item);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
}
params.in_files.push_back(item);
}
params.in_files.push_back(value);
}
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(common_arg(
@ -1969,9 +1981,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_OFFLOAD"));
add_opt(common_arg(
{"--image", "--audio"}, "FILE",
"path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n",
"path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n",
[](common_params & params, const std::string & value) {
params.image.emplace_back(value);
for (const auto & item : string_split<std::string>(value, ',')) {
params.image.emplace_back(item);
}
}
).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
@ -2218,12 +2232,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
));
add_opt(common_arg(
{"--override-kv"}, "KEY=TYPE:VALUE",
"advanced option to override model metadata by key. may be specified multiple times.\n"
"types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false",
{"--override-kv"}, "KEY=TYPE:VALUE,...",
"advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n"
"types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false",
[](common_params & params, const std::string & value) {
if (!string_parse_kv_override(value.c_str(), params.kv_overrides)) {
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", value.c_str()));
std::vector<std::string> kv_overrides;
std::string current;
bool escaping = false;
for (const char c : value) {
if (escaping) {
current.push_back(c);
escaping = false;
} else if (c == '\\') {
escaping = true;
} else if (c == ',') {
kv_overrides.push_back(current);
current.clear();
} else {
current.push_back(c);
}
}
if (escaping) {
current.push_back('\\');
}
kv_overrides.push_back(current);
for (const auto & kv_override : kv_overrides) {
if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) {
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str()));
}
}
}
));
@ -2237,33 +2278,50 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
));
add_opt(common_arg(
{"--lora"}, "FNAME",
"path to LoRA adapter (can be repeated to use multiple adapters)",
"path to LoRA adapter (use comma-separated values to load multiple adapters)",
[](common_params & params, const std::string & value) {
params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
for (const auto & item : string_split<std::string>(value, ',')) {
params.lora_adapters.push_back({ item, 1.0, "", "", nullptr });
}
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(common_arg(
{"--lora-scaled"}, "FNAME", "SCALE",
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & fname, const std::string & scale) {
params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
{"--lora-scaled"}, "FNAME:SCALE,...",
"path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n"
"note: use comma-separated values",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
auto parts = string_split<std::string>(item, ':');
if (parts.size() != 2) {
throw std::invalid_argument("lora-scaled format: FNAME:SCALE");
}
params.lora_adapters.push_back({ parts[0], std::stof(parts[1]), "", "", nullptr });
}
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(common_arg(
{"--control-vector"}, "FNAME",
"add a control vector\nnote: this argument can be repeated to add multiple control vectors",
"add a control vector\nnote: use comma-separated values to add multiple control vectors",
[](common_params & params, const std::string & value) {
params.control_vectors.push_back({ 1.0f, value, });
for (const auto & item : string_split<std::string>(value, ',')) {
params.control_vectors.push_back({ 1.0f, item, });
}
}
));
add_opt(common_arg(
{"--control-vector-scaled"}, "FNAME", "SCALE",
{"--control-vector-scaled"}, "FNAME:SCALE,...",
"add a control vector with user defined scaling SCALE\n"
"note: this argument can be repeated to add multiple scaled control vectors",
[](common_params & params, const std::string & fname, const std::string & scale) {
params.control_vectors.push_back({ std::stof(scale), fname });
"note: use comma-separated values (format: FNAME:SCALE,...)",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
auto parts = string_split<std::string>(item, ':');
if (parts.size() != 2) {
throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE");
}
params.control_vectors.push_back({ std::stof(parts[1]), parts[0] });
}
}
));
add_opt(common_arg(
@ -2353,13 +2411,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_env("HF_TOKEN"));
add_opt(common_arg(
{"--context-file"}, "FNAME",
"file to load context from (repeat to specify multiple files)",
"file to load context from (use comma-separated values to specify multiple files)",
[](common_params & params, const std::string & value) {
std::ifstream file(value, std::ios::binary);
for (const auto & item : string_split<std::string>(value, ',')) {
std::ifstream file(item, std::ios::binary);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
}
params.context_files.push_back(item);
}
params.context_files.push_back(value);
}
).set_examples({LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg(

View File

@ -1092,7 +1092,7 @@ common_init_result::common_init_result(common_params & params) :
auto cparams = common_context_params_to_llama(params);
if (params.fit_params) {
LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__);
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);

View File

@ -104,10 +104,9 @@ struct ring_buffer {
struct common_sampler {
common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * chain;
bool grammar;
ring_buffer<llama_token> prev;
std::vector<llama_token_data> cur;
@ -167,15 +166,14 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
lparams.no_perf = params.no_perf;
llama_sampler * grmr = nullptr;
llama_sampler * chain = llama_sampler_chain_init(lparams);
bool grammar = false;
std::vector<llama_sampler *> samplers;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
grammar = true;
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
@ -224,15 +222,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
if (!params.grammar.empty()) {
if (params.grammar_lazy) {
samplers.push_back(
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size()));
trigger_tokens.data(), trigger_tokens.size());
} else {
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}
grammar = true;
}
}
@ -303,8 +298,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ chain,
/* .grammar = */ grammar,
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
@ -315,6 +310,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
delete gsmpl;
@ -324,24 +320,11 @@ void common_sampler_free(struct common_sampler * gsmpl) {
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
const auto tm = gsmpl->tm();
if (gsmpl->grammar) {
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
if (gsmpl->grmr && accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token);
}
for (int i = 0; i < n_smpl; i++) {
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
// the grammar sampler is always the first one
if (i == 0) {
if (accept_grammar) {
llama_sampler_accept(smpl, token);
}
} else {
llama_sampler_accept(smpl, token);
}
}
} else {
llama_sampler_accept(gsmpl->chain, token);
}
gsmpl->prev.push_back(token);
}
@ -353,8 +336,8 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .grammar = */ gsmpl->grammar,
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
@ -410,7 +393,7 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
return gsmpl->chain;
}
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
@ -418,11 +401,42 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
llama_token id = LLAMA_TOKEN_NULL;
auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
gsmpl->set_logits(ctx, idx);
if (grammar_first) {
llama_sampler_apply(grmr, &cur_p);
}
llama_sampler_apply(chain, &cur_p);
id = cur_p.data[cur_p.selected].id;
if (grammar_first) {
return id;
}
// check if it the sampled token fits the grammar (grammar-based rejection sampling)
{
llama_token_data single_token_data = { id, 1.0f, 0.0f };
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
llama_sampler_apply(grmr, &single_token_data_array);
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) {
return id;
}
}
// resampling:
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
gsmpl->set_logits(ctx, idx);
llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
@ -432,7 +446,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return id;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
@ -440,7 +454,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
@ -452,7 +466,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
}
if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
@ -462,13 +476,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
return result;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {

View File

@ -57,7 +57,10 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
// - check if the token fits the grammar (if any)
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
// if grammar_first is true, the grammar is applied before the samplers (slower)
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
// generalized version of common_sampler_sample
//
@ -75,10 +78,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

View File

@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft(
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);
common_sampler_sample(smpl, ctx_dft, 0);
common_sampler_sample(smpl, ctx_dft, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl, true);

View File

@ -1,6 +1,26 @@
# Android
## Build with Android Studio
Import the `examples/llama.android` directory into Android Studio, then perform a Gradle sync and build the project.
![Project imported into Android Studio](./android/imported-into-android-studio.png)
This Android binding supports hardware acceleration up to `SME2` for **Arm** and `AMX` for **x86-64** CPUs on Android and ChromeOS devices.
It automatically detects the host's hardware to load compatible kernels. As a result, it runs seamlessly on both the latest premium devices and older devices that may lack modern CPU features or have limited RAM, without requiring any manual configuration.
A minimal Android app frontend is included to showcase the bindings core functionalities:
1. **Parse GGUF metadata** via `GgufMetadataReader` from either a `ContentResolver` provided `Uri` or a local `File`.
2. **Obtain a `TierDetection` or `InferenceEngine`** instance through the high-level facade APIs.
3. **Send a raw user prompt** for automatic template formatting, prefill, and decoding. Then collect the generated tokens in a Kotlin `Flow`.
For a production-ready experience that leverages advanced features such as system prompts and benchmarks, check out [Arm AI Chat](https://play.google.com/store/apps/details?id=com.arm.aichat) on Google Play.
This project is made possible through a collaborative effort by Arm's **CT-ML**, **CE-ML** and **STE** groups:
| ![Home screen](./android/arm-ai-chat-home-screen.png) | ![System prompt](./android/system-prompt-setup.png) | !["Haiku"](./android/chat-with-system-prompt-haiku.png) |
|:------------------------------------------------------:|:----------------------------------------------------:|:--------------------------------------------------------:|
| Home screen | System prompt | "Haiku" |
## Build on Android using Termux
[Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid.

View File

@ -1,16 +1,18 @@
plugins {
id("com.android.application")
id("org.jetbrains.kotlin.android")
alias(libs.plugins.android.application)
alias(libs.plugins.jetbrains.kotlin.android)
}
android {
namespace = "com.example.llama"
compileSdk = 34
compileSdk = 36
defaultConfig {
applicationId = "com.example.llama"
applicationId = "com.example.llama.aichat"
minSdk = 33
targetSdk = 34
targetSdk = 36
versionCode = 1
versionName = "1.0"
@ -21,8 +23,17 @@ android {
}
buildTypes {
debug {
isMinifyEnabled = true
isShrinkResources = true
proguardFiles(
getDefaultProguardFile("proguard-android.txt"),
"proguard-rules.pro"
)
}
release {
isMinifyEnabled = false
isMinifyEnabled = true
isShrinkResources = true
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
@ -36,30 +47,15 @@ android {
kotlinOptions {
jvmTarget = "1.8"
}
buildFeatures {
compose = true
}
composeOptions {
kotlinCompilerExtensionVersion = "1.5.1"
}
}
dependencies {
implementation(libs.bundles.androidx)
implementation(libs.material)
implementation("androidx.core:core-ktx:1.12.0")
implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.2")
implementation("androidx.activity:activity-compose:1.8.2")
implementation(platform("androidx.compose:compose-bom:2023.08.00"))
implementation("androidx.compose.ui:ui")
implementation("androidx.compose.ui:ui-graphics")
implementation("androidx.compose.ui:ui-tooling-preview")
implementation("androidx.compose.material3:material3")
implementation(project(":llama"))
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
androidTestImplementation(platform("androidx.compose:compose-bom:2023.08.00"))
androidTestImplementation("androidx.compose.ui:ui-test-junit4")
debugImplementation("androidx.compose.ui:ui-tooling")
debugImplementation("androidx.compose.ui:ui-test-manifest")
implementation(project(":lib"))
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)
}

View File

@ -19,3 +19,11 @@
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
-keep class com.arm.aichat.* { *; }
-keep class com.arm.aichat.gguf.* { *; }
-assumenosideeffects class android.util.Log {
public static int v(...);
public static int d(...);
}

View File

@ -1,24 +1,21 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">
<uses-permission android:name="android.permission.INTERNET" />
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
<application
android:allowBackup="true"
android:dataExtractionRules="@xml/data_extraction_rules"
android:extractNativeLibs="true"
android:fullBackupContent="@xml/backup_rules"
android:icon="@mipmap/ic_launcher"
android:icon="@mipmap/ic_launcher_round"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/Theme.LlamaAndroid"
android:theme="@style/Theme.AiChatSample"
>
<activity
android:name=".MainActivity"
android:exported="true"
android:theme="@style/Theme.LlamaAndroid">
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />

View File

@ -1,119 +0,0 @@
package com.example.llama
import android.app.DownloadManager
import android.net.Uri
import android.util.Log
import androidx.compose.material3.Button
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableDoubleStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.core.database.getLongOrNull
import androidx.core.net.toUri
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import java.io.File
data class Downloadable(val name: String, val source: Uri, val destination: File) {
companion object {
@JvmStatic
private val tag: String? = this::class.qualifiedName
sealed interface State
data object Ready: State
data class Downloading(val id: Long): State
data class Downloaded(val downloadable: Downloadable): State
data class Error(val message: String): State
@JvmStatic
@Composable
fun Button(viewModel: MainViewModel, dm: DownloadManager, item: Downloadable) {
var status: State by remember {
mutableStateOf(
if (item.destination.exists()) Downloaded(item)
else Ready
)
}
var progress by remember { mutableDoubleStateOf(0.0) }
val coroutineScope = rememberCoroutineScope()
suspend fun waitForDownload(result: Downloading, item: Downloadable): State {
while (true) {
val cursor = dm.query(DownloadManager.Query().setFilterById(result.id))
if (cursor == null) {
Log.e(tag, "dm.query() returned null")
return Error("dm.query() returned null")
}
if (!cursor.moveToFirst() || cursor.count < 1) {
cursor.close()
Log.i(tag, "cursor.moveToFirst() returned false or cursor.count < 1, download canceled?")
return Ready
}
val pix = cursor.getColumnIndex(DownloadManager.COLUMN_BYTES_DOWNLOADED_SO_FAR)
val tix = cursor.getColumnIndex(DownloadManager.COLUMN_TOTAL_SIZE_BYTES)
val sofar = cursor.getLongOrNull(pix) ?: 0
val total = cursor.getLongOrNull(tix) ?: 1
cursor.close()
if (sofar == total) {
return Downloaded(item)
}
progress = (sofar * 1.0) / total
delay(1000L)
}
}
fun onClick() {
when (val s = status) {
is Downloaded -> {
viewModel.load(item.destination.path)
}
is Downloading -> {
coroutineScope.launch {
status = waitForDownload(s, item)
}
}
else -> {
item.destination.delete()
val request = DownloadManager.Request(item.source).apply {
setTitle("Downloading model")
setDescription("Downloading model: ${item.name}")
setAllowedNetworkTypes(DownloadManager.Request.NETWORK_WIFI)
setDestinationUri(item.destination.toUri())
}
viewModel.log("Saving ${item.name} to ${item.destination.path}")
Log.i(tag, "Saving ${item.name} to ${item.destination.path}")
val id = dm.enqueue(request)
status = Downloading(id)
onClick()
}
}
}
Button(onClick = { onClick() }, enabled = status !is Downloading) {
when (status) {
is Downloading -> Text(text = "Downloading ${(progress * 100).toInt()}%")
is Downloaded -> Text("Load ${item.name}")
is Ready -> Text("Download ${item.name}")
is Error -> Text("Download ${item.name}")
}
}
}
}
}

View File

@ -1,154 +1,257 @@
package com.example.llama
import android.app.ActivityManager
import android.app.DownloadManager
import android.content.ClipData
import android.content.ClipboardManager
import android.net.Uri
import android.os.Bundle
import android.os.StrictMode
import android.os.StrictMode.VmPolicy
import android.text.format.Formatter
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.activity.viewModels
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.material3.Button
import androidx.compose.material3.LocalContentColor
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
import androidx.core.content.getSystemService
import com.example.llama.ui.theme.LlamaAndroidTheme
import android.util.Log
import android.widget.EditText
import android.widget.TextView
import android.widget.Toast
import androidx.activity.enableEdgeToEdge
import androidx.activity.result.contract.ActivityResultContracts
import androidx.appcompat.app.AppCompatActivity
import androidx.lifecycle.lifecycleScope
import androidx.recyclerview.widget.LinearLayoutManager
import androidx.recyclerview.widget.RecyclerView
import com.arm.aichat.AiChat
import com.arm.aichat.InferenceEngine
import com.arm.aichat.gguf.GgufMetadata
import com.arm.aichat.gguf.GgufMetadataReader
import com.google.android.material.floatingactionbutton.FloatingActionButton
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
import java.io.FileOutputStream
import java.io.InputStream
import java.util.UUID
class MainActivity(
activityManager: ActivityManager? = null,
downloadManager: DownloadManager? = null,
clipboardManager: ClipboardManager? = null,
): ComponentActivity() {
private val tag: String? = this::class.simpleName
class MainActivity : AppCompatActivity() {
private val activityManager by lazy { activityManager ?: getSystemService<ActivityManager>()!! }
private val downloadManager by lazy { downloadManager ?: getSystemService<DownloadManager>()!! }
private val clipboardManager by lazy { clipboardManager ?: getSystemService<ClipboardManager>()!! }
// Android views
private lateinit var ggufTv: TextView
private lateinit var messagesRv: RecyclerView
private lateinit var userInputEt: EditText
private lateinit var userActionFab: FloatingActionButton
private val viewModel: MainViewModel by viewModels()
// Arm AI Chat inference engine
private lateinit var engine: InferenceEngine
// Get a MemoryInfo object for the device's current memory status.
private fun availableMemory(): ActivityManager.MemoryInfo {
return ActivityManager.MemoryInfo().also { memoryInfo ->
activityManager.getMemoryInfo(memoryInfo)
}
}
// Conversation states
private var isModelReady = false
private val messages = mutableListOf<Message>()
private val lastAssistantMsg = StringBuilder()
private val messageAdapter = MessageAdapter(messages)
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
enableEdgeToEdge()
setContentView(R.layout.activity_main)
StrictMode.setVmPolicy(
VmPolicy.Builder(StrictMode.getVmPolicy())
.detectLeakedClosableObjects()
.build()
)
// Find views
ggufTv = findViewById(R.id.gguf)
messagesRv = findViewById(R.id.messages)
messagesRv.layoutManager = LinearLayoutManager(this)
messagesRv.adapter = messageAdapter
userInputEt = findViewById(R.id.user_input)
userActionFab = findViewById(R.id.fab)
val free = Formatter.formatFileSize(this, availableMemory().availMem)
val total = Formatter.formatFileSize(this, availableMemory().totalMem)
viewModel.log("Current memory: $free / $total")
viewModel.log("Downloads directory: ${getExternalFilesDir(null)}")
val extFilesDir = getExternalFilesDir(null)
val models = listOf(
Downloadable(
"Phi-2 7B (Q4_0, 1.6 GiB)",
Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true"),
File(extFilesDir, "phi-2-q4_0.gguf"),
),
Downloadable(
"TinyLlama 1.1B (f16, 2.2 GiB)",
Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true"),
File(extFilesDir, "tinyllama-1.1-f16.gguf"),
),
Downloadable(
"Phi 2 DPO (Q3_K_M, 1.48 GiB)",
Uri.parse("https://huggingface.co/TheBloke/phi-2-dpo-GGUF/resolve/main/phi-2-dpo.Q3_K_M.gguf?download=true"),
File(extFilesDir, "phi-2-dpo.Q3_K_M.gguf")
),
)
setContent {
LlamaAndroidTheme {
// A surface container using the 'background' color from the theme
Surface(
modifier = Modifier.fillMaxSize(),
color = MaterialTheme.colorScheme.background
) {
MainCompose(
viewModel,
clipboardManager,
downloadManager,
models,
)
// Arm AI Chat initialization
lifecycleScope.launch(Dispatchers.Default) {
engine = AiChat.getInferenceEngine(applicationContext)
}
// Upon CTA button tapped
userActionFab.setOnClickListener {
if (isModelReady) {
// If model is ready, validate input and send to engine
handleUserInput()
} else {
// Otherwise, prompt user to select a GGUF metadata on the device
getContent.launch(arrayOf("*/*"))
}
}
}
private val getContent = registerForActivityResult(
ActivityResultContracts.OpenDocument()
) { uri ->
Log.i(TAG, "Selected file uri:\n $uri")
uri?.let { handleSelectedModel(it) }
}
/**
* Handles the file Uri from [getContent] result
*/
private fun handleSelectedModel(uri: Uri) {
// Update UI states
userActionFab.isEnabled = false
userInputEt.hint = "Parsing GGUF..."
ggufTv.text = "Parsing metadata from selected file \n$uri"
lifecycleScope.launch(Dispatchers.IO) {
// Parse GGUF metadata
Log.i(TAG, "Parsing GGUF metadata...")
contentResolver.openInputStream(uri)?.use {
GgufMetadataReader.create().readStructuredMetadata(it)
}?.let { metadata ->
// Update UI to show GGUF metadata to user
Log.i(TAG, "GGUF parsed: \n$metadata")
withContext(Dispatchers.Main) {
ggufTv.text = metadata.toString()
}
// Ensure the model file is available
val modelName = metadata.filename() + FILE_EXTENSION_GGUF
contentResolver.openInputStream(uri)?.use { input ->
ensureModelFile(modelName, input)
}?.let { modelFile ->
loadModel(modelName, modelFile)
withContext(Dispatchers.Main) {
isModelReady = true
userInputEt.hint = "Type and send a message!"
userInputEt.isEnabled = true
userActionFab.setImageResource(R.drawable.outline_send_24)
userActionFab.isEnabled = true
}
}
}
}
}
@Composable
fun MainCompose(
viewModel: MainViewModel,
clipboard: ClipboardManager,
dm: DownloadManager,
models: List<Downloadable>
) {
Column {
val scrollState = rememberLazyListState()
Box(modifier = Modifier.weight(1f)) {
LazyColumn(state = scrollState) {
items(viewModel.messages) {
Text(
it,
style = MaterialTheme.typography.bodyLarge.copy(color = LocalContentColor.current),
modifier = Modifier.padding(16.dp)
)
}
}
}
OutlinedTextField(
value = viewModel.message,
onValueChange = { viewModel.updateMessage(it) },
label = { Text("Message") },
)
Row {
Button({ viewModel.send() }) { Text("Send") }
Button({ viewModel.bench(8, 4, 1) }) { Text("Bench") }
Button({ viewModel.clear() }) { Text("Clear") }
Button({
viewModel.messages.joinToString("\n").let {
clipboard.setPrimaryClip(ClipData.newPlainText("", it))
}
}) { Text("Copy") }
/**
* Prepare the model file within app's private storage
*/
private suspend fun ensureModelFile(modelName: String, input: InputStream) =
withContext(Dispatchers.IO) {
File(ensureModelsDirectory(), modelName).also { file ->
// Copy the file into local storage if not yet done
if (!file.exists()) {
Log.i(TAG, "Start copying file to $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Copying file..."
}
Column {
for (model in models) {
Downloadable.Button(viewModel, dm, model)
FileOutputStream(file).use { input.copyTo(it) }
Log.i(TAG, "Finished copying file to $modelName")
} else {
Log.i(TAG, "File already exists $modelName")
}
}
}
/**
* Load the model file from the app private storage
*/
private suspend fun loadModel(modelName: String, modelFile: File) =
withContext(Dispatchers.IO) {
Log.i(TAG, "Loading model $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Loading model..."
}
engine.loadModel(modelFile.path)
}
/**
* Validate and send the user message into [InferenceEngine]
*/
private fun handleUserInput() {
userInputEt.text.toString().also { userSsg ->
if (userSsg.isEmpty()) {
Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
} else {
userInputEt.text = null
userActionFab.isEnabled = false
// Update message states
messages.add(Message(UUID.randomUUID().toString(), userSsg, true))
lastAssistantMsg.clear()
messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
lifecycleScope.launch(Dispatchers.Default) {
engine.sendUserPrompt(userSsg)
.onCompletion {
withContext(Dispatchers.Main) {
userActionFab.isEnabled = true
}
}.collect { token ->
val messageCount = messages.size
check(messageCount > 0 && !messages[messageCount - 1].isUser)
messages.removeAt(messageCount - 1).copy(
content = lastAssistantMsg.append(token).toString()
).let { messages.add(it) }
withContext(Dispatchers.Main) {
messageAdapter.notifyItemChanged(messages.size - 1)
}
}
}
}
}
}
/**
* Run a benchmark with the model file
*/
private suspend fun runBenchmark(modelName: String, modelFile: File) =
withContext(Dispatchers.Default) {
Log.i(TAG, "Starts benchmarking $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Running benchmark..."
}
engine.bench(
pp=BENCH_PROMPT_PROCESSING_TOKENS,
tg=BENCH_TOKEN_GENERATION_TOKENS,
pl=BENCH_SEQUENCE,
nr=BENCH_REPETITION
).let { result ->
messages.add(Message(UUID.randomUUID().toString(), result, false))
withContext(Dispatchers.Main) {
messageAdapter.notifyItemChanged(messages.size - 1)
}
}
}
/**
* Create the `models` directory if not exist.
*/
private fun ensureModelsDirectory() =
File(filesDir, DIRECTORY_MODELS).also {
if (it.exists() && !it.isDirectory) { it.delete() }
if (!it.exists()) { it.mkdir() }
}
companion object {
private val TAG = MainActivity::class.java.simpleName
private const val DIRECTORY_MODELS = "models"
private const val FILE_EXTENSION_GGUF = ".gguf"
private const val BENCH_PROMPT_PROCESSING_TOKENS = 512
private const val BENCH_TOKEN_GENERATION_TOKENS = 128
private const val BENCH_SEQUENCE = 1
private const val BENCH_REPETITION = 3
}
}
fun GgufMetadata.filename() = when {
basic.name != null -> {
basic.name?.let { name ->
basic.sizeLabel?.let { size ->
"$name-$size"
} ?: name
}
}
architecture?.architecture != null -> {
architecture?.architecture?.let { arch ->
basic.uuid?.let { uuid ->
"$arch-$uuid"
} ?: "$arch-${System.currentTimeMillis()}"
}
}
else -> {
"model-${System.currentTimeMillis().toHexString()}"
}
}

View File

@ -1,105 +0,0 @@
package com.example.llama
import android.llama.cpp.LLamaAndroid
import android.util.Log
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.setValue
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch
class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() {
companion object {
@JvmStatic
private val NanosPerSecond = 1_000_000_000.0
}
private val tag: String? = this::class.simpleName
var messages by mutableStateOf(listOf("Initializing..."))
private set
var message by mutableStateOf("")
private set
override fun onCleared() {
super.onCleared()
viewModelScope.launch {
try {
llamaAndroid.unload()
} catch (exc: IllegalStateException) {
messages += exc.message!!
}
}
}
fun send() {
val text = message
message = ""
// Add to messages console.
messages += text
messages += ""
viewModelScope.launch {
llamaAndroid.send(text)
.catch {
Log.e(tag, "send() failed", it)
messages += it.message!!
}
.collect { messages = messages.dropLast(1) + (messages.last() + it) }
}
}
fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) {
viewModelScope.launch {
try {
val start = System.nanoTime()
val warmupResult = llamaAndroid.bench(pp, tg, pl, nr)
val end = System.nanoTime()
messages += warmupResult
val warmup = (end - start).toDouble() / NanosPerSecond
messages += "Warm up time: $warmup seconds, please wait..."
if (warmup > 5.0) {
messages += "Warm up took too long, aborting benchmark"
return@launch
}
messages += llamaAndroid.bench(512, 128, 1, 3)
} catch (exc: IllegalStateException) {
Log.e(tag, "bench() failed", exc)
messages += exc.message!!
}
}
}
fun load(pathToModel: String) {
viewModelScope.launch {
try {
llamaAndroid.load(pathToModel)
messages += "Loaded $pathToModel"
} catch (exc: IllegalStateException) {
Log.e(tag, "load() failed", exc)
messages += exc.message!!
}
}
}
fun updateMessage(newMessage: String) {
message = newMessage
}
fun clear() {
messages = listOf()
}
fun log(message: String) {
messages += message
}
}

View File

@ -0,0 +1,51 @@
package com.example.llama
import android.view.LayoutInflater
import android.view.View
import android.view.ViewGroup
import android.widget.TextView
import androidx.recyclerview.widget.RecyclerView
data class Message(
val id: String,
val content: String,
val isUser: Boolean
)
class MessageAdapter(
private val messages: List<Message>
) : RecyclerView.Adapter<RecyclerView.ViewHolder>() {
companion object {
private const val VIEW_TYPE_USER = 1
private const val VIEW_TYPE_ASSISTANT = 2
}
override fun getItemViewType(position: Int): Int {
return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT
}
override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder {
val layoutInflater = LayoutInflater.from(parent.context)
return if (viewType == VIEW_TYPE_USER) {
val view = layoutInflater.inflate(R.layout.item_message_user, parent, false)
UserMessageViewHolder(view)
} else {
val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false)
AssistantMessageViewHolder(view)
}
}
override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) {
val message = messages[position]
if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) {
val textView = holder.itemView.findViewById<TextView>(R.id.msg_content)
textView.text = message.content
}
}
override fun getItemCount(): Int = messages.size
class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
}

View File

@ -1,11 +0,0 @@
package com.example.llama.ui.theme
import androidx.compose.ui.graphics.Color
val Purple80 = Color(0xFFD0BCFF)
val PurpleGrey80 = Color(0xFFCCC2DC)
val Pink80 = Color(0xFFEFB8C8)
val Purple40 = Color(0xFF6650a4)
val PurpleGrey40 = Color(0xFF625b71)
val Pink40 = Color(0xFF7D5260)

View File

@ -1,70 +0,0 @@
package com.example.llama.ui.theme
import android.app.Activity
import android.os.Build
import androidx.compose.foundation.isSystemInDarkTheme
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.darkColorScheme
import androidx.compose.material3.dynamicDarkColorScheme
import androidx.compose.material3.dynamicLightColorScheme
import androidx.compose.material3.lightColorScheme
import androidx.compose.runtime.Composable
import androidx.compose.runtime.SideEffect
import androidx.compose.ui.graphics.toArgb
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalView
import androidx.core.view.WindowCompat
private val DarkColorScheme = darkColorScheme(
primary = Purple80,
secondary = PurpleGrey80,
tertiary = Pink80
)
private val LightColorScheme = lightColorScheme(
primary = Purple40,
secondary = PurpleGrey40,
tertiary = Pink40
/* Other default colors to override
background = Color(0xFFFFFBFE),
surface = Color(0xFFFFFBFE),
onPrimary = Color.White,
onSecondary = Color.White,
onTertiary = Color.White,
onBackground = Color(0xFF1C1B1F),
onSurface = Color(0xFF1C1B1F),
*/
)
@Composable
fun LlamaAndroidTheme(
darkTheme: Boolean = isSystemInDarkTheme(),
// Dynamic color is available on Android 12+
dynamicColor: Boolean = true,
content: @Composable () -> Unit
) {
val colorScheme = when {
dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {
val context = LocalContext.current
if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)
}
darkTheme -> DarkColorScheme
else -> LightColorScheme
}
val view = LocalView.current
if (!view.isInEditMode) {
SideEffect {
val window = (view.context as Activity).window
window.statusBarColor = colorScheme.primary.toArgb()
WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme
}
}
MaterialTheme(
colorScheme = colorScheme,
typography = Typography,
content = content
)
}

View File

@ -1,34 +0,0 @@
package com.example.llama.ui.theme
import androidx.compose.material3.Typography
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.font.FontFamily
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.sp
// Set of Material typography styles to start with
val Typography = Typography(
bodyLarge = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Normal,
fontSize = 16.sp,
lineHeight = 24.sp,
letterSpacing = 0.5.sp
)
/* Other default text styles to override
titleLarge = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Normal,
fontSize = 22.sp,
lineHeight = 28.sp,
letterSpacing = 0.sp
),
labelSmall = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Medium,
fontSize = 11.sp,
lineHeight = 16.sp,
letterSpacing = 0.5.sp
)
*/
)

View File

@ -0,0 +1,4 @@
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
<solid android:color="#E5E5EA" />
<corners android:radius="16dp" />
</shape>

View File

@ -0,0 +1,4 @@
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
<solid android:color="#4285F4" />
<corners android:radius="16dp" />
</shape>

View File

@ -0,0 +1,10 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal">
<path
android:fillColor="@android:color/white"
android:pathData="M20,6h-8l-2,-2L4,4c-1.1,0 -1.99,0.9 -1.99,2L2,18c0,1.1 0.9,2 2,2h16c1.1,0 2,-0.9 2,-2L22,8c0,-1.1 -0.9,-2 -2,-2zM20,18L4,18L4,8h16v10z"/>
</vector>

View File

@ -0,0 +1,11 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal"
android:autoMirrored="true">
<path
android:fillColor="@android:color/white"
android:pathData="M4.01,6.03l7.51,3.22 -7.52,-1 0.01,-2.22m7.5,8.72L4,17.97v-2.22l7.51,-1M2.01,3L2,10l15,2 -15,2 0.01,7L23,12 2.01,3z"/>
</vector>

View File

@ -0,0 +1,76 @@
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:id="@+id/main"
android:layout_height="match_parent"
android:layout_width="match_parent">
<LinearLayout
android:fitsSystemWindows="true"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
tools:context=".MainActivity">
<FrameLayout
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="1">
<ScrollView
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:fadeScrollbars="false">
<TextView
android:id="@+id/gguf"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_margin="16dp"
android:text="Selected GGUF model's metadata will show here."
style="@style/TextAppearance.MaterialComponents.Body2"
android:maxLines="100" />
</ScrollView>
</FrameLayout>
<androidx.recyclerview.widget.RecyclerView
android:id="@+id/messages"
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="4"
android:padding="16dp"
android:fadeScrollbars="false"
app:reverseLayout="true"
tools:listitem="@layout/item_message_assistant"/>
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:orientation="horizontal">
<EditText
android:id="@+id/user_input"
android:enabled="false"
android:layout_width="0dp"
android:layout_weight="1"
android:layout_height="match_parent"
android:padding="8dp"
style="@style/TextAppearance.MaterialComponents.Body2"
android:hint="Please first pick a GGUF model file to import." />
<com.google.android.material.floatingactionbutton.FloatingActionButton
android:id="@+id/fab"
android:enabled="true"
style="@style/Widget.Material3.FloatingActionButton.Primary"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_margin="8dp"
android:src="@drawable/outline_folder_open_24" />
</LinearLayout>
</LinearLayout>
</androidx.constraintlayout.widget.ConstraintLayout>

View File

@ -0,0 +1,15 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:padding="8dp"
android:gravity="start">
<TextView
android:id="@+id/msg_content"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:background="@drawable/bg_assistant_message"
android:padding="12dp"
android:textColor="@android:color/black" />
</LinearLayout>

View File

@ -0,0 +1,15 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:padding="8dp"
android:gravity="end">
<TextView
android:id="@+id/msg_content"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:background="@drawable/bg_user_message"
android:padding="12dp"
android:textColor="@android:color/white" />
</LinearLayout>

View File

@ -1,3 +1,3 @@
<resources>
<string name="app_name">LlamaAndroid</string>
<string name="app_name">AI Chat basic sample</string>
</resources>

View File

@ -1,5 +1,10 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<style name="Theme.LlamaAndroid" parent="android:Theme.Material.Light.NoActionBar" />
<style name="Base.Theme.AiChatSample" parent="Theme.Material3.DayNight.NoActionBar">
<!-- Customize your light theme here. -->
<!-- <item name="colorPrimary">@color/my_light_primary</item> -->
</style>
<style name="Theme.AiChatSample" parent="Base.Theme.AiChatSample" />
</resources>

View File

@ -1,6 +1,6 @@
// Top-level build file where you can add configuration options common to all sub-projects/modules.
plugins {
id("com.android.application") version "8.2.0" apply false
id("org.jetbrains.kotlin.android") version "1.9.0" apply false
id("com.android.library") version "8.2.0" apply false
alias(libs.plugins.android.application) apply false
alias(libs.plugins.android.library) apply false
alias(libs.plugins.jetbrains.kotlin.android) apply false
}

View File

@ -21,3 +21,4 @@ kotlin.code.style=official
# resources declared in the library itself and none from the library's dependencies,
# thereby reducing the size of the R class for that library
android.nonTransitiveRClass=true
android.native.buildOutput=verbose

View File

@ -0,0 +1,53 @@
[versions]
# Plugins
agp = "8.13.0"
kotlin = "2.2.20"
# AndroidX
activity = "1.11.0"
appcompat = "1.7.1"
core-ktx = "1.17.0"
constraint-layout = "2.2.1"
datastore-preferences = "1.1.7"
# Material
material = "1.13.0"
# Testing
espresso-core = "3.7.0"
androidx-junit = "1.3.0"
junit = "4.13.2"
[plugins]
android-application = { id = "com.android.application", version.ref = "agp" }
android-library = { id = "com.android.library", version.ref = "agp" }
jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
[libraries]
# AndroidX
androidx-activity = { group = "androidx.activity", name = "activity", version.ref = "activity" }
androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" }
androidx-constraintlayout = { group = "androidx.constraintlayout", name = "constraintlayout", version.ref = "constraint-layout" }
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "core-ktx" }
androidx-datastore-preferences = { group = "androidx.datastore", name = "datastore-preferences", version.ref = "datastore-preferences" }
#Material
material = { group = "com.google.android.material", name = "material", version.ref = "material" }
# Testing
androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espresso-core" }
androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-junit" }
junit = { group = "junit", name = "junit", version.ref = "junit" }
[bundles]
androidx = [
"androidx-activity",
"androidx-appcompat",
"androidx-constraintlayout",
"androidx-core-ktx",
"androidx-datastore-preferences",
]

View File

@ -1,6 +1,6 @@
#Thu Dec 21 14:31:09 AEDT 2023
#Tue Apr 01 11:15:06 PDT 2025
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

View File

@ -0,0 +1,78 @@
plugins {
alias(libs.plugins.android.library)
alias(libs.plugins.jetbrains.kotlin.android)
}
android {
namespace = "com.arm.aichat"
compileSdk = 36
ndkVersion = "29.0.13113456"
defaultConfig {
minSdk = 33
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
consumerProguardFiles("consumer-rules.pro")
ndk {
abiFilters += listOf("arm64-v8a", "x86_64")
}
externalNativeBuild {
cmake {
arguments += "-DCMAKE_BUILD_TYPE=Release"
arguments += "-DCMAKE_MESSAGE_LOG_LEVEL=DEBUG"
arguments += "-DCMAKE_VERBOSE_MAKEFILE=ON"
arguments += "-DBUILD_SHARED_LIBS=ON"
arguments += "-DLLAMA_BUILD_COMMON=ON"
arguments += "-DLLAMA_CURL=OFF"
arguments += "-DGGML_NATIVE=OFF"
arguments += "-DGGML_BACKEND_DL=ON"
arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
arguments += "-DGGML_LLAMAFILE=OFF"
}
}
aarMetadata {
minCompileSdk = 35
}
}
externalNativeBuild {
cmake {
path("src/main/cpp/CMakeLists.txt")
version = "3.31.6"
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
kotlin {
jvmToolchain(17)
compileOptions {
targetCompatibility = JavaVersion.VERSION_17
}
}
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
}
}
publishing {
singleVariant("release") {
withJavadocJar()
}
}
}
dependencies {
implementation(libs.androidx.core.ktx)
implementation(libs.androidx.datastore.preferences)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
}

View File

@ -0,0 +1,8 @@
-keep class com.arm.aichat.* { *; }
-keep class com.arm.aichat.gguf.* { *; }
-keepclasseswithmembernames class * {
native <methods>;
}
-keep class kotlin.Metadata { *; }

View File

@ -0,0 +1,56 @@
cmake_minimum_required(VERSION 3.31.6)
project("ai-chat" VERSION 1.0.0 LANGUAGES C CXX)
set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED true)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED true)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE)
# --------------------------------------------------------------------------
# AI Chat library
# --------------------------------------------------------------------------
if(DEFINED ANDROID_ABI)
message(STATUS "Detected Android ABI: ${ANDROID_ABI}")
if(ANDROID_ABI STREQUAL "arm64-v8a")
set(GGML_SYSTEM_ARCH "ARM")
set(GGML_CPU_KLEIDIAI ON)
set(GGML_OPENMP ON)
elseif(ANDROID_ABI STREQUAL "x86_64")
set(GGML_SYSTEM_ARCH "x86")
set(GGML_CPU_KLEIDIAI OFF)
set(GGML_OPENMP OFF)
else()
message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}")
endif()
endif()
set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
add_subdirectory(${LLAMA_SRC} build-llama)
add_library(${CMAKE_PROJECT_NAME} SHARED
ai_chat.cpp)
target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH}
GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}>
GGML_OPENMP=$<BOOL:${GGML_OPENMP}>
)
target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
${LLAMA_SRC}
${LLAMA_SRC}/common
${LLAMA_SRC}/include
${LLAMA_SRC}/ggml/include
${LLAMA_SRC}/ggml/src)
target_link_libraries(${CMAKE_PROJECT_NAME}
llama
common
android
log)

View File

@ -0,0 +1,565 @@
#include <android/log.h>
#include <jni.h>
#include <iomanip>
#include <cmath>
#include <string>
#include <unistd.h>
#include <sampling.h>
#include "logging.h"
#include "chat.h"
#include "common.h"
#include "llama.h"
template<class T>
static std::string join(const std::vector<T> &values, const std::string &delim) {
std::ostringstream str;
for (size_t i = 0; i < values.size(); i++) {
str << values[i];
if (i < values.size() - 1) { str << delim; }
}
return str.str();
}
/**
* LLama resources: context, model, batch and sampler
*/
constexpr int N_THREADS_MIN = 2;
constexpr int N_THREADS_MAX = 4;
constexpr int N_THREADS_HEADROOM = 2;
constexpr int DEFAULT_CONTEXT_SIZE = 8192;
constexpr int OVERFLOW_HEADROOM = 4;
constexpr int BATCH_SIZE = 512;
constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
static llama_model * g_model;
static llama_context * g_context;
static llama_batch g_batch;
static common_chat_templates_ptr g_chat_templates;
static common_sampler * g_sampler;
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) {
// Set llama log handler to Android
llama_log_set(aichat_android_log_callback, nullptr);
// Loading all CPU backend variants
const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0);
LOGi("Loading backends from %s", path_to_backend);
ggml_backend_load_all_from_path(path_to_backend);
env->ReleaseStringUTFChars(nativeLibDir, path_to_backend);
// Initialize backends
llama_backend_init();
LOGi("Backend initiated; Log handler set.");
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
llama_model_params model_params = llama_model_default_params();
const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
auto *model = llama_model_load_from_file(model_path, model_params);
env->ReleaseStringUTFChars(jmodel_path, model_path);
if (!model) {
return 1;
}
g_model = model;
return 0;
}
static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
if (!model) {
LOGe("%s: model cannot be null", __func__);
return nullptr;
}
// Multi-threading setup
const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
(int) sysconf(_SC_NPROCESSORS_ONLN) -
N_THREADS_HEADROOM));
LOGi("%s: Using %d threads", __func__, n_threads);
// Context parameters setup
llama_context_params ctx_params = llama_context_default_params();
const int trained_context_size = llama_model_n_ctx_train(model);
if (n_ctx > trained_context_size) {
LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
__func__, trained_context_size, n_ctx);
}
ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = BATCH_SIZE;
ctx_params.n_ubatch = BATCH_SIZE;
ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads;
auto *context = llama_init_from_model(g_model, ctx_params);
if (context == nullptr) {
LOGe("%s: llama_new_context_with_model() returned null)", __func__);
}
return context;
}
static common_sampler *new_sampler(float temp) {
common_params_sampling sparams;
sparams.temp = temp;
return common_sampler_init(g_model, sparams);
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
auto *context = init_context(g_model);
if (!context) { return 1; }
g_context = context;
g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
g_chat_templates = common_chat_templates_init(g_model, "");
g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
return 0;
}
static std::string get_backend() {
std::vector<std::string> backends;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
auto *reg = ggml_backend_reg_get(i);
std::string name = ggml_backend_reg_name(reg);
if (name != "CPU") {
backends.push_back(ggml_backend_reg_name(reg));
}
}
return backends.empty() ? "CPU" : join(backends, ",");
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
return env->NewStringUTF(llama_print_system_info());
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
jint pl, jint nr) {
auto *context = init_context(g_model, pp);
if (!context) {
const auto *const err_msg = "Fail to init_context! Bench aborted.";
LOGe(err_msg);
return env->NewStringUTF(err_msg);
}
auto pp_avg = 0.0;
auto tg_avg = 0.0;
auto pp_std = 0.0;
auto tg_std = 0.0;
const uint32_t n_ctx = llama_n_ctx(context);
LOGi("n_ctx = %d", n_ctx);
int i, j;
int nri;
for (nri = 0; nri < nr; nri++) {
LOGi("Benchmark prompt processing (pp = %d)", pp);
common_batch_clear(g_batch);
const int n_tokens = pp;
for (i = 0; i < n_tokens; i++) {
common_batch_add(g_batch, 0, i, {0}, false);
}
g_batch.logits[g_batch.n_tokens - 1] = true;
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp_start = ggml_time_us();
if (llama_decode(context, g_batch) != 0) {
LOGe("llama_decode() failed during prompt processing");
}
const auto t_pp_end = ggml_time_us();
// bench text generation
LOGi("Benchmark text generation (tg = %d)", tg);
llama_memory_clear(llama_get_memory(context), false);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {
common_batch_clear(g_batch);
for (j = 0; j < pl; j++) {
common_batch_add(g_batch, 0, i, {j}, true);
}
if (llama_decode(context, g_batch) != 0) {
LOGe("llama_decode() failed during text generation");
}
}
const auto t_tg_end = ggml_time_us();
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
const auto speed_pp = double(pp) / t_pp;
const auto speed_tg = double(pl * tg) / t_tg;
pp_avg += speed_pp;
tg_avg += speed_tg;
pp_std += speed_pp * speed_pp;
tg_std += speed_tg * speed_tg;
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
}
llama_free(context);
pp_avg /= double(nr);
tg_avg /= double(nr);
if (nr > 1) {
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
} else {
pp_std = 0;
tg_std = 0;
}
char model_desc[128];
llama_model_desc(g_model, model_desc, sizeof(model_desc));
const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
const auto backend = get_backend();
std::stringstream result;
result << std::setprecision(3);
result << "| model | size | params | backend | test | t/s |\n";
result << "| --- | --- | --- | --- | --- | --- |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
<< backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
<< backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
return env->NewStringUTF(result.str().c_str());
}
/**
* Completion loop's long-term states:
* - chat management
* - position tracking
*/
constexpr const char *ROLE_SYSTEM = "system";
constexpr const char *ROLE_USER = "user";
constexpr const char *ROLE_ASSISTANT = "assistant";
static std::vector<common_chat_msg> chat_msgs;
static llama_pos system_prompt_position;
static llama_pos current_position;
static void reset_long_term_states(const bool clear_kv_cache = true) {
chat_msgs.clear();
system_prompt_position = 0;
current_position = 0;
if (clear_kv_cache)
llama_memory_clear(llama_get_memory(g_context), false);
}
/**
* TODO-hyin: implement sliding-window version as a better alternative
*
* Context shifting by discarding the older half of the tokens appended after system prompt:
* - take the [system_prompt_position] first tokens from the original prompt
* - take half of the last (system_prompt_position - system_prompt_position) tokens
* - recompute the logits in batches
*/
static void shift_context() {
const int n_discard = (current_position - system_prompt_position) / 2;
LOGi("%s: Discarding %d tokens", __func__, n_discard);
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
current_position -= n_discard;
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
}
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
common_chat_msg new_msg;
new_msg.role = role;
new_msg.content = content;
auto formatted = common_chat_format_single(
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
chat_msgs.push_back(new_msg);
LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
return formatted;
}
/**
* Completion loop's short-term states:
* - stop generation position
* - token chars caching
* - current assistant message being generated
*/
static llama_pos stop_generation_position;
static std::string cached_token_chars;
static std::ostringstream assistant_ss;
static void reset_short_term_states() {
stop_generation_position = 0;
cached_token_chars.clear();
assistant_ss.str("");
}
static int decode_tokens_in_batches(
llama_context *context,
llama_batch &batch,
const llama_tokens &tokens,
const llama_pos start_pos,
const bool compute_last_logit = false) {
// Process tokens in batches using the global batch
LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
common_batch_clear(batch);
LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
// Shift context if current batch cannot fit into the context
if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
shift_context();
}
// Add tokens to the batch with proper positions
for (int j = 0; j < cur_batch_size; j++) {
const llama_token token_id = tokens[i + j];
const llama_pos position = start_pos + i + j;
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
common_batch_add(batch, token_id, position, {0}, want_logit);
}
// Decode this batch
const int decode_result = llama_decode(context, batch);
if (decode_result) {
LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
return 1;
}
}
return 0;
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt(
JNIEnv *env,
jobject /*unused*/,
jstring jsystem_prompt
) {
// Reset long-term & short-term states
reset_long_term_states();
reset_short_term_states();
// Obtain system prompt from JEnv
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
std::string formatted_system_prompt(system_prompt);
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
// Format system prompt if applicable
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
if (has_chat_template) {
formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
}
// Tokenize system prompt
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
has_chat_template, has_chat_template);
for (auto id: system_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// Handle context overflow
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
if ((int) system_tokens.size() > max_batch_size) {
LOGe("%s: System prompt too long for context! %d tokens, max: %d",
__func__, (int) system_tokens.size(), max_batch_size);
return 1;
}
// Decode system tokens in batches
if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
LOGe("%s: llama_decode() failed!", __func__);
return 2;
}
// Update position
system_prompt_position = current_position = (int) system_tokens.size();
return 0;
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt(
JNIEnv *env,
jobject /*unused*/,
jstring juser_prompt,
jint n_predict
) {
// Reset short-term states
reset_short_term_states();
// Obtain and tokenize user prompt
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
std::string formatted_user_prompt(user_prompt);
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
// Format user prompt if applicable
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
if (has_chat_template) {
formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
}
// Decode formatted user prompts
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
for (auto id: user_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
const int user_prompt_size = (int) user_tokens.size();
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
if (user_prompt_size > max_batch_size) {
const int skipped_tokens = user_prompt_size - max_batch_size;
user_tokens.resize(max_batch_size);
LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
}
// Decode user tokens in batches
if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
LOGe("%s: llama_decode() failed!", __func__);
return 2;
}
// Update position
current_position += user_prompt_size;
stop_generation_position = current_position + user_prompt_size + n_predict;
return 0;
}
static bool is_valid_utf8(const char *string) {
if (!string) { return true; }
const auto *bytes = (const unsigned char *) string;
int num;
while (*bytes != 0x00) {
if ((*bytes & 0x80) == 0x00) {
// U+0000 to U+007F
num = 1;
} else if ((*bytes & 0xE0) == 0xC0) {
// U+0080 to U+07FF
num = 2;
} else if ((*bytes & 0xF0) == 0xE0) {
// U+0800 to U+FFFF
num = 3;
} else if ((*bytes & 0xF8) == 0xF0) {
// U+10000 to U+10FFFF
num = 4;
} else {
return false;
}
bytes += 1;
for (int i = 1; i < num; ++i) {
if ((*bytes & 0xC0) != 0x80) {
return false;
}
bytes += 1;
}
}
return true;
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken(
JNIEnv *env,
jobject /*unused*/
) {
// Infinite text generation via context shifting
if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
LOGw("%s: Context full! Shifting...", __func__);
shift_context();
}
// Stop if reaching the marked position
if (current_position >= stop_generation_position) {
LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
return nullptr;
}
// Sample next token
const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
common_sampler_accept(g_sampler, new_token_id, true);
// Populate the batch with new token, then decode
common_batch_clear(g_batch);
common_batch_add(g_batch, new_token_id, current_position, {0}, true);
if (llama_decode(g_context, g_batch) != 0) {
LOGe("%s: llama_decode() failed for generated token", __func__);
return nullptr;
}
// Update position
current_position++;
// Stop if next token is EOG
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
return nullptr;
}
// If not EOG, convert to text
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
cached_token_chars += new_token_chars;
// Create and return a valid UTF-8 Java string
jstring result = nullptr;
if (is_valid_utf8(cached_token_chars.c_str())) {
result = env->NewStringUTF(cached_token_chars.c_str());
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
assistant_ss << cached_token_chars;
cached_token_chars.clear();
} else {
LOGv("id: %d,\tappend to cache", new_token_id);
result = env->NewStringUTF("");
}
return result;
}
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
// Reset long-term & short-term states
reset_long_term_states();
reset_short_term_states();
// Free up resources
common_sampler_free(g_sampler);
g_chat_templates.reset();
llama_batch_free(g_batch);
llama_free(g_context);
llama_model_free(g_model);
}
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *env, jobject /*unused*/) {
llama_backend_free();
}

View File

@ -0,0 +1,61 @@
//
// Created by Han Yin on 10/31/25.
//
#ifndef AICHAT_LOGGING_H
#define AICHAT_LOGGING_H
#endif //AICHAT_LOGGING_H
#pragma once
#include <android/log.h>
#ifndef LOG_TAG
#define LOG_TAG "ai-chat"
#endif
#ifndef LOG_MIN_LEVEL
#if defined(NDEBUG)
#define LOG_MIN_LEVEL ANDROID_LOG_INFO
#else
#define LOG_MIN_LEVEL ANDROID_LOG_VERBOSE
#endif
#endif
static inline int ai_should_log(int prio) {
return __android_log_is_loggable(prio, LOG_TAG, LOG_MIN_LEVEL);
}
#if LOG_MIN_LEVEL <= ANDROID_LOG_VERBOSE
#define LOGv(...) do { if (ai_should_log(ANDROID_LOG_VERBOSE)) __android_log_print(ANDROID_LOG_VERBOSE, LOG_TAG, __VA_ARGS__); } while (0)
#else
#define LOGv(...) ((void)0)
#endif
#if LOG_MIN_LEVEL <= ANDROID_LOG_DEBUG
#define LOGd(...) do { if (ai_should_log(ANDROID_LOG_DEBUG)) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__); } while (0)
#else
#define LOGd(...) ((void)0)
#endif
#define LOGi(...) do { if (ai_should_log(ANDROID_LOG_INFO )) __android_log_print(ANDROID_LOG_INFO , LOG_TAG, __VA_ARGS__); } while (0)
#define LOGw(...) do { if (ai_should_log(ANDROID_LOG_WARN )) __android_log_print(ANDROID_LOG_WARN , LOG_TAG, __VA_ARGS__); } while (0)
#define LOGe(...) do { if (ai_should_log(ANDROID_LOG_ERROR)) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__); } while (0)
static inline int android_log_prio_from_ggml(enum ggml_log_level level) {
switch (level) {
case GGML_LOG_LEVEL_ERROR: return ANDROID_LOG_ERROR;
case GGML_LOG_LEVEL_WARN: return ANDROID_LOG_WARN;
case GGML_LOG_LEVEL_INFO: return ANDROID_LOG_INFO;
case GGML_LOG_LEVEL_DEBUG: return ANDROID_LOG_DEBUG;
default: return ANDROID_LOG_DEFAULT;
}
}
static inline void aichat_android_log_callback(enum ggml_log_level level,
const char* text,
void* /*user*/) {
const int prio = android_log_prio_from_ggml(level);
if (!ai_should_log(prio)) return;
__android_log_write(prio, LOG_TAG, text);
}

View File

@ -0,0 +1,14 @@
package com.arm.aichat
import android.content.Context
import com.arm.aichat.internal.InferenceEngineImpl
/**
* Main entry point for Arm's AI Chat library.
*/
object AiChat {
/**
* Get the inference engine single instance.
*/
fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context)
}

View File

@ -0,0 +1,89 @@
package com.arm.aichat
import com.arm.aichat.InferenceEngine.State
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.StateFlow
/**
* Interface defining the core LLM inference operations.
*/
interface InferenceEngine {
/**
* Current state of the inference engine
*/
val state: StateFlow<State>
/**
* Load a model from the given path.
*
* @throws UnsupportedArchitectureException if model architecture not supported
*/
suspend fun loadModel(pathToModel: String)
/**
* Sends a system prompt to the loaded model
*/
suspend fun setSystemPrompt(systemPrompt: String)
/**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
/**
* Runs a benchmark with the specified parameters.
*/
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
/**
* Unloads the currently loaded model.
*/
suspend fun cleanUp()
/**
* Cleans up resources when the engine is no longer needed.
*/
fun destroy()
/**
* States of the inference engine
*/
sealed class State {
object Uninitialized : State()
object Initializing : State()
object Initialized : State()
object LoadingModel : State()
object UnloadingModel : State()
object ModelReady : State()
object Benchmarking : State()
object ProcessingSystemPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()
data class Error(val exception: Exception) : State()
}
companion object {
const val DEFAULT_PREDICT_LENGTH = 1024
}
}
val State.isUninterruptible
get() = this is State.Initializing ||
this is State.LoadingModel ||
this is State.UnloadingModel ||
this is State.Benchmarking ||
this is State.ProcessingSystemPrompt ||
this is State.ProcessingUserPrompt
val State.isModelLoaded: Boolean
get() = this is State.ModelReady ||
this is State.Benchmarking ||
this is State.ProcessingSystemPrompt ||
this is State.ProcessingUserPrompt ||
this is State.Generating
class UnsupportedArchitectureException : Exception()

View File

@ -0,0 +1,61 @@
package com.arm.aichat.gguf
import kotlin.collections.get
/**
* Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`).
* The `label` matches what llamacli prints.
*/
enum class FileType(val code: Int, val label: String) {
ALL_F32(0, "all F32"),
MOSTLY_F16(1, "F16"),
MOSTLY_Q4_0(2, "Q4_0"),
MOSTLY_Q4_1(3, "Q4_1"),
// 4 removed
MOSTLY_Q8_0(7, "Q8_0"),
MOSTLY_Q5_0(8, "Q5_0"),
MOSTLY_Q5_1(9, "Q5_1"),
/* Kquants ------------------------------------------------------------ */
MOSTLY_Q2_K (10, "Q2_K - Medium"),
MOSTLY_Q3_K_S (11, "Q3_K - Small"),
MOSTLY_Q3_K_M (12, "Q3_K - Medium"),
MOSTLY_Q3_K_L (13, "Q3_K - Large"),
MOSTLY_Q4_K_S (14, "Q4_K - Small"),
MOSTLY_Q4_K_M (15, "Q4_K - Medium"),
MOSTLY_Q5_K_S (16, "Q5_K - Small"),
MOSTLY_Q5_K_M (17, "Q5_K - Medium"),
MOSTLY_Q6_K (18, "Q6_K"),
/* IQ quants ----------------------------------------------------------- */
MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"),
MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"),
MOSTLY_Q2_K_S (21, "Q2_K - Small"),
MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"),
MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"),
MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"),
MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"),
MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"),
MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"),
MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"),
MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"),
MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"),
MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"),
/* BF16 & Ternary ------------------------------------------------------ */
MOSTLY_BF16 (32, "BF16"),
MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"),
MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"),
/* Special flag -------------------------------------------------------- */
GUESSED(1024, "(guessed)"),
UNKNOWN(-1, "unknown");
companion object {
private val map = entries.associateBy(FileType::code)
fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN
}
}

View File

@ -0,0 +1,132 @@
package com.arm.aichat.gguf
import java.io.IOException
/**
* Structured metadata of GGUF
*/
data class GgufMetadata(
// Basic file info
val version: GgufVersion,
val tensorCount: Long,
val kvCount: Long,
// General info
val basic: BasicInfo,
val author: AuthorInfo? = null,
val additional: AdditionalInfo? = null,
val architecture: ArchitectureInfo? = null,
val baseModels: List<BaseModelInfo>? = null,
val tokenizer: TokenizerInfo? = null,
// Derivative info
val dimensions: DimensionsInfo? = null,
val attention: AttentionInfo? = null,
val rope: RopeInfo? = null,
val experts: ExpertsInfo? = null
) {
enum class GgufVersion(val code: Int, val label: String) {
/** First public draft; littleendian only, no alignment key. */
LEGACY_V1(1, "Legacy v1"),
/** Added splitfile support and some extra metadata keys. */
EXTENDED_V2(2, "Extended v2"),
/** Current spec: endianaware, mandatory alignment, fully validated. */
VALIDATED_V3(3, "Validated v3");
companion object {
fun fromCode(code: Int): GgufVersion =
entries.firstOrNull { it.code == code }
?: throw IOException("Unknown GGUF version code $code")
}
override fun toString(): String = "$label (code=$code)"
}
data class BasicInfo(
val uuid: String? = null,
val name: String? = null,
val nameLabel: String? = null,
val sizeLabel: String? = null, // Size label like "7B"
)
data class AuthorInfo(
val organization: String? = null,
val author: String? = null,
val doi: String? = null,
val url: String? = null,
val repoUrl: String? = null,
val license: String? = null,
val licenseLink: String? = null,
)
data class AdditionalInfo(
val type: String? = null,
val description: String? = null,
val tags: List<String>? = null,
val languages: List<String>? = null,
)
data class ArchitectureInfo(
val architecture: String? = null,
val fileType: Int? = null,
val vocabSize: Int? = null,
val finetune: String? = null,
val quantizationVersion: Int? = null,
)
data class BaseModelInfo(
val name: String? = null,
val author: String? = null,
val version: String? = null,
val organization: String? = null,
val url: String? = null,
val doi: String? = null,
val uuid: String? = null,
val repoUrl: String? = null,
)
data class TokenizerInfo(
val model: String? = null,
val bosTokenId: Int? = null,
val eosTokenId: Int? = null,
val unknownTokenId: Int? = null,
val paddingTokenId: Int? = null,
val addBosToken: Boolean? = null,
val addEosToken: Boolean? = null,
val chatTemplate: String? = null,
)
data class DimensionsInfo(
val contextLength: Int? = null,
val embeddingSize: Int? = null,
val blockCount: Int? = null,
val feedForwardSize: Int? = null,
)
data class AttentionInfo(
val headCount: Int? = null,
val headCountKv: Int? = null,
val keyLength: Int? = null,
val valueLength: Int? = null,
val layerNormEpsilon: Float? = null,
val layerNormRmsEpsilon: Float? = null,
)
data class RopeInfo(
val frequencyBase: Float? = null,
val dimensionCount: Int? = null,
val scalingType: String? = null,
val scalingFactor: Float? = null,
val attnFactor: Float? = null,
val originalContextLength: Int? = null,
val finetuned: Boolean? = null,
)
data class ExpertsInfo(
val count: Int? = null,
val usedCount: Int? = null,
)
}

View File

@ -0,0 +1,77 @@
package com.arm.aichat.gguf
import android.content.Context
import android.net.Uri
import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl
import java.io.File
import java.io.IOException
import java.io.InputStream
/**
* Interface for reading GGUF metadata from model files.
* Use `GgufMetadataReader.create()` to get an instance.
*/
interface GgufMetadataReader {
/**
* Reads the magic number from the specified file path.
*
* @param file Java File to the GGUF file with absolute path
* @return true if file is valid GGUF, otherwise false
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun ensureSourceFileFormat(file: File): Boolean
/**
* Reads the magic number from the specified file path.
*
* @param context Context for obtaining [android.content.ContentProvider]
* @param uri Uri to the GGUF file provided by [android.content.ContentProvider]
* @return true if file is valid GGUF, otherwise false
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean
/**
* Reads and parses GGUF metadata from the specified file path.
*
* @param input the [InputStream] obtained from a readable file or content
* @return Structured metadata extracted from the file
* @throws IOException if file is damaged or cannot be read
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
companion object {
private val DEFAULT_SKIP_KEYS = setOf(
"tokenizer.chat_template",
"tokenizer.ggml.scores",
"tokenizer.ggml.tokens",
"tokenizer.ggml.token_type"
)
/**
* Creates a default GgufMetadataReader instance
*/
fun create(): GgufMetadataReader = GgufMetadataReaderImpl(
skipKeys = DEFAULT_SKIP_KEYS,
arraySummariseThreshold = 1_000
)
/**
* Creates a GgufMetadataReader with custom configuration
*
* @param skipKeys Keys whose value should be skipped entirely (not kept in the result map)
* @param arraySummariseThreshold If 0, arrays longer get summarised, not materialised;
* If -1, never summarise.
*/
fun create(
skipKeys: Set<String> = DEFAULT_SKIP_KEYS,
arraySummariseThreshold: Int = 1_000
): GgufMetadataReader = GgufMetadataReaderImpl(
skipKeys = skipKeys,
arraySummariseThreshold = arraySummariseThreshold
)
}
}
class InvalidFileFormatException : IOException()

View File

@ -0,0 +1,309 @@
package com.arm.aichat.internal
import android.content.Context
import android.util.Log
import com.arm.aichat.InferenceEngine
import com.arm.aichat.UnsupportedArchitectureException
import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance
import dalvik.annotation.optimization.FastNative
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
import java.io.IOException
/**
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
*
* This class implements a singleton pattern for managing the lifecycle of a single LLM instance.
* All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety
* with the underlying C++ native code.
*
* The typical usage flow is:
* 1. Get instance via [getInstance]
* 2. Load a model with [loadModel]
* 3. Send prompts with [sendUserPrompt]
* 4. Generate responses as token streams
* 5. Perform [cleanUp] when done with a model
* 6. Properly [destroy] when completely done
*
* State transitions are managed automatically and validated at each operation.
*
* @see ai_chat.cpp for the native implementation details
*/
internal class InferenceEngineImpl private constructor(
private val nativeLibDir: String
) : InferenceEngine {
companion object {
private val TAG = InferenceEngineImpl::class.java.simpleName
@Volatile
private var instance: InferenceEngine? = null
/**
* Create or obtain [InferenceEngineImpl]'s single instance.
*
* @param Context for obtaining native library directory
* @throws IllegalArgumentException if native library path is invalid
* @throws UnsatisfiedLinkError if library failed to load
*/
internal fun getInstance(context: Context) =
instance ?: synchronized(this) {
val nativeLibDir = context.applicationInfo.nativeLibraryDir
require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" }
try {
Log.i(TAG, "Instantiating InferenceEngineImpl,,,")
InferenceEngineImpl(nativeLibDir).also { instance = it }
} catch (e: UnsatisfiedLinkError) {
Log.e(TAG, "Failed to load native library from $nativeLibDir", e)
throw e
}
}
}
/**
* JNI methods
* @see ai_chat.cpp
*/
@FastNative
private external fun init(nativeLibDir: String)
@FastNative
private external fun load(modelPath: String): Int
@FastNative
private external fun prepare(): Int
@FastNative
private external fun systemInfo(): String
@FastNative
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
@FastNative
private external fun processSystemPrompt(systemPrompt: String): Int
@FastNative
private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
@FastNative
private external fun generateNextToken(): String?
@FastNative
private external fun unload()
@FastNative
private external fun shutdown()
private val _state =
MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized)
override val state: StateFlow<InferenceEngine.State> = _state
private var _readyForSystemPrompt = false
/**
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
*/
@OptIn(ExperimentalCoroutinesApi::class)
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
init {
llamaScope.launch {
try {
check(_state.value is InferenceEngine.State.Uninitialized) {
"Cannot load native library in ${_state.value.javaClass.simpleName}!"
}
_state.value = InferenceEngine.State.Initializing
Log.i(TAG, "Loading native library...")
System.loadLibrary("ai-chat")
init(nativeLibDir)
_state.value = InferenceEngine.State.Initialized
Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}")
} catch (e: Exception) {
Log.e(TAG, "Failed to load native library", e)
throw e
}
}
}
/**
* Load the LLM
*/
override suspend fun loadModel(pathToModel: String) =
withContext(llamaDispatcher) {
check(_state.value is InferenceEngine.State.Initialized) {
"Cannot load model in ${_state.value.javaClass.simpleName}!"
}
try {
Log.i(TAG, "Checking access to model file... \n$pathToModel")
File(pathToModel).let {
require(it.exists()) { "File not found" }
require(it.isFile) { "Not a valid file" }
require(it.canRead()) { "Cannot read file" }
}
Log.i(TAG, "Loading model... \n$pathToModel")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.LoadingModel
load(pathToModel).let {
// TODO-han.yin: find a better way to pass other error codes
if (it != 0) throw UnsupportedArchitectureException()
}
prepare().let {
if (it != 0) throw IOException("Failed to prepare resources")
}
Log.i(TAG, "Model loaded!")
_readyForSystemPrompt = true
_state.value = InferenceEngine.State.ModelReady
} catch (e: Exception) {
Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e)
_state.value = InferenceEngine.State.Error(e)
throw e
}
}
/**
* Process the plain text system prompt
*
* TODO-han.yin: return error code if system prompt not correct processed?
*/
override suspend fun setSystemPrompt(prompt: String) =
withContext(llamaDispatcher) {
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
check(_state.value is InferenceEngine.State.ModelReady) {
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
}
Log.i(TAG, "Sending system prompt...")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.ProcessingSystemPrompt
processSystemPrompt(prompt).let { result ->
if (result != 0) {
RuntimeException("Failed to process system prompt: $result").also {
_state.value = InferenceEngine.State.Error(it)
throw it
}
}
}
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
_state.value = InferenceEngine.State.ModelReady
}
/**
* Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
*/
override fun sendUserPrompt(
message: String,
predictLength: Int,
): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is InferenceEngine.State.ModelReady) {
"User prompt discarded due to: ${_state.value.javaClass.simpleName}"
}
try {
Log.i(TAG, "Sending user prompt...")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.ProcessingUserPrompt
processUserPrompt(message, predictLength).let { result ->
if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result")
return@flow
}
}
Log.i(TAG, "User prompt processed. Generating assistant prompt...")
_state.value = InferenceEngine.State.Generating
while (true) {
generateNextToken()?.let { utf8token ->
if (utf8token.isNotEmpty()) emit(utf8token)
} ?: break
}
Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
_state.value = InferenceEngine.State.ModelReady
} catch (e: CancellationException) {
Log.i(TAG, "Generation cancelled by user.")
_state.value = InferenceEngine.State.ModelReady
throw e
} catch (e: Exception) {
Log.e(TAG, "Error during generation!", e)
_state.value = InferenceEngine.State.Error(e)
throw e
}
}.flowOn(llamaDispatcher)
/**
* Benchmark the model
*/
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
withContext(llamaDispatcher) {
check(_state.value is InferenceEngine.State.ModelReady) {
"Benchmark request discarded due to: $state"
}
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
_readyForSystemPrompt = false // Just to be safe
_state.value = InferenceEngine.State.Benchmarking
benchModel(pp, tg, pl, nr).also {
_state.value = InferenceEngine.State.ModelReady
}
}
/**
* Unloads the model and frees resources, or reset error states
*/
override suspend fun cleanUp() =
withContext(llamaDispatcher) {
when (val state = _state.value) {
is InferenceEngine.State.ModelReady -> {
Log.i(TAG, "Unloading model and free resources...")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.UnloadingModel
unload()
_state.value = InferenceEngine.State.Initialized
Log.i(TAG, "Model unloaded!")
Unit
}
is InferenceEngine.State.Error -> {
Log.i(TAG, "Resetting error states...")
_state.value = InferenceEngine.State.Initialized
Log.i(TAG, "States reset!")
Unit
}
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
}
}
/**
* Cancel all ongoing coroutines and free GGML backends
*/
override fun destroy() {
_readyForSystemPrompt = false
llamaScope.cancel()
when(_state.value) {
is InferenceEngine.State.Uninitialized -> {}
is InferenceEngine.State.Initialized -> shutdown()
else -> { unload(); shutdown() }
}
}
}

View File

@ -0,0 +1,590 @@
package com.arm.aichat.internal.gguf
import android.content.Context
import android.net.Uri
import com.arm.aichat.gguf.GgufMetadata
import com.arm.aichat.gguf.GgufMetadataReader
import com.arm.aichat.gguf.InvalidFileFormatException
import java.io.File
import java.io.IOException
import java.io.InputStream
/**
* Utility class to read GGUF model files and extract metadata key-value pairs.
* This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data.
*/
internal class GgufMetadataReaderImpl(
private val skipKeys: Set<String>,
private val arraySummariseThreshold: Int,
) : GgufMetadataReader {
companion object {
private const val ARCH_LLAMA = "llama"
}
/** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */
enum class MetadataType(val code: Int) {
UINT8(0), INT8(1), UINT16(2), INT16(3),
UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
companion object {
private val codeMap = entries.associateBy(MetadataType::code)
fun fromCode(code: Int): MetadataType = codeMap[code]
?: throw IOException("Unknown metadata value type code: $code")
}
}
/** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */
sealed class MetadataValue {
data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int
data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int
data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian)
data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian)
data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian)
data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian)
data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float
data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true)
data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed)
data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue()
data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian)
data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian)
data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double
}
/* Convert MetadataValue to plain Kotlin primitives for allMetadata map */
private fun MetadataValue.toPrimitive(): Any = when (this) {
is MetadataValue.UInt8 -> value
is MetadataValue.Int8 -> value
is MetadataValue.UInt16 -> value
is MetadataValue.Int16 -> value
is MetadataValue.UInt32 -> value
is MetadataValue.Int32 -> value
is MetadataValue.Float32 -> value
is MetadataValue.Bool -> value
is MetadataValue.StringVal -> value
is MetadataValue.UInt64 -> value
is MetadataValue.Int64 -> value
is MetadataValue.Float64 -> value
is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
}
/**
* Reads the magic number from the specified file path.
*
* @param context Context for obtaining ContentResolver
* @param uri Uri to the GGUF file provided by ContentProvider
* @return true if file is valid GGUF, otherwise false
*/
override suspend fun ensureSourceFileFormat(file: File): Boolean =
file.inputStream().buffered().use { ensureMagic(it) }
/**
* Reads the magic number from the specified file path.
*
* @param context Context for obtaining ContentResolver
* @param uri Uri to the GGUF file provided by ContentProvider
* @return true if file is valid GGUF, otherwise false
*/
override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean =
context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
/** Reads the 4byte magic; throws if magic ≠ "GGUF". */
private fun ensureMagic(input: InputStream): Boolean =
ByteArray(4).let {
if (input.read(it) != 4) throw IOException("Not a valid file!")
it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
}
/**
* Highlevel entry point: parses a `.gguf` file on disk and returns the fully
* populated [GgufMetadata] tree.
*
* Steps performed internally:
* 1. Reads and validates the 8byte header (`"GGUF"` magic + version).
* 2. Streams through the keyvalue section, skipping large blobs if the key
* appears in [skipKeys] or if an array exceeds [arraySummariseThreshold].
* 3. Converts the resulting raw map into stronglytyped substructures
* (basic info, tokenizer, rope, etc.).
*
* The method is STREAMINGONLY: tensors are never mapped or loaded into
* memory, so even multiGB model files can be processed in < 50 ms.
*
* @param path Absolute or relative filesystem path to a `.gguf` file.
* @return A [GgufMetadata] instance containing all recognised metadata plus
* an `allMetadata` map with any keys that were not given a dedicated
* field.
* @throws IOException if the file is not GGUF, the version is unsupported,
* or the metadata block is truncated / corrupt.
*/
override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
// ── 1. header ──────────────────────────────────────────────────────────
// throws on mismatch
val version = ensureMagicAndVersion(input)
val tensorCount = readLittleLong(input)
val kvCount = readLittleLong(input)
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
// ── 3. build structured object ────────────────────────────────────────
return buildStructured(meta, version, tensorCount, kvCount)
}
/** Reads the 4byte magic + 4byte version; throws if magic ≠ "GGUF". */
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
if (!ensureMagic(input)) throw InvalidFileFormatException()
return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
}
/**
* Read an unsigned 32bit littleendian integer.
*
* @throws IOException if fewer than four bytes are available.
*/
private fun readLEUInt32(input: InputStream): Int {
val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
return (b3 and 0xFF shl 24) or
(b2 and 0xFF shl 16) or
(b1 and 0xFF shl 8) or
(b0 and 0xFF)
}
/**
* Lowlevel helper that reads the entire key-value section from the current
* stream position.
*
* @param input Open stream positioned JUST AFTER the header.
* @param kvCnt Number of keyvalue pairs (taken from the header).
* @return Mutable map with one [MetadataValue] for every key that is NOT skipped.
*
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking
* [skipValue] or [parseValue] accordingly.
*/
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> =
mutableMapOf<String, MetadataValue>().apply {
repeat(kvCnt.toInt()) {
val key = readString(input)
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
if (key in skipKeys) {
skipValue(input, valueT)
} else {
this[key] = parseValue(input, valueT)
}
}
}
/**
* Converts a flat [Map]<[String], [MetadataValue]> into the stronglytyped
* [GgufMetadata] tree used by the rest of the app.
*
* Only the keys listed in the spec are copied into dedicated data classes;
* everything else is preserved in `GgufMetadata.allMetadata`.
*
* @param m Raw key/value map.
* @param version GGUF fileformat version (enum).
* @param tensorCnt Number of tensors (from the header).
* @param kvCnt Total metadata pair count (from the header).
*/
private fun buildStructured(
m: Map<String, MetadataValue>,
version: GgufMetadata.GgufVersion,
tensorCnt: Long,
kvCnt: Long
): GgufMetadata {
// ---------- helpers ----------
fun String.str() = (m[this] as? MetadataValue.StringVal)?.value
fun String.bool() = (m[this] as? MetadataValue.Bool)?.value
fun String.i32() = (m[this] as? MetadataValue.Int32)?.value
fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt()
fun String.f32() = (m[this] as? MetadataValue.Float32)?.value
fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat()
fun String.strList(): List<String>? =
(m[this] as? MetadataValue.ArrayVal)
?.elements
?.mapNotNull { (it as? MetadataValue.StringVal)?.value }
val arch = "general.architecture".str() ?: ARCH_LLAMA
// -------------- populate sections ----------------
val basic = GgufMetadata.BasicInfo(
uuid = "general.uuid".str(),
name = "general.basename".str(),
nameLabel = "general.name".str(),
sizeLabel = "general.size_label".str()
)
val author = GgufMetadata.AuthorInfo(
organization = "general.organization".str(),
author = "general.author".str(),
doi = "general.doi".str(),
url = "general.url".str(),
repoUrl = "general.repo_url".str(),
license = "general.license".str(),
licenseLink = "general.license.link".str()
).takeUnless {
organization == null && author == null && doi == null &&
url == null && repoUrl == null && license == null && licenseLink == null
}
val additional = GgufMetadata.AdditionalInfo(
type = "general.type".str(),
description = "general.description".str(),
tags = "general.tags".strList(),
languages = "general.languages".strList()
).takeUnless {
type == null && description == null && tags == null && languages == null
}
val architectureInfo = GgufMetadata.ArchitectureInfo(
architecture = arch,
fileType = "general.file_type".u32(),
vocabSize = "$arch.vocab_size".u32(),
finetune = "general.finetune".str(),
quantizationVersion = "general.quantization_version".u32()
).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null }
val baseModels = buildList {
val n = "general.base_model.count".u32() ?: 0
for (i in 0 until n) {
fun k(s: String) = "general.base_model.$i.$s"
add(
GgufMetadata.BaseModelInfo(
name = k("name").str(),
author = k("author").str(),
version = k("version").str(),
organization = k("organization").str(),
url = k("url").str(),
doi = k("doi").str(),
uuid = k("uuid").str(),
repoUrl = k("repo_url").str(),
)
)
}
}.takeIf { it.isNotEmpty() }
val tokenizer = GgufMetadata.TokenizerInfo(
model = "tokenizer.ggml.model".str(),
bosTokenId = "tokenizer.ggml.bos_token_id".u32(),
eosTokenId = "tokenizer.ggml.eos_token_id".u32(),
unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(),
paddingTokenId = "tokenizer.ggml.padding_token_id".u32(),
addBosToken = "tokenizer.ggml.add_bos_token".bool(),
addEosToken = "tokenizer.ggml.add_eos_token".bool(),
chatTemplate = "tokenizer.chat_template".str()
).takeUnless { model == null && bosTokenId == null && eosTokenId == null &&
unknownTokenId == null && paddingTokenId == null &&
addBosToken == null && addEosToken == null && chatTemplate == null
}
val dimensions = GgufMetadata.DimensionsInfo(
contextLength = "$arch.context_length".u32(),
embeddingSize = "$arch.embedding_length".u32(),
blockCount = "$arch.block_count".u32(),
feedForwardSize = "$arch.feed_forward_length".u32()
).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null }
val attention = GgufMetadata.AttentionInfo(
headCount = "$arch.attention.head_count".u32(),
headCountKv = "$arch.attention.head_count_kv".u32(),
keyLength = "$arch.attention.key_length".u32(),
valueLength = "$arch.attention.value_length".u32(),
layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(),
layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(),
).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null &&
layerNormEpsilon == null && layerNormRmsEpsilon == null
}
val rope = GgufMetadata.RopeInfo(
frequencyBase = "$arch.rope.freq_base".f32(),
dimensionCount = "$arch.rope.dimension_count".u32(),
scalingType = "$arch.rope.scaling.type".str(),
scalingFactor = "$arch.rope.scaling.factor".f32(),
attnFactor = "$arch.rope.scaling.attn_factor".f32(),
originalContextLength = "$arch.rope.scaling.original_context_length".u32(),
finetuned = "$arch.rope.scaling.finetuned".bool()
).takeUnless { frequencyBase == null && dimensionCount == null &&
scalingType == null && scalingFactor == null && attnFactor == null &&
originalContextLength == null && finetuned == null
}
val experts = GgufMetadata.ExpertsInfo(
count = "$arch.expert_count".u32(),
usedCount = "$arch.expert_used_count".u32()
).takeUnless { count == null && usedCount == null }
return GgufMetadata(
version = version,
tensorCount = tensorCnt,
kvCount = kvCnt,
basic = basic,
author = author,
additional = additional,
architecture = architectureInfo,
baseModels = baseModels,
tokenizer = tokenizer,
dimensions = dimensions,
attention = attention,
rope = rope,
experts = experts
)
}
/**
* Recursively parses a metadata value of the given type from the input stream.
* @param input The input stream positioned at the start of the value.
* @param type The metadata value type to parse.
*/
private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) {
MetadataType.UINT8 -> {
// 1-byte unsigned integer
val byteVal = input.read()
if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.")
MetadataValue.UInt8(byteVal.toUByte())
}
MetadataType.INT8 -> {
// 1-byte signed integer
val byteVal = input.read()
if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.")
MetadataValue.Int8(byteVal.toByte())
}
MetadataType.UINT16 -> {
// 2-byte unsigned integer (little-endian)
val bytes = ByteArray(2)
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.")
// Combine two bytes (little-endian) into an unsigned 16-bit value
val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
MetadataValue.UInt16(u16.toUShort())
}
MetadataType.INT16 -> {
// 2-byte signed integer (little-endian)
val bytes = ByteArray(2)
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.")
// Combine to 16-bit and interpret as signed
val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
MetadataValue.Int16(i16.toShort())
}
MetadataType.UINT32 -> {
// 4-byte unsigned integer (little-endian)
val bytes = ByteArray(4)
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.")
// Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt
val u32 = (bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
MetadataValue.UInt32(u32.toUInt())
}
MetadataType.INT32 -> {
// 4-byte signed integer (little-endian)
val bytes = ByteArray(4)
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.")
// Combine four bytes into a 32-bit signed int
val i32 = (bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF)
MetadataValue.Int32(i32)
}
MetadataType.FLOAT32 -> {
// 4-byte IEEE 754 float (little-endian)
val bytes = ByteArray(4)
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.")
// Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float
val bits = (bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF)
val floatVal = Float.fromBits(bits)
MetadataValue.Float32(floatVal)
}
MetadataType.BOOL -> {
// 1-byte boolean (0 = false, 1 = true)
val byteVal = input.read()
if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.")
if (byteVal != 0 && byteVal != 1) {
throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).")
}
MetadataValue.Bool(byteVal != 0)
}
MetadataType.STRING -> {
// UTF-8 string (length-prefixed with 8-byte length)
val str = readString(input)
MetadataValue.StringVal(str)
}
MetadataType.ARRAY -> {
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
val len = readLittleLong(input)
val count = len.toInt()
if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) {
// fastforward without allocation
repeat(count) { skipValue(input, elemType) }
MetadataValue.StringVal("Array($elemType, $count items) /* summarised */")
} else {
val list = ArrayList<MetadataValue>(count)
repeat(count) { list += parseValue(input, elemType) }
MetadataValue.ArrayVal(elemType, list)
}
}
MetadataType.UINT64 -> {
// 8-byte unsigned integer (little-endian)
val bytes = ByteArray(8)
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.")
// Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range.
val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or
(bytes[6].toULong() and 0xFFuL shl 48) or
(bytes[5].toULong() and 0xFFuL shl 40) or
(bytes[4].toULong() and 0xFFuL shl 32) or
(bytes[3].toULong() and 0xFFuL shl 24) or
(bytes[2].toULong() and 0xFFuL shl 16) or
(bytes[1].toULong() and 0xFFuL shl 8) or
(bytes[0].toULong() and 0xFFuL)
MetadataValue.UInt64(u64)
}
MetadataType.INT64 -> {
// 8-byte signed integer (little-endian)
val bytes = ByteArray(8)
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.")
// Combine 8 bytes into a signed 64-bit value (Long)
val i64 = (bytes[7].toLong() and 0xFFL shl 56) or
(bytes[6].toLong() and 0xFFL shl 48) or
(bytes[5].toLong() and 0xFFL shl 40) or
(bytes[4].toLong() and 0xFFL shl 32) or
(bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
MetadataValue.Int64(i64)
}
MetadataType.FLOAT64 -> {
// 8-byte IEEE 754 double (little-endian)
val bytes = ByteArray(8)
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.")
// Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double
val bits = (bytes[7].toLong() and 0xFFL shl 56) or
(bytes[6].toLong() and 0xFFL shl 48) or
(bytes[5].toLong() and 0xFFL shl 40) or
(bytes[4].toLong() and 0xFFL shl 32) or
(bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
val doubleVal = Double.fromBits(bits)
MetadataValue.Float64(doubleVal)
}
}
private fun <T> T?.takeUnless(check: T.() -> Boolean): T? =
this?.takeIf { !it.check() }
/** Helper: Skip a value in the stream without storing it (still maintains pointer). */
private fun skipValue(input: InputStream, type: MetadataType) {
when (type) {
MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1)
MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2)
MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4)
MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8)
MetadataType.STRING -> {
val len = readLittleLong(input); input.skipFully(len)
}
MetadataType.ARRAY -> {
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
val len = readLittleLong(input)
repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip
}
}
}
/** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */
private fun readLittleLong(input: InputStream): Long {
val bytes = ByteArray(8)
input.readFully(bytes)
// Combine 8 bytes into a 64-bit value (Little Endian).
// Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement).
// In our context (lengths/counts), such extremely large values are not expected.
return (bytes[7].toLong() and 0xFFL shl 56) or
(bytes[6].toLong() and 0xFFL shl 48) or
(bytes[5].toLong() and 0xFFL shl 40) or
(bytes[4].toLong() and 0xFFL shl 32) or
(bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
}
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
private fun readString(input: InputStream): String =
// Read 8-byte little-endian length (number of bytes in the string).
readLittleLong(input).let { len ->
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
// Read the UTF-8 bytes of the given length.
ByteArray(len.toInt()).let {
if (it.isNotEmpty()) input.readFully(it)
String(it, Charsets.UTF_8)
}
}
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
private fun littleEndianBytesToInt(bytes: ByteArray): Int =
// Note: assumes bytes length is 4.
(bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF)
/**
* Robust skip that works the same on JDK 11 and Androids desugared runtime.
*
* @param n Number of bytes to advance in the stream.
* @throws IOException on premature EOF.
*/
private fun InputStream.skipFully(n: Long) {
var remaining = n
val scratch = ByteArray(8192) // readandtoss buffer
while (remaining > 0) {
val skipped = skip(remaining)
when {
skipped > 0 -> remaining -= skipped // normal fast path
skipped == 0L -> {
// fallback: read and discard
val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt())
if (read == -1) throw IOException("EOF while skipping $n bytes")
remaining -= read
}
else -> throw IOException("Skip returned negative value")
}
}
}
/**
* Extension that keeps reading until the requested number of bytes are filled.
* Falls back to `read()` when `skip()` returns 0, which happens on some Android
* streams.
*
* @param buf Destination buffer.
* @param len Number of bytes to fill (defaults to `buf.size`).
* @throws IOException on premature EOF.
*/
private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) {
var off = 0
while (off < len) {
val n = read(buf, off, len - off)
if (n == -1) throw IOException("EOF after $off of $len bytes")
off += n
}
}
/**
* Read EXACTLY `n` bytes or throw never returns a partiallyfilled array.
* This is used for small fixedlength reads (e.g. 4byte type codes).
*
* @throws IOException on premature EOF.
*/
private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
if (read(it) != n) throw IOException("Unexpected EOF")
}
}

View File

@ -1,71 +0,0 @@
plugins {
id("com.android.library")
id("org.jetbrains.kotlin.android")
}
android {
namespace = "android.llama.cpp"
compileSdk = 34
defaultConfig {
minSdk = 33
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
consumerProguardFiles("consumer-rules.pro")
ndk {
// Add NDK properties if wanted, e.g.
// abiFilters += listOf("arm64-v8a")
}
externalNativeBuild {
cmake {
arguments += "-DLLAMA_CURL=OFF"
arguments += "-DLLAMA_BUILD_COMMON=ON"
arguments += "-DGGML_LLAMAFILE=OFF"
arguments += "-DCMAKE_BUILD_TYPE=Release"
cppFlags += listOf()
arguments += listOf()
cppFlags("")
}
}
}
buildTypes {
release {
isMinifyEnabled = false
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
}
}
externalNativeBuild {
cmake {
path("src/main/cpp/CMakeLists.txt")
version = "3.22.1"
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
}
kotlinOptions {
jvmTarget = "1.8"
}
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
}
}
}
dependencies {
implementation("androidx.core:core-ktx:1.12.0")
implementation("androidx.appcompat:appcompat:1.6.1")
implementation("com.google.android.material:material:1.11.0")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
}

View File

@ -1,53 +0,0 @@
# For more information about using CMake with Android Studio, read the
# documentation: https://d.android.com/studio/projects/add-native-code.html.
# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
# Sets the minimum CMake version required for this project.
cmake_minimum_required(VERSION 3.22.1)
# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
# Since this is the top level CMakeLists.txt, the project name is also accessible
# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
# build script scope).
project("llama-android")
#include(FetchContent)
#FetchContent_Declare(
# llama
# GIT_REPOSITORY https://github.com/ggml-org/llama.cpp
# GIT_TAG master
#)
# Also provides "common"
#FetchContent_MakeAvailable(llama)
# Creates and names a library, sets it as either STATIC
# or SHARED, and provides the relative paths to its source code.
# You can define multiple libraries, and CMake builds them for you.
# Gradle automatically packages shared libraries with your APK.
#
# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
# is preferred for the same purpose.
#
#load local llama.cpp
add_subdirectory(../../../../../../ build-llama)
# In order to load a library into your app from Java/Kotlin, you must call
# System.loadLibrary() and pass the name of the library defined here;
# for GameActivity/NativeActivity derived applications, the same library name must be
# used in the AndroidManifest.xml file.
add_library(${CMAKE_PROJECT_NAME} SHARED
# List C/C++ source files with relative paths to this CMakeLists.txt.
llama-android.cpp)
# Specifies libraries CMake should link to your target library. You
# can link libraries from various origins, such as libraries defined in this
# build script, prebuilt third-party libraries, or Android system libraries.
target_link_libraries(${CMAKE_PROJECT_NAME}
# List libraries link to the target library
llama
common
android
log)

View File

@ -1,452 +0,0 @@
#include <android/log.h>
#include <jni.h>
#include <iomanip>
#include <math.h>
#include <string>
#include <unistd.h>
#include "llama.h"
#include "common.h"
// Write C++ code here.
//
// Do not forget to dynamically load the C++ library into your application.
//
// For instance,
//
// In MainActivity.java:
// static {
// System.loadLibrary("llama-android");
// }
//
// Or, in MainActivity.kt:
// companion object {
// init {
// System.loadLibrary("llama-android")
// }
// }
#define TAG "llama-android.cpp"
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
jclass la_int_var;
jmethodID la_int_var_value;
jmethodID la_int_var_inc;
std::string cached_token_chars;
bool is_valid_utf8(const char * string) {
if (!string) {
return true;
}
const unsigned char * bytes = (const unsigned char *)string;
int num;
while (*bytes != 0x00) {
if ((*bytes & 0x80) == 0x00) {
// U+0000 to U+007F
num = 1;
} else if ((*bytes & 0xE0) == 0xC0) {
// U+0080 to U+07FF
num = 2;
} else if ((*bytes & 0xF0) == 0xE0) {
// U+0800 to U+FFFF
num = 3;
} else if ((*bytes & 0xF8) == 0xF0) {
// U+10000 to U+10FFFF
num = 4;
} else {
return false;
}
bytes += 1;
for (int i = 1; i < num; ++i) {
if ((*bytes & 0xC0) != 0x80) {
return false;
}
bytes += 1;
}
}
return true;
}
static void log_callback(ggml_log_level level, const char * fmt, void * data) {
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
llama_model_params model_params = llama_model_default_params();
auto path_to_model = env->GetStringUTFChars(filename, 0);
LOGi("Loading model from %s", path_to_model);
auto model = llama_model_load_from_file(path_to_model, model_params);
env->ReleaseStringUTFChars(filename, path_to_model);
if (!model) {
LOGe("load_model() failed");
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
return 0;
}
return reinterpret_cast<jlong>(model);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
llama_model_free(reinterpret_cast<llama_model *>(model));
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) {
auto model = reinterpret_cast<llama_model *>(jmodel);
if (!model) {
LOGe("new_context(): model cannot be null");
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
return 0;
}
int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
LOGi("Using %d threads", n_threads);
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 2048;
ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads;
llama_context * context = llama_new_context_with_model(model, ctx_params);
if (!context) {
LOGe("llama_new_context_with_model() returned null)");
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
"llama_new_context_with_model() returned null)");
return 0;
}
return reinterpret_cast<jlong>(context);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) {
llama_free(reinterpret_cast<llama_context *>(context));
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) {
llama_backend_free();
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
llama_log_set(log_callback, NULL);
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_bench_1model(
JNIEnv *env,
jobject,
jlong context_pointer,
jlong model_pointer,
jlong batch_pointer,
jint pp,
jint tg,
jint pl,
jint nr
) {
auto pp_avg = 0.0;
auto tg_avg = 0.0;
auto pp_std = 0.0;
auto tg_std = 0.0;
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto model = reinterpret_cast<llama_model *>(model_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const int n_ctx = llama_n_ctx(context);
LOGi("n_ctx = %d", n_ctx);
int i, j;
int nri;
for (nri = 0; nri < nr; nri++) {
LOGi("Benchmark prompt processing (pp)");
common_batch_clear(*batch);
const int n_tokens = pp;
for (i = 0; i < n_tokens; i++) {
common_batch_add(*batch, 0, i, { 0 }, false);
}
batch->logits[batch->n_tokens - 1] = true;
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) {
LOGi("llama_decode() failed during prompt processing");
}
const auto t_pp_end = ggml_time_us();
// bench text generation
LOGi("Benchmark text generation (tg)");
llama_memory_clear(llama_get_memory(context), false);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {
common_batch_clear(*batch);
for (j = 0; j < pl; j++) {
common_batch_add(*batch, 0, i, { j }, true);
}
LOGi("llama_decode() text generation: %d", i);
if (llama_decode(context, *batch) != 0) {
LOGi("llama_decode() failed during text generation");
}
}
const auto t_tg_end = ggml_time_us();
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
const auto speed_pp = double(pp) / t_pp;
const auto speed_tg = double(pl * tg) / t_tg;
pp_avg += speed_pp;
tg_avg += speed_tg;
pp_std += speed_pp * speed_pp;
tg_std += speed_tg * speed_tg;
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
}
pp_avg /= double(nr);
tg_avg /= double(nr);
if (nr > 1) {
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
} else {
pp_std = 0;
tg_std = 0;
}
char model_desc[128];
llama_model_desc(model, model_desc, sizeof(model_desc));
const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
const auto backend = "(Android)"; // TODO: What should this be?
std::stringstream result;
result << std::setprecision(2);
result << "| model | size | params | backend | test | t/s |\n";
result << "| --- | --- | --- | --- | --- | --- |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
return env->NewStringUTF(result.str().c_str());
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
llama_batch *batch = new llama_batch {
0,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
};
if (embd) {
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
} else {
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
}
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
for (int i = 0; i < n_tokens; ++i) {
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return reinterpret_cast<jlong>(batch);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
//llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
delete batch;
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = true;
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
return reinterpret_cast<jlong>(smpl);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer));
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
llama_backend_init();
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) {
return env->NewStringUTF(llama_print_system_info());
}
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_completion_1init(
JNIEnv *env,
jobject,
jlong context_pointer,
jlong batch_pointer,
jstring jtext,
jboolean format_chat,
jint n_len
) {
cached_token_chars.clear();
const auto text = env->GetStringUTFChars(jtext, 0);
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
bool parse_special = (format_chat == JNI_TRUE);
const auto tokens_list = common_tokenize(context, text, true, parse_special);
auto n_ctx = llama_n_ctx(context);
auto n_kv_req = tokens_list.size() + n_len;
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
if (n_kv_req > n_ctx) {
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
}
for (auto id : tokens_list) {
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
}
common_batch_clear(*batch);
// evaluate the initial prompt
for (auto i = 0; i < tokens_list.size(); i++) {
common_batch_add(*batch, tokens_list[i], i, { 0 }, false);
}
// llama_decode will output logits only for the last token of the prompt
batch->logits[batch->n_tokens - 1] = true;
if (llama_decode(context, *batch) != 0) {
LOGe("llama_decode() failed");
}
env->ReleaseStringUTFChars(jtext, text);
return batch->n_tokens;
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_completion_1loop(
JNIEnv * env,
jobject,
jlong context_pointer,
jlong batch_pointer,
jlong sampler_pointer,
jint n_len,
jobject intvar_ncur
) {
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
const auto model = llama_get_model(context);
const auto vocab = llama_model_get_vocab(model);
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
// sample the most likely token
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
return nullptr;
}
auto new_token_chars = common_token_to_piece(context, new_token_id);
cached_token_chars += new_token_chars;
jstring new_token = nullptr;
if (is_valid_utf8(cached_token_chars.c_str())) {
new_token = env->NewStringUTF(cached_token_chars.c_str());
LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
cached_token_chars.clear();
} else {
new_token = env->NewStringUTF("");
}
common_batch_clear(*batch);
common_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
env->CallVoidMethod(intvar_ncur, la_int_var_inc);
if (llama_decode(context, *batch) != 0) {
LOGe("llama_decode() returned null");
}
return new_token;
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_memory_clear(llama_get_memory(reinterpret_cast<llama_context *>(context)), true);
}

View File

@ -1,180 +0,0 @@
package android.llama.cpp
import android.util.Log
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.withContext
import java.util.concurrent.Executors
import kotlin.concurrent.thread
class LLamaAndroid {
private val tag: String? = this::class.simpleName
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
thread(start = false, name = "Llm-RunLoop") {
Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
// No-op if called more than once.
System.loadLibrary("llama-android")
// Set llama log handler to Android
log_to_android()
backend_init(false)
Log.d(tag, system_info())
it.run()
}.apply {
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
Log.e(tag, "Unhandled exception", exception)
}
}
}.asCoroutineDispatcher()
private val nlen: Int = 64
private external fun log_to_android()
private external fun load_model(filename: String): Long
private external fun free_model(model: Long)
private external fun new_context(model: Long): Long
private external fun free_context(context: Long)
private external fun backend_init(numa: Boolean)
private external fun backend_free()
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
private external fun free_batch(batch: Long)
private external fun new_sampler(): Long
private external fun free_sampler(sampler: Long)
private external fun bench_model(
context: Long,
model: Long,
batch: Long,
pp: Int,
tg: Int,
pl: Int,
nr: Int
): String
private external fun system_info(): String
private external fun completion_init(
context: Long,
batch: Long,
text: String,
formatChat: Boolean,
nLen: Int
): Int
private external fun completion_loop(
context: Long,
batch: Long,
sampler: Long,
nLen: Int,
ncur: IntVar
): String?
private external fun kv_cache_clear(context: Long)
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
return withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
Log.d(tag, "bench(): $state")
bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
}
else -> throw IllegalStateException("No model loaded")
}
}
}
suspend fun load(pathToModel: String) {
withContext(runLoop) {
when (threadLocalState.get()) {
is State.Idle -> {
val model = load_model(pathToModel)
if (model == 0L) throw IllegalStateException("load_model() failed")
val context = new_context(model)
if (context == 0L) throw IllegalStateException("new_context() failed")
val batch = new_batch(512, 0, 1)
if (batch == 0L) throw IllegalStateException("new_batch() failed")
val sampler = new_sampler()
if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
Log.i(tag, "Loaded model $pathToModel")
threadLocalState.set(State.Loaded(model, context, batch, sampler))
}
else -> throw IllegalStateException("Model already loaded")
}
}
}
fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
while (ncur.value <= nlen) {
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
if (str == null) {
break
}
emit(str)
}
kv_cache_clear(state.context)
}
else -> {}
}
}.flowOn(runLoop)
/**
* Unloads the model and frees resources.
*
* This is a no-op if there's no model loaded.
*/
suspend fun unload() {
withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
free_context(state.context)
free_model(state.model)
free_batch(state.batch)
free_sampler(state.sampler);
threadLocalState.set(State.Idle)
}
else -> {}
}
}
}
companion object {
private class IntVar(value: Int) {
@Volatile
var value: Int = value
private set
fun inc() {
synchronized(this) {
value += 1
}
}
}
private sealed interface State {
data object Idle: State
data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State
}
// Enforce only one instance of Llm.
private val _instance: LLamaAndroid = LLamaAndroid()
fun instance(): LLamaAndroid = _instance
}
}

View File

@ -8,11 +8,11 @@ pluginManagement {
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
mavenCentral()
google()
}
}
rootProject.name = "LlamaAndroid"
rootProject.name = "AiChat"
include(":app")
include(":llama")
include(":lib")

View File

@ -5,7 +5,7 @@ import os
import importlib
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
import torch
import numpy as np
@ -116,11 +116,11 @@ def debug_hook(name):
def fn(_m, input, output):
if isinstance(input, torch.Tensor):
summarize(input, name + "_in")
elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor):
elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
summarize(input[0], name + "_in")
if isinstance(output, torch.Tensor):
summarize(output, name + "_out")
elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor):
elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
summarize(output[0], name + "_out")
return fn
@ -130,6 +130,7 @@ unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
parser = argparse.ArgumentParser(description="Process model with specified path")
parser.add_argument("--model-path", "-m", help="Path to the model")
parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
args = parser.parse_args()
model_path = os.environ.get("MODEL_PATH", args.model_path)
@ -142,8 +143,13 @@ if model_path is None:
print("Loading model and tokenizer using AutoTokenizer:", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
multimodal = False
full_config = config
print("Model type: ", config.model_type)
if "vocab_size" not in config and "text_config" in config:
config = config.text_config
multimodal = True
print("Vocab size: ", config.vocab_size)
print("Hidden size: ", config.hidden_size)
print("Number of layers: ", config.num_hidden_layers)
@ -168,6 +174,11 @@ if unreleased_model_name:
except (ImportError, AttributeError) as e:
print(f"Failed to import or load model: {e}")
exit(1)
else:
if multimodal:
model = AutoModelForImageTextToText.from_pretrained(
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=full_config
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
@ -185,7 +196,10 @@ model_name = os.path.basename(model_path)
print(f"Model class: {model.__class__.__name__}")
device = next(model.parameters()).device
if os.getenv("MODEL_TESTING_PROMPT"):
if args.prompt_file:
with open(args.prompt_file, encoding='utf-8') as f:
prompt = f.read()
elif os.getenv("MODEL_TESTING_PROMPT"):
prompt = os.getenv("MODEL_TESTING_PROMPT")
else:
prompt = "Hello, my name is"
@ -195,9 +209,18 @@ print(f"Input tokens: {input_ids}")
print(f"Input text: {repr(prompt)}")
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
batch_size = 512
with torch.no_grad():
outputs = model(input_ids.to(model.device))
logits = outputs.logits
past = None
outputs = None
for i in range(0, input_ids.size(1), batch_size):
print(f"Processing chunk with tokens {i} to {i + batch_size}")
chunk = input_ids[:, i:i + batch_size]
outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True)
past = outputs.past_key_values
logits = outputs.logits # type: ignore
# Extract logits for the last token (next token prediction)
last_logits = logits[0, -1, :].float().cpu().numpy()

View File

@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
bool accept = false;
if (params.sampling.temp > 0) {
// stochastic verification
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
continue;
}
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);

View File

@ -386,6 +386,9 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD)
ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
ggml_add_cpu_backend_variant(android_armv9.0_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2)
ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SME)
ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
elseif (APPLE)
ggml_add_cpu_backend_variant(apple_m1 DOTPROD)
ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8)

View File

@ -43,6 +43,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@ -51,6 +53,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
// repack.cpp
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
@ -67,10 +71,14 @@
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__POWERPC__) || defined(__powerpc__)
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
// quants.c
@ -91,6 +99,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@ -99,6 +109,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__loongarch64)
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
@ -119,6 +131,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@ -127,6 +141,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__riscv)
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
@ -154,6 +170,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
@ -161,6 +179,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__s390x__)
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
@ -187,6 +207,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@ -195,6 +217,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__wasm__)
// quants.c
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
@ -223,6 +247,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@ -231,4 +257,6 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#endif

View File

@ -786,6 +786,133 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q8_0_4x4_q8_0(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int ncols_interleaved = 4;
const int blocklen = 4;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
for (int c = 0; c < nc; c += ncols_interleaved) {
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
float32x4_t acc = vdupq_n_f32(0);
for (int b = 0; b < nb; b++) {
int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
int32x4_t ret = vdupq_n_s32(0);
ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
a_ptr++;
b_ptr++;
}
vst1q_f32(s, acc);
s += ncols_interleaved;
}
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q8_0_4x8_q8_0(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int ncols_interleaved = 4;
const int blocklen = 8;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
for (int c = 0; c < nc; c += ncols_interleaved) {
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
float32x4_t acc = vdupq_n_f32(0);
for (int b = 0; b < nb; b++) {
int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
int32x4_t ret0 = vdupq_n_s32(0);
int32x4_t ret1 = vdupq_n_s32(0);
// 0..7
ret0 = vdotq_s32(ret0, b_low.val[0], a0);
ret1 = vdotq_s32(ret1, b_low.val[1], a0);
// 8..15
ret0 = vdotq_s32(ret0, b_low.val[2], a1);
ret1 = vdotq_s32(ret1, b_low.val[3], a1);
// 16..23
ret0 = vdotq_s32(ret0, b_high.val[0], a2);
ret1 = vdotq_s32(ret1, b_high.val[1], a2);
// 24..31
ret0 = vdotq_s32(ret0, b_high.val[2], a3);
ret1 = vdotq_s32(ret1, b_high.val[3], a3);
int32x4_t ret = vpaddq_s32(ret0, ret1);
acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
a_ptr++;
b_ptr++;
}
vst1q_f32(s, acc);
s += ncols_interleaved;
}
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
@ -2610,3 +2737,159 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q8_0_4x4_q8_0(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int ncols_interleaved = 4;
const int blocklen = 4;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
for (int y = 0; y < nr / 4; y++) {
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
float32x4_t sumf[4];
for (int m = 0; m < 4; m++) {
sumf[m] = vdupq_n_f32(0);
}
for (int l = 0; l < nb; l++) {
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));
int32x4_t sumi_0 = vdupq_n_s32(0);
int32x4_t sumi_1 = vdupq_n_s32(0);
int32x4_t sumi_2 = vdupq_n_s32(0);
int32x4_t sumi_3 = vdupq_n_s32(0);
for (int k_group = 0; k_group < 8; k_group += 4) {
int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);
int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);
for (int k = 0; k < 4; k++) {
sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);
sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);
sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);
sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);
}
}
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
}
for (int m = 0; m < 4; m++) {
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
}
}
}
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q8_0_4x8_q8_0(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int ncols_interleaved = 4;
const int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
for (int y = 0; y < nr; y += 4) {
const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
for (int x = 0; x < nc; x += ncols_interleaved) {
const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
const block_q8_0x4 * a_ptr = a_ptr_base;
float32x4_t acc_f32[4];
for (int i = 0; i < 4; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
int32x4_t acc[4];
for (int i = 0; i < 4; i++) {
acc[i] = vdupq_n_s32(0);
}
// Process 4 chunks of 8 positions each
for (int chunk = 0; chunk < 4; chunk++) {
int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);
int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);
int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);
int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);
acc[0] = vmmlaq_s32(acc[0], a01, b01);
acc[1] = vmmlaq_s32(acc[1], a01, b23);
acc[2] = vmmlaq_s32(acc[2], a23, b01);
acc[3] = vmmlaq_s32(acc[3], a23, b23);
}
// Reorder outputs from 2×2 tiles to row-major
// acc[0] = [r0c0, r0c1, r1c0, r1c1]
// acc[1] = [r0c2, r0c3, r1c2, r1c3]
// acc[2] = [r2c0, r2c1, r3c0, r3c1]
// acc[3] = [r2c2, r2c3, r3c2, r3c3]
int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));
int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));
int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));
int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));
// Scales
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));
acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));
acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));
acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));
acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));
a_ptr++;
b_ptr++;
}
for (int row = 0; row < 4; row++) {
vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);
}
}
}
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}

View File

@ -692,6 +692,100 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
}
}
void ggml_gemv_q8_0_4x4_q8_0_generic(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int ncols_interleaved = 4;
const int blocklen = 4;
assert(nr == 1);
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[4];
int sumi;
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0;
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / blocklen); k++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
sumi += v0 * a_ptr[l].qs[k * blocklen + i];
}
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j];
}
}
}
void ggml_gemv_q8_0_4x8_q8_0_generic(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int ncols_interleaved = 4;
const int blocklen = 8;
assert(nr == 1);
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[4];
int sumi;
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0;
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / blocklen); k++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
sumi += v0 * a_ptr[l].qs[k * blocklen + i];
}
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j];
}
}
}
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
@ -1219,8 +1313,129 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
}
}
void ggml_gemm_q8_0_4x4_q8_0_generic(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int ncols_interleaved = 4;
const int blocklen = 4;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
float sumf[4][4];
int sumi;
for (int y = 0; y < nr / 4; y++) {
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0;
}
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / blocklen); k++) {
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
}
sumf[m][j] +=
sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
}
}
}
}
}
void ggml_gemm_q8_0_4x8_q8_0_generic(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK8_0;
const int nb = n / qk;
const int ncols_interleaved = 4;
const int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
float sumf[4][4];
int sumi;
for (int y = 0; y < nr / 4; y++) {
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0;
}
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / blocklen); k++) {
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
}
sumf[m][j] +=
sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
}
}
}
}
}
} // extern "C"
static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) {
block_q8_0x4 out;
for (int i = 0; i < 4; i++) {
out.d[i] = in[i].d;
}
const int end = QK8_0 * 4 / blck_size_interleave;
for (int i = 0; i < end; ++i) {
int src_id = i % 4;
int src_offset = (i / 4) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
}
return out;
}
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
block_q4_0x4 out;
@ -1534,6 +1749,38 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block
GGML_UNUSED(data_size);
}
static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t,
int interleave_block,
const void * GGML_RESTRICT data,
size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q8_0);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
constexpr int nrows_interleaved = 4;
block_q8_0x4 * dst = (block_q8_0x4 *) t->data;
const block_q8_0 * src = (const block_q8_0 *) data;
block_q8_0 dst_tmp[4];
int nrow = ggml_nrows(t);
int nblocks = t->ne[0] / QK8_0;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0));
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
return -1;
}
for (int b = 0; b < nrow; b += nrows_interleaved) {
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++) {
dst_tmp[i] = src[x + i * nblocks];
}
*dst++ = make_block_q8_0x4(dst_tmp, interleave_block);
}
src += nrows_interleaved * nblocks;
}
return 0;
}
static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
block_iq4_nlx4 out;
@ -1702,6 +1949,14 @@ template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void *
return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
}
template <> int repack<block_q8_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
}
template <> int repack<block_q8_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size);
}
// gemv
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
void gemv(int, float *, size_t, const void *, const void *, int, int);
@ -1738,6 +1993,14 @@ template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
// gemm
template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
void gemm(int, float *, size_t, const void *, const void *, int, int);
@ -1774,6 +2037,14 @@ template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
class tensor_traits_base : public ggml::cpu::tensor_traits {
public:
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
@ -2168,6 +2439,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
// instance for Q8_0
static const ggml::cpu::repack::tensor_traits<block_q8_0, 4, 4, GGML_TYPE_Q8_0> q8_0_4x4_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q8_0, 8, 4, GGML_TYPE_Q8_0> q8_0_4x8_q8_0;
if (cur->type == GGML_TYPE_Q4_0) {
if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)
|| (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) {
@ -2218,6 +2493,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
return &iq4_nl_4x4_q8_0;
}
}
} else if (cur->type == GGML_TYPE_Q8_0) {
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
if (cur->ne[1] % 4 == 0) {
return &q8_0_4x8_q8_0;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 4 == 0) {
return &q8_0_4x4_q8_0;
}
}
}
return nullptr;

View File

@ -98,6 +98,10 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
// Native implementations
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
@ -120,6 +124,10 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#if defined(__cplusplus)
} // extern "C"

View File

@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
argmax = shared_argmax[lane_id];
}
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {

View File

@ -76,15 +76,31 @@ namespace ggml_cuda_mma {
// For the A/C matrices this means I major == row major, J major == column major.
// For the B matrix this means I major == column major, J major == row major.
// MIRRORED == Each data value is held exactly once per thread subgroup.
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
DATA_LAYOUT_I_MAJOR_MIRRORED = 10,
DATA_LAYOUT_J_MAJOR_MIRRORED = 20,
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
DATA_LAYOUT_I_MAJOR_MIRRORED = 20,
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3.
};
// Implemented mma combinations are:
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
constexpr bool is_i_major(const data_layout dl) {
return dl == DATA_LAYOUT_I_MAJOR ||
dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
dl == DATA_LAYOUT_I_MAJOR_DUAL;
}
constexpr data_layout get_input_data_layout() {
#if defined(RDNA3)
return DATA_LAYOUT_I_MAJOR_DUAL;
#else
return DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
}
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
struct tile {};
@ -115,9 +131,9 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 4) {
return threadIdx.x % 32;
} else if constexpr (I == 16 && J == 16) {
return 4 * (threadIdx.x / 16) + l;
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 32) {
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
return threadIdx.x % 32;
} else {
NO_DEVICE_CODE;
return -1;
@ -132,9 +148,9 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 4) {
return 2 * (threadIdx.x / 32) + l;
} else if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
return 4 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 32) {
return threadIdx.x % 32;
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
} else {
NO_DEVICE_CODE;
return -1;
@ -171,28 +187,19 @@ namespace ggml_cuda_mma {
}
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
#elif defined(RDNA3)
static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16;
#endif // defined(RDNA4)
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 4) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
#if defined(RDNA4)
return 8 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return 2 * l + (threadIdx.x / 16);
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
if constexpr (supported()) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
@ -201,7 +208,17 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
// matrix C
#if defined(RDNA3)
return 2 * l + (threadIdx.x / 16);
#else
return ne * (threadIdx.x / 16) + l;
#endif // defined(RDNA3)
} else if constexpr (I == 16 && J == 8) {
// mmq input for RDNA4
return ne * (threadIdx.x / 16) + l;
} else if constexpr (I == 16 && J == 4) {
return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
@ -293,12 +310,7 @@ namespace ggml_cuda_mma {
}
}
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA3)
// RDNA3 has duplicated data as input.
static constexpr int ne = I * J / 32 * 2;
#else
static constexpr int ne = I * J / 32;
#endif // defined(RDNA3)
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
@ -317,14 +329,7 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return l;
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
} else {
NO_DEVICE_CODE;
return -1;
@ -382,42 +387,19 @@ namespace ggml_cuda_mma {
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA3)
// RDNA3 has duplicated data as input.
static constexpr int ne = I * J / 32 * 2;
#else
static constexpr int ne = I * J / 32;
#endif // defined(RDNA3)
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return l;
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
} else {
NO_DEVICE_CODE;
return -1;
}
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
}
#else
static constexpr int ne = I * J / WARP_SIZE;
@ -458,6 +440,28 @@ namespace ggml_cuda_mma {
#endif // defined(AMD_WMMA_AVAILABLE)
};
template <int I_, int J_, typename T>
struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
T x[ne] = {0};
static constexpr __device__ bool supported() {
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
}
static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
}
};
template <int I_, int J_>
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
static constexpr int I = I_;
@ -524,6 +528,42 @@ namespace ggml_cuda_mma {
}
};
template <int I_, int J_, typename T>
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
static constexpr int ne = I * J / 32 * 2;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 4) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (supported()) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (supported()) {
return l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
};
#if defined(TURING_MMA_AVAILABLE)
template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
@ -569,55 +609,28 @@ namespace ggml_cuda_mma {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
} else {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
}
#elif defined(AMD_WMMA_AVAILABLE)
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
#if defined(RDNA4)
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
#elif defined(RDNA3)
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2));
#else
NO_DEVICE_CODE;
#endif // defined(RDNA4)
} else if constexpr (std::is_same_v<T, int>) {
if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x;
#if defined(RDNA4)
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
#elif defined(RDNA3)
static_assert(tile<I,J,T>::ne >= 4, "fragment too small");
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
xi[0] = xs[0];
xi[1] = xs[1];
#endif // defined(RDNA4)
} else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x;
#if defined(RDNA4)
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
#elif defined(RDNA3)
static_assert(tile<I,J,T>::ne >= 8, "fragment too small");
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
// contiguous four 64-bit chunks per lane for the wider RDNA3 fragment
xi[0] = xs[0];
xi[1] = xs[1];
const int64_t * xs1 = xs + 2;
xi[2] = xs1[0];
xi[3] = xs1[1];
#endif // defined(RDNA4)
} else {
NO_DEVICE_CODE;
// All wmma layout has contiguous data when i-major.
if constexpr (is_i_major(dl)) {
// the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
if constexpr (sizeof(t.x) > aligned_copy_bytes) {
static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
#pragma unroll
for (int i = 0; i < aligned_copy_count; ++i) {
ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
}
} else {
NO_DEVICE_CODE;
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
}
} else {
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
}
#else
#pragma unroll
@ -660,9 +673,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE
}
template <typename T>
template <typename T, data_layout dl>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
#if defined(TURING_MMA_AVAILABLE)
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
@ -832,8 +845,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE
}
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
@ -887,8 +901,9 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
@ -940,8 +955,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE
}
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
@ -967,8 +983,9 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
template <data_layout dl_d, data_layout dl_ab>
static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
#if defined(AMD_MFMA_AVAILABLE)
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
int32x4_t * acc = (int32x4_t *) D.x;
@ -1122,8 +1139,9 @@ namespace ggml_cuda_mma {
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
}
template <data_layout dl_d, data_layout dl_ab>
static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
#if defined(AMD_WMMA_AVAILABLE)
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

View File

@ -32,11 +32,13 @@ static __global__ void mul_mat_f(
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
typedef tile<16, 8, T> tile_A;
typedef tile<tile_B_I, 8, T> tile_B;
typedef tile<16, tile_C_J, float> tile_C;
constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
@ -272,11 +274,13 @@ static __global__ void mul_mat_f_ids(
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
typedef tile<16, 8, T> tile_A;
typedef tile<tile_B_I, 8, T> tile_B;
typedef tile<16, tile_C_J, float> tile_C;
constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_B_I = is_tf32 ? 8 : 16;
constexpr int tile_C_J = is_tf32 ? 8 : 16;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {

View File

@ -797,9 +797,10 @@ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, int> tile_A;
typedef tile<16, 8, int> tile_B;
typedef tile<16, 16, int> tile_C;
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@ -966,9 +967,10 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, int> tile_A;
typedef tile<16, 8, int> tile_B;
typedef tile<16, 16, int> tile_C;
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@ -1130,10 +1132,11 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE)
typedef tile<16, 8, int> tile_A;
typedef tile<16, 8, int> tile_B;
typedef tile<16, 16, int> tile_C;
typedef tile<64, 2, int> tile_load;
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@ -1179,9 +1182,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
typedef tile<16, 4, int> tile_A;
typedef tile<16, 4, int> tile_B;
typedef tile<16, 16, int> tile_C;
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@ -1435,10 +1439,11 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE)
typedef tile<16, 8, int> tile_A;
typedef tile<16, 8, int> tile_B;
typedef tile<16, 16, int> tile_C;
typedef tile<64, 2, int> tile_load;
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@ -1501,10 +1506,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
typedef tile<16, 4, int> tile_A;
typedef tile<16, 4, int> tile_B;
typedef tile<16, 16, int> tile_C;
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@ -2265,10 +2270,11 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE)
typedef tile<16, 8, int> tile_A;
typedef tile<16, 8, int> tile_B;
typedef tile<16, 16, int> tile_C;
typedef tile<64, 2, int> tile_load;
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@ -2316,9 +2322,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
typedef tile<16, 4, int> tile_A;
typedef tile<16, 4, int> tile_B;
typedef tile<16, 16, int> tile_C;
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@ -3015,7 +3022,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int tileC_IJ = mmq_get_granularity_device(0);
typedef tile<tileC_IJ, tileC_IJ, int> tile_C;
typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int rows_per_warp = granularity;
#else
typedef tile<16, 8, int> tile_C;

View File

@ -288,7 +288,7 @@ class LocalTensor:
data_range: LocalTensorRange
def mmap_bytes(self) -> np.ndarray:
return np.memmap(self.data_range.filename, offset=self.data_range.offset, shape=self.data_range.size)
return np.memmap(self.data_range.filename, mode='r', offset=self.data_range.offset, shape=self.data_range.size)
class SafetensorsLocal:

View File

@ -2055,7 +2055,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_SHORTCONV_INPROJ,
LLM_TENSOR_SHORTCONV_OUTPROJ,
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_OUTPUT_NORM_LFM2,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_EXPS,
LLM_TENSOR_FFN_DOWN_EXPS,

View File

@ -362,23 +362,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
}
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
if (!smpl) {
return;
}
if (smpl->iface->accept) {
smpl->iface->accept(smpl, token);
}
}
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
if (!smpl) {
return;
}
GGML_ASSERT(smpl->iface->apply);
smpl->iface->apply(smpl, cur_p);
}
void llama_sampler_reset(struct llama_sampler * smpl) {
if (!smpl) {
return;
}
if (smpl->iface->reset) {
smpl->iface->reset(smpl);
}
}
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
if (!smpl) {
return nullptr;
}
if (smpl->iface->clone) {
return smpl->iface->clone(smpl);
}

View File

@ -73,6 +73,7 @@ static std::vector<llama_device_memory_data> llama_get_device_memory_data(
llama_model_params mparams_copy = *mparams;
mparams_copy.no_alloc = true;
mparams_copy.use_mmap = false;
mparams_copy.use_mlock = false;
llama_model * model = llama_model_load_from_file(path_model, mparams_copy);
if (model == nullptr) {
@ -184,6 +185,7 @@ static void llama_params_fit_impl(
int64_t sum_projected_free = 0;
int64_t min_projected_free = INT64_MAX;
int64_t sum_projected_used = 0;
int64_t sum_projected_model = 0;
int64_t sum_projected_ctx = 0;
if (nd > 1) {
@ -199,6 +201,7 @@ static void llama_params_fit_impl(
sum_projected_used += projected_used;
sum_projected_free += projected_free;
min_projected_free = std::min(min_projected_free, projected_free);
sum_projected_model += dmd.mb.model;
sum_projected_ctx += dmd.mb.context;
if (nd > 1) {
@ -234,10 +237,24 @@ static void llama_params_fit_impl(
if (cparams->n_ctx == 0) {
if (hp_nct > n_ctx_min) {
const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct;
const uint32_t ctx_reduction = std::min(
uint32_t((-global_surplus + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
int64_t memory_reduction = -global_surplus;
if (nd > 1) {
// for multiple devices we need to be more conservative in terms of how much context we think can fit:
// - for dense models only whole layers can be assigned to devices
// - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer
// - on average we expect a waste of 0.5 layers/tensors per device
// - use slightly more than the expected average for nd devices to be safe
const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl);
memory_reduction += (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6);
}
uint32_t ctx_reduction = std::min(uint32_t((memory_reduction + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min);
cparams->n_ctx = hp_nct - ctx_reduction;
const int64_t memory_reduction = ctx_reduction * bytes_per_ctx;
cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend
ctx_reduction = hp_nct - cparams->n_ctx;
memory_reduction = ctx_reduction * bytes_per_ctx;
global_surplus += memory_reduction;
LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n",
__func__, hp_nct, cparams->n_ctx, memory_reduction/MiB);
@ -481,8 +498,13 @@ static void llama_params_fit_impl(
} else {
LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__);
}
uint32_t n_unassigned = hp_ngl;
for (int id = nd - 1; id >= 0; id--) {
uint32_t n_unassigned = hp_ngl;
for (size_t jd = id + 1; jd < nd; ++jd) {
assert(n_unassigned >= ngl_per_device[jd].n_layer);
n_unassigned -= ngl_per_device[jd].n_layer;
}
std::vector<ngl_t> ngl_per_device_high = ngl_per_device;
ngl_per_device_high[id].n_layer = n_unassigned;
if (hp_nex > 0) {
@ -491,7 +513,9 @@ static void llama_params_fit_impl(
if (ngl_per_device_high[id].n_layer > 0) {
std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe);
if (mem_high[id] > targets[id]) {
assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer);
uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta);
while (delta > 1) {
uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]);
step_size = std::max(step_size, uint32_t(1));
@ -507,18 +531,17 @@ static void llama_params_fit_impl(
if (mem_test[id] <= targets[id]) {
ngl_per_device = ngl_per_device_test;
mem = mem_test;
n_unassigned -= ngl_per_device[id].n_layer;
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
} else {
ngl_per_device_high = ngl_per_device_test;
mem_high = mem_test;
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer);
}
delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer;
}
} else {
assert(ngl_per_device_high[id].n_layer == n_unassigned);
ngl_per_device = ngl_per_device_high;
n_unassigned -= ngl_per_device[id].n_layer;
LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer);
}
}

View File

@ -329,6 +329,7 @@ struct mtmd_context {
case PROJECTOR_TYPE_QWEN25O:
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_GLMA:
audio_preproc = std::make_unique<mtmd_audio_preprocessor_whisper>(ctx_a);
break;
default:

View File

@ -17,6 +17,7 @@
#include <chrono>
#include <queue>
#include <filesystem>
#include <cstring>
#ifdef _WIN32
#include <winsock2.h>
@ -33,7 +34,8 @@
#include <limits.h>
#endif
#define CMD_EXIT "exit"
#define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready"
// address for child process, this is needed because router may run on 0.0.0.0
// ref: https://github.com/ggml-org/llama.cpp/issues/17862
@ -534,6 +536,8 @@ void server_models::load(const std::string & name) {
std::vector<char *> argv = to_char_ptr_array(child_args);
std::vector<char *> envp = to_char_ptr_array(child_env);
// TODO @ngxson : maybe separate stdout and stderr in the future
// so that we can use stdout for commands and stderr for logging
int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
if (result != 0) {
@ -547,11 +551,17 @@ void server_models::load(const std::string & name) {
// captured variables are guaranteed to be destroyed only after the thread is joined
inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() {
// read stdout/stderr and forward to main server log
bool state_received = false; // true if child state received
FILE * p_stdout_stderr = subprocess_stdout(child_proc.get());
if (p_stdout_stderr) {
char buffer[4096];
while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) {
LOG("[%5d] %s", port, buffer);
if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) {
// child process is ready
this->update_status(name, SERVER_MODEL_STATUS_LOADED);
state_received = true;
}
}
} else {
SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
@ -595,7 +605,7 @@ static void interrupt_subprocess(FILE * stdin_file) {
// because subprocess.h does not provide a way to send SIGINT,
// we will send a command to the child process to exit gracefully
if (stdin_file) {
fprintf(stdin_file, "%s\n", CMD_EXIT);
fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
fflush(stdin_file);
}
}
@ -707,32 +717,13 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
return proxy;
}
std::thread server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler) {
std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler) {
// send a notification to the router server that a model instance is ready
// TODO @ngxson : use HTTP client from libcommon
httplib::Client cli(base_params.hostname, router_port);
cli.set_connection_timeout(0, 200000); // 200 milliseconds
httplib::Request req;
req.method = "POST";
req.path = "/models/status";
req.set_header("Content-Type", "application/json");
if (!base_params.api_keys.empty()) {
req.set_header("Authorization", "Bearer " + base_params.api_keys[0]);
}
json body;
body["model"] = name;
body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED);
req.body = body.dump();
SRV_INF("notifying router server (port=%d) that model %s is ready\n", router_port, name.c_str());
auto result = cli.send(std::move(req));
if (result.error() != httplib::Error::Success) {
auto err_str = httplib::to_string(result.error());
SRV_ERR("failed to notify router server: %s\n", err_str.c_str());
exit(1); // force exit
}
common_log_pause(common_log_main());
fflush(stdout);
fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY);
fflush(stdout);
common_log_resume(common_log_main());
// setup thread for monitoring stdin
return std::thread([shutdown_handler]() {
@ -746,7 +737,7 @@ std::thread server_models::setup_child_server(const common_params & base_params,
eof = true;
break;
}
if (line.find(CMD_EXIT) != std::string::npos) {
if (line.find(CMD_ROUTER_TO_CHILD_EXIT) != std::string::npos) {
SRV_INF("%s", "exit command received, exiting...\n");
shutdown_handler(0);
break;
@ -869,18 +860,6 @@ void server_models_routes::init_routes() {
return res;
};
// used by child process to notify the router about status change
// TODO @ngxson : maybe implement authentication for this endpoint in the future
this->post_router_models_status = [this](const server_http_req & req) {
auto res = std::make_unique<server_http_res>();
json body = json::parse(req.body);
std::string model = json_value(body, "model", std::string());
std::string value = json_value(body, "value", std::string());
models.update_status(model, server_model_status_from_string(value));
res_ok(res, {{"success", true}});
return res;
};
this->get_router_models = [this](const server_http_req &) {
auto res = std::make_unique<server_http_res>();
json models_json = json::array();

View File

@ -144,7 +144,7 @@ public:
// notify the router server that a model instance is ready
// return the monitoring thread (to be joined by the caller)
static std::thread setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler);
static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler);
};
struct server_models_routes {
@ -162,7 +162,6 @@ struct server_models_routes {
server_http_context::handler_t proxy_post;
server_http_context::handler_t get_router_models;
server_http_context::handler_t post_router_models_load;
server_http_context::handler_t post_router_models_status;
server_http_context::handler_t post_router_models_unload;
};

View File

@ -153,7 +153,6 @@ int main(int argc, char ** argv, char ** envp) {
routes.get_models = models_routes->get_router_models;
ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load));
ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
ctx_http.post("/models/status", ex_wrapper(models_routes->post_router_models_status));
}
ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
@ -291,7 +290,7 @@ int main(int argc, char ** argv, char ** envp) {
const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
std::thread monitor_thread;
if (router_port != nullptr) {
monitor_thread = server_models::setup_child_server(params, std::atoi(router_port), params.model_alias, shutdown_handler);
monitor_thread = server_models::setup_child_server(shutdown_handler);
}
// this call blocks the main thread until queue_tasks.terminate() is called