Compare commits

..

No commits in common. "master" and "b7378" have entirely different histories.

257 changed files with 6632 additions and 13461 deletions

View File

@ -107,7 +107,7 @@ ENTRYPOINT ["/app/tools.sh"]
# ENTRYPOINT ["/app/llama-server"]
### Target: light
# Lightweight image containing only llama-cli and llama-completion
# Lightweight image containing only llama-cli
# ==============================================================================
FROM base AS light

View File

@ -23,12 +23,11 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH
RUN echo "Building with static libs" && \
source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \
cmake -B build -DGGML_NATIVE=OFF -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_TESTS=OFF && \
cmake --build build --config Release --target llama-cli && \
cmake --build build --config Release --target llama-completion
cmake --build build --config Release --target llama-cli
# TODO: use image with NNRT
FROM ascendai/cann:$ASCEND_VERSION AS runtime
COPY --from=build /app/build/bin/llama-cli /app/build/bin/llama-completion /
COPY --from=build /app/build/bin/llama-cli /llama-cli
ENV LC_ALL=C.utf8

View File

@ -37,7 +37,6 @@ make -j GGML_CUDA=1
%install
mkdir -p %{buildroot}%{_bindir}/
cp -p llama-cli %{buildroot}%{_bindir}/llama-cuda-cli
cp -p llama-completion %{buildroot}%{_bindir}/llama-cuda-completion
cp -p llama-server %{buildroot}%{_bindir}/llama-cuda-server
cp -p llama-simple %{buildroot}%{_bindir}/llama-cuda-simple
@ -69,7 +68,6 @@ rm -rf %{_builddir}/*
%files
%{_bindir}/llama-cuda-cli
%{_bindir}/llama-cuda-completion
%{_bindir}/llama-cuda-server
%{_bindir}/llama-cuda-simple
/usr/lib/systemd/system/llamacuda.service

View File

@ -39,7 +39,6 @@ make -j
%install
mkdir -p %{buildroot}%{_bindir}/
cp -p llama-cli %{buildroot}%{_bindir}/llama-cli
cp -p llama-completion %{buildroot}%{_bindir}/llama-completion
cp -p llama-server %{buildroot}%{_bindir}/llama-server
cp -p llama-simple %{buildroot}%{_bindir}/llama-simple
@ -71,7 +70,6 @@ rm -rf %{_builddir}/*
%files
%{_bindir}/llama-cli
%{_bindir}/llama-completion
%{_bindir}/llama-server
%{_bindir}/llama-simple
/usr/lib/systemd/system/llama.service

View File

@ -11,7 +11,7 @@ body:
(i.e. the generated text) are incorrect or llama.cpp crashes during model evaluation.
If you encountered the issue while using an external UI (e.g. ollama),
please reproduce your issue using one of the examples/binaries in this repository.
The `llama-completion` binary can be used for simple and reproducible model inference.
The `llama-cli` binary can be used for simple and reproducible model inference.
- type: textarea
id: version
attributes:
@ -74,12 +74,9 @@ body:
Please give us a summary of the problem and tell us how to reproduce it.
If you can narrow down the bug to specific hardware, compile flags, or command line arguments,
that information would be very much appreciated by us.
If possible, please try to reproduce the issue using `llama-completion` with `-fit off`.
If you can only reproduce the issue with `-fit on`, please provide logs both with and without `--verbose`.
placeholder: >
e.g. when I run llama-completion with `-fa on` I get garbled outputs for very long prompts.
With short prompts or `-fa off` it works correctly.
e.g. when I run llama-cli with -ngl 99 I get garbled outputs.
When I use -ngl 0 it works correctly.
Here are the exact commands that I used: ...
validations:
required: true

View File

@ -86,7 +86,6 @@ body:
description: >
If applicable, please copy and paste any relevant log output, including any generated text.
This will be automatically formatted into code, so no need for backticks.
If you are encountering problems specifically with the `llama_params_fit` module, always upload `--verbose` logs as well.
render: shell
validations:
required: false

View File

@ -20,8 +20,7 @@ on:
'**/*.swift',
'**/*.m',
'**/*.metal',
'**/*.comp',
'**/*.glsl'
'**/*.comp'
]
pull_request:
@ -41,8 +40,7 @@ on:
'**/*.swift',
'**/*.m',
'**/*.metal',
'**/*.comp',
'**/*.glsl'
'**/*.comp'
]
concurrency:

View File

@ -1,225 +0,0 @@
# Server WebUI build and tests
name: Server WebUI
on:
workflow_dispatch: # allows manual triggering
inputs:
sha:
description: 'Commit SHA1 to build'
required: false
type: string
slow_tests:
description: 'Run slow tests'
required: true
type: boolean
push:
branches:
- master
paths: ['.github/workflows/server-webui.yml', 'tools/server/webui/**.*', 'tools/server/tests/**.*', 'tools/server/public/**']
pull_request:
types: [opened, synchronize, reopened]
paths: ['.github/workflows/server-webui.yml', 'tools/server/webui/**.*', 'tools/server/tests/**.*', 'tools/server/public/**']
env:
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
LLAMA_LOG_VERBOSITY: 10
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
webui-check:
name: WebUI Checks
runs-on: ubuntu-latest
continue-on-error: true
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Setup Node.js
id: node
uses: actions/setup-node@v4
with:
node-version: "22"
cache: "npm"
cache-dependency-path: "tools/server/webui/package-lock.json"
- name: Install dependencies
id: setup
if: ${{ steps.node.conclusion == 'success' }}
run: npm ci
working-directory: tools/server/webui
- name: Run type checking
if: ${{ always() && steps.setup.conclusion == 'success' }}
run: npm run check
working-directory: tools/server/webui
- name: Run linting
if: ${{ always() && steps.setup.conclusion == 'success' }}
run: npm run lint
working-directory: tools/server/webui
- name: Build application
if: ${{ always() && steps.setup.conclusion == 'success' }}
run: npm run build
working-directory: tools/server/webui
- name: Install Playwright browsers
id: playwright
if: ${{ always() && steps.setup.conclusion == 'success' }}
run: npx playwright install --with-deps
working-directory: tools/server/webui
- name: Build Storybook
if: ${{ always() && steps.playwright.conclusion == 'success' }}
run: npm run build-storybook
working-directory: tools/server/webui
- name: Run Client tests
if: ${{ always() && steps.playwright.conclusion == 'success' }}
run: npm run test:client
working-directory: tools/server/webui
- name: Run Unit tests
if: ${{ always() && steps.playwright.conclusion == 'success' }}
run: npm run test:unit
working-directory: tools/server/webui
- name: Run UI tests
if: ${{ always() && steps.playwright.conclusion == 'success' }}
run: npm run test:ui -- --testTimeout=60000
working-directory: tools/server/webui
- name: Run E2E tests
if: ${{ always() && steps.playwright.conclusion == 'success' }}
run: npm run test:e2e
working-directory: tools/server/webui
server-build:
runs-on: ubuntu-latest
strategy:
matrix:
sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken
build_type: [RelWithDebInfo]
include:
- build_type: Release
sanitizer: ""
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
steps:
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get -y install \
build-essential \
xxd \
git \
cmake \
curl \
wget \
language-pack-en \
libssl-dev
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Python setup
id: setup_python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Tests dependencies
id: test_dependencies
run: |
pip install -r tools/server/tests/requirements.txt
- name: Setup Node.js for WebUI
uses: actions/setup-node@v4
with:
node-version: "22"
cache: "npm"
cache-dependency-path: "tools/server/webui/package-lock.json"
- name: Install WebUI dependencies
run: npm ci
working-directory: tools/server/webui
- name: Build WebUI
run: npm run build
working-directory: tools/server/webui
- name: Build (no OpenMP)
id: cmake_build_no_openmp
if: ${{ matrix.sanitizer == 'THREAD' }}
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DGGML_OPENMP=OFF ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Build (sanitizers)
id: cmake_build_sanitizers
if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }}
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Build (sanitizers)
id: cmake_build
if: ${{ matrix.sanitizer == '' }}
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Tests
id: server_integration_tests
if: ${{ matrix.sanitizer == '' }}
env:
GITHUB_ACTIONS: "true"
run: |
cd tools/server/tests
./tests.sh
- name: Tests (sanitizers)
id: server_integration_tests_sanitizers
if: ${{ matrix.sanitizer != '' }}
run: |
cd tools/server/tests
LLAMA_SANITIZE=1 ./tests.sh
- name: Slow tests
id: server_integration_tests_slow
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
run: |
cd tools/server/tests
SLOW_TESTS=1 ./tests.sh

View File

@ -76,6 +76,270 @@ jobs:
run: |
pip install -r tools/server/tests/requirements.txt
webui-setup:
name: WebUI Setup
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
cache: "npm"
cache-dependency-path: "tools/server/webui/package-lock.json"
- name: Cache node_modules
uses: actions/cache@v4
id: cache-node-modules
with:
path: tools/server/webui/node_modules
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-modules-
- name: Install dependencies
if: steps.cache-node-modules.outputs.cache-hit != 'true'
run: npm ci
working-directory: tools/server/webui
webui-check:
needs: webui-setup
name: WebUI Check
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Restore node_modules cache
uses: actions/cache@v4
with:
path: tools/server/webui/node_modules
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-modules-
- name: Run type checking
run: npm run check
working-directory: tools/server/webui
- name: Run linting
run: npm run lint
working-directory: tools/server/webui
webui-build:
needs: webui-check
name: WebUI Build
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Restore node_modules cache
uses: actions/cache@v4
with:
path: tools/server/webui/node_modules
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-modules-
- name: Build application
run: npm run build
working-directory: tools/server/webui
webui-tests:
needs: webui-build
name: Run WebUI tests
permissions:
contents: read
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: "22"
- name: Restore node_modules cache
uses: actions/cache@v4
with:
path: tools/server/webui/node_modules
key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
restore-keys: |
${{ runner.os }}-node-modules-
- name: Install Playwright browsers
run: npx playwright install --with-deps
working-directory: tools/server/webui
- name: Build Storybook
run: npm run build-storybook
working-directory: tools/server/webui
- name: Run Client tests
run: npm run test:client
working-directory: tools/server/webui
- name: Run Server tests
run: npm run test:server
working-directory: tools/server/webui
- name: Run UI tests
run: npm run test:ui -- --testTimeout=60000
working-directory: tools/server/webui
- name: Run E2E tests
run: npm run test:e2e
working-directory: tools/server/webui
server-build:
needs: [webui-tests]
runs-on: ubuntu-latest
strategy:
matrix:
sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken
build_type: [RelWithDebInfo]
include:
- build_type: Release
sanitizer: ""
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
steps:
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get -y install \
build-essential \
xxd \
git \
cmake \
curl \
wget \
language-pack-en \
libssl-dev
- name: Clone
id: checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Python setup
id: setup_python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Tests dependencies
id: test_dependencies
run: |
pip install -r tools/server/tests/requirements.txt
- name: Setup Node.js for WebUI
uses: actions/setup-node@v4
with:
node-version: "22"
cache: "npm"
cache-dependency-path: "tools/server/webui/package-lock.json"
- name: Install WebUI dependencies
run: npm ci
working-directory: tools/server/webui
- name: Build WebUI
run: npm run build
working-directory: tools/server/webui
- name: Build (no OpenMP)
id: cmake_build_no_openmp
if: ${{ matrix.sanitizer == 'THREAD' }}
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DGGML_OPENMP=OFF ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Build (sanitizers)
id: cmake_build_sanitizers
if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }}
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Build (sanitizers)
id: cmake_build
if: ${{ matrix.sanitizer == '' }}
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_CURL=OFF \
-DLLAMA_OPENSSL=ON \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Tests
id: server_integration_tests
if: ${{ matrix.sanitizer == '' }}
env:
GITHUB_ACTIONS: "true"
run: |
cd tools/server/tests
./tests.sh
- name: Tests (sanitizers)
id: server_integration_tests_sanitizers
if: ${{ matrix.sanitizer != '' }}
run: |
cd tools/server/tests
LLAMA_SANITIZE=1 ./tests.sh
- name: Slow tests
id: server_integration_tests_slow
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
run: |
cd tools/server/tests
SLOW_TESTS=1 ./tests.sh
server-windows:
runs-on: windows-2022

1
.gitignore vendored
View File

@ -54,7 +54,6 @@
/out/
/tmp/
/autogen-*.md
/common/build-info.cpp
# Deprecated

View File

@ -32,7 +32,7 @@
/examples/export-docs/ @ggerganov
/examples/gen-docs/ @ggerganov
/examples/gguf/ @ggerganov
/examples/llama.android/ @ggerganov @hanyin-arm @naco-siren
/examples/llama.android/ @ggerganov
/examples/llama.swiftui/ @ggerganov
/examples/llama.vim @ggerganov
/examples/lookahead/ @ggerganov
@ -87,8 +87,7 @@
/tests/ @ggerganov
/tests/test-chat-.* @pwilkin
/tools/batched-bench/ @ggerganov
/tools/cli/ @ngxson
/tools/completion/ @ggerganov
/tools/main/ @ggerganov
/tools/mtmd/ @ngxson
/tools/perplexity/ @ggerganov
/tools/quantize/ @ggerganov

View File

@ -190,7 +190,6 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
- Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma)
- Android: [llama.android](/examples/llama.android)
</details>
@ -314,7 +313,7 @@ The Hugging Face platform provides a variety of online tools for converting, qua
To learn more about model quantization, [read this documentation](tools/quantize/README.md)
## [`llama-cli`](tools/cli)
## [`llama-cli`](tools/main)
#### A CLI tool for accessing and experimenting with most of `llama.cpp`'s functionality.
@ -526,8 +525,7 @@ To learn more about model quantization, [read this documentation](tools/quantize
## Other documentation
- [cli](tools/cli/README.md)
- [completion](tools/completion/README.md)
- [main (cli)](tools/main/README.md)
- [server](tools/server/README.md)
- [GBNF grammars](grammars/README.md)

View File

@ -68,6 +68,3 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/
Please note that using AI to identify vulnerabilities and generate reports is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before submitting the report.
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
> [!IMPORTANT]
> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080

View File

@ -398,8 +398,6 @@ function gg_run_qwen3_0_6b {
./bin/llama-quantize ${model_bf16} ${model_q5_k} q5_k $(nproc)
./bin/llama-quantize ${model_bf16} ${model_q6_k} q6_k $(nproc)
(time ./bin/llama-fit-params --model ${model_f16} 2>&1 | tee -a $OUT/${ci}-fp-f16.log)
(time ./bin/llama-completion -no-cnv --model ${model_f16} -ngl 99 -c 1024 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-completion -no-cnv --model ${model_bf16} -ngl 99 -c 1024 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-bf16.log
(time ./bin/llama-completion -no-cnv --model ${model_q8_0} -ngl 99 -c 1024 -s 1234 -n 64 --ignore-eos -p "I believe the meaning of life is" ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
@ -525,8 +523,6 @@ function gg_run_embd_bge_small {
./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0
(time ./bin/llama-fit-params --model ${model_f16} 2>&1 | tee -a $OUT/${ci}-fp-f16.log)
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
@ -567,8 +563,6 @@ function gg_run_rerank_tiny {
model_f16="${path_models}/ggml-model-f16.gguf"
(time ./bin/llama-fit-params --model ${model_f16} 2>&1 | tee -a $OUT/${ci}-fp-f16.log)
# for this model, the SEP token is "</s>"
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log

View File

@ -20,7 +20,6 @@
#include <nlohmann/json.hpp>
#include <algorithm>
#include <cinttypes>
#include <climits>
#include <cstdarg>
#include <fstream>
@ -420,8 +419,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}
};
std::set<std::string> seen_args;
for (int i = 1; i < argc; i++) {
const std::string arg_prefix = "--";
@ -432,9 +429,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
if (arg_to_options.find(arg) == arg_to_options.end()) {
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
auto & tmp = arg_to_options[arg];
auto opt = *tmp.first;
bool is_positive = tmp.second;
@ -510,7 +504,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) {
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage) {
throw std::invalid_argument("error: --model is required\n");
}
@ -535,9 +529,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.kv_overrides.back().key[0] = 0;
}
// pad tensor_buft_overrides for llama_params_fit:
const size_t ntbo = llama_max_tensor_buft_overrides();
while (params.tensor_buft_overrides.size() < ntbo) {
if (!params.tensor_buft_overrides.empty()) {
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}
@ -732,7 +724,7 @@ static void add_rpc_devices(const std::string & servers) {
}
}
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map) {
bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map) {
common_params dummy_params;
common_params_context ctx_arg = common_params_parser_init(dummy_params, ex, nullptr);
@ -741,9 +733,6 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
for (const auto & arg : opt.args) {
arg_to_options[arg] = &opt;
}
for (const auto & arg : opt.args_neg) {
arg_to_options[arg] = &opt;
}
}
// TODO @ngxson : find a way to deduplicate this code
@ -755,8 +744,6 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
}
};
std::set<std::string> seen_args;
for (int i = 1; i < argc; i++) {
const std::string arg_prefix = "--";
@ -767,9 +754,6 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
if (arg_to_options.find(arg) == arg_to_options.end()) {
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
if (!seen_args.insert(arg).second) {
LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
}
auto opt = *arg_to_options[arg];
std::string val;
if (opt.value_hint != nullptr) {
@ -845,19 +829,6 @@ bool common_arg_utils::is_autoy(const std::string & value) {
}
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
// per-example default params
// we define here to make sure it's included in llama-gen-docs
if (ex == LLAMA_EXAMPLE_COMPLETION) {
params.use_jinja = false; // disable jinja by default
} else if (ex == LLAMA_EXAMPLE_MTMD) {
params.use_jinja = false; // disable jinja by default
params.sampling.temp = 0.2; // lower temp by default for better quality
} else if (ex == LLAMA_EXAMPLE_SERVER) {
params.n_parallel = -1; // auto by default
}
params.use_color = tty_can_use_colors();
// load dynamic backends
@ -1130,7 +1101,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
{"--ctx-checkpoints", "--swa-checkpoints"}, "N",
string_format("max number of context checkpoints to create per slot (default: %d)"
string_format("max number of context checkpoints to create per slot (default: %d)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
[](common_params & params, int value) {
params.n_ctx_checkpoints = value;
@ -1138,7 +1109,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--cache-ram", "-cram"}, "N",
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)"
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib),
[](common_params & params, int value) {
params.cache_ram_mib = value;
@ -1146,11 +1117,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
"use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)",
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
[](common_params & params) {
params.kv_unified = true;
}
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY}));
).set_env("LLAMA_ARG_KV_UNIFIED"));
add_opt(common_arg(
{"--context-shift"},
{"--no-context-shift"},
@ -1236,15 +1208,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DIFFUSION}));
add_opt(common_arg(
{"--in-file"}, "FNAME",
"an input file (use comma-separated values to specify multiple files)",
"an input file (repeat to specify multiple files)",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
std::ifstream file(item);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
}
params.in_files.push_back(item);
std::ifstream file(value);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
}
params.in_files.push_back(value);
}
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(common_arg(
@ -1442,7 +1412,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.top_k = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
}
).set_sparam().set_env("LLAMA_ARG_TOP_K"));
).set_sparam());
add_opt(common_arg(
{"--top-p"}, "N",
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
@ -1912,27 +1882,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
LOG_WRN("DEPRECATED: --defrag-thold is deprecated and no longer necessary to specify\n");
}
).set_env("LLAMA_ARG_DEFRAG_THOLD"));
if (ex == LLAMA_EXAMPLE_SERVER) {
// this is to make sure this option appears in the server-specific section of the help message
add_opt(common_arg(
{"-np", "--parallel"}, "N",
string_format("number of server slots (default: %d, -1 = auto)", params.n_parallel),
[](common_params & params, int value) {
if (value == 0) {
throw std::invalid_argument("error: invalid value for n_parallel\n");
}
params.n_parallel = value;
}
).set_env("LLAMA_ARG_N_PARALLEL").set_examples({LLAMA_EXAMPLE_SERVER}));
} else {
add_opt(common_arg(
{"-np", "--parallel"}, "N",
string_format("number of parallel sequences to decode (default: %d)", params.n_parallel),
[](common_params & params, int value) {
params.n_parallel = value;
}
).set_env("LLAMA_ARG_N_PARALLEL"));
}
add_opt(common_arg(
{"-np", "--parallel"}, "N",
string_format("number of parallel sequences to decode (default: %d)", params.n_parallel),
[](common_params & params, int value) {
params.n_parallel = value;
}
).set_env("LLAMA_ARG_N_PARALLEL"));
add_opt(common_arg(
{"-ns", "--sequences"}, "N",
string_format("number of sequences to decode (default: %d)", params.n_sequences),
@ -1981,11 +1937,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_OFFLOAD"));
add_opt(common_arg(
{"--image", "--audio"}, "FILE",
"path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n",
"path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
params.image.emplace_back(item);
}
params.image.emplace_back(value);
}
).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
@ -2196,34 +2150,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_env("LLAMA_ARG_MAIN_GPU"));
add_opt(common_arg(
{ "-fit", "--fit" }, "[on|off]",
string_format("whether to adjust unset arguments to fit in device memory ('on' or 'off', default: '%s')", params.fit_params ? "on" : "off"),
[](common_params & params, const std::string & value) {
if (is_truthy(value)) {
params.fit_params = true;
} else if (is_falsey(value)) {
params.fit_params = false;
} else {
throw std::runtime_error(
string_format("error: unkown value for --fit: '%s'\n", value.c_str()));
}
}
).set_env("LLAMA_ARG_FIT"));
add_opt(common_arg(
{ "-fitt", "--fit-target" }, "MiB",
string_format("target margin per device for --fit option, default: %zu", params.fit_params_target/(1024*1024)),
[](common_params & params, int value) {
params.fit_params_target = value * size_t(1024*1024);
}
).set_env("LLAMA_ARG_FIT_TARGET"));
add_opt(common_arg(
{ "-fitc", "--fit-ctx" }, "N",
string_format("minimum ctx size that can be set by --fit option, default: %" PRIu32, params.fit_params_min_ctx),
[](common_params & params, int value) {
params.fit_params_min_ctx = value;
}
).set_env("LLAMA_ARG_FIT_CTX"));
add_opt(common_arg(
{"--check-tensors"},
string_format("check model tensor data for invalid values (default: %s)", params.check_tensors ? "true" : "false"),
@ -2232,39 +2158,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
));
add_opt(common_arg(
{"--override-kv"}, "KEY=TYPE:VALUE,...",
"advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n"
"types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false",
{"--override-kv"}, "KEY=TYPE:VALUE",
"advanced option to override model metadata by key. may be specified multiple times.\n"
"types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false",
[](common_params & params, const std::string & value) {
std::vector<std::string> kv_overrides;
std::string current;
bool escaping = false;
for (const char c : value) {
if (escaping) {
current.push_back(c);
escaping = false;
} else if (c == '\\') {
escaping = true;
} else if (c == ',') {
kv_overrides.push_back(current);
current.clear();
} else {
current.push_back(c);
}
}
if (escaping) {
current.push_back('\\');
}
kv_overrides.push_back(current);
for (const auto & kv_override : kv_overrides) {
if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) {
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str()));
}
if (!string_parse_kv_override(value.c_str(), params.kv_overrides)) {
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", value.c_str()));
}
}
));
@ -2278,50 +2177,33 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
));
add_opt(common_arg(
{"--lora"}, "FNAME",
"path to LoRA adapter (use comma-separated values to load multiple adapters)",
"path to LoRA adapter (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
params.lora_adapters.push_back({ item, 1.0, "", "", nullptr });
}
params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(common_arg(
{"--lora-scaled"}, "FNAME:SCALE,...",
"path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n"
"note: use comma-separated values",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
auto parts = string_split<std::string>(item, ':');
if (parts.size() != 2) {
throw std::invalid_argument("lora-scaled format: FNAME:SCALE");
}
params.lora_adapters.push_back({ parts[0], std::stof(parts[1]), "", "", nullptr });
}
{"--lora-scaled"}, "FNAME", "SCALE",
"path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
[](common_params & params, const std::string & fname, const std::string & scale) {
params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(common_arg(
{"--control-vector"}, "FNAME",
"add a control vector\nnote: use comma-separated values to add multiple control vectors",
"add a control vector\nnote: this argument can be repeated to add multiple control vectors",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
params.control_vectors.push_back({ 1.0f, item, });
}
params.control_vectors.push_back({ 1.0f, value, });
}
));
add_opt(common_arg(
{"--control-vector-scaled"}, "FNAME:SCALE,...",
{"--control-vector-scaled"}, "FNAME", "SCALE",
"add a control vector with user defined scaling SCALE\n"
"note: use comma-separated values (format: FNAME:SCALE,...)",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
auto parts = string_split<std::string>(item, ':');
if (parts.size() != 2) {
throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE");
}
params.control_vectors.push_back({ std::stof(parts[1]), parts[0] });
}
"note: this argument can be repeated to add multiple scaled control vectors",
[](common_params & params, const std::string & fname, const std::string & scale) {
params.control_vectors.push_back({ std::stof(scale), fname });
}
));
add_opt(common_arg(
@ -2411,15 +2293,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_env("HF_TOKEN"));
add_opt(common_arg(
{"--context-file"}, "FNAME",
"file to load context from (use comma-separated values to specify multiple files)",
"file to load context from (repeat to specify multiple files)",
[](common_params & params, const std::string & value) {
for (const auto & item : string_split<std::string>(value, ',')) {
std::ifstream file(item, std::ios::binary);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
}
params.context_files.push_back(item);
std::ifstream file(value, std::ios::binary);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
}
params.context_files.push_back(value);
}
).set_examples({LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg(
@ -2610,20 +2490,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.api_prefix = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
add_opt(common_arg(
{"--webui-config"}, "JSON",
"JSON that provides default WebUI settings (overrides WebUI defaults)",
[](common_params & params, const std::string & value) {
params.webui_config_json = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG"));
add_opt(common_arg(
{"--webui-config-file"}, "PATH",
"JSON file that provides default WebUI settings (overrides WebUI defaults)",
[](common_params & params, const std::string & value) {
params.webui_config_json = read_file(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG_FILE"));
add_opt(common_arg(
{"--webui"},
{"--no-webui"},

View File

@ -115,7 +115,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
// parse input arguments from CLI into a map
// TODO: support repeated args in the future
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

View File

@ -4,14 +4,9 @@
using json = nlohmann::json;
static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
int count = 0;
static std::string_view trim_trailing_space(std::string_view sv) {
while (!sv.empty() && std::isspace(static_cast<unsigned char>(sv.back()))) {
if (max != -1 && count <= max) {
break;
}
sv.remove_suffix(1);
count++;
}
return sv;
}
@ -98,7 +93,7 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
if (is_arg_string && current_tool) {
// Serialize to JSON, but exclude the end quote
std::string dumped = json(trim_trailing_space(node.text)).dump();
std::string dumped = json(node.text).dump();
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
needs_closing_quote = true;
}
@ -106,7 +101,6 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
if (is_arg_close && current_tool) {
if (needs_closing_quote) {
current_tool->arguments += "\"";
needs_closing_quote = false;
}
}
@ -115,10 +109,6 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
}
if (is_tool_close && current_tool) {
if (needs_closing_quote) {
current_tool->arguments += "\"";
needs_closing_quote = false;
}
current_tool->arguments += "}";
}
}

View File

@ -711,25 +711,6 @@ static void foreach_function(const json & tools, const std::function<void(const
}
}
static void foreach_parameter(const json & function, const std::function<void(const std::string &, const json &, bool)> & fn) {
if (!function.contains("parameters") || !function.at("parameters").is_object()) {
return;
}
const auto & params = function.at("parameters");
if (!params.contains("properties") || !params.at("properties").is_object()) {
return;
}
const auto & props = params.at("properties");
std::set<std::string> required;
if (params.contains("required") && params.at("required").is_array()) {
params.at("required").get_to(required);
}
for (const auto & [name, prop] : props.items()) {
bool is_required = (required.find(name) != required.end());
fn(name, prop, is_required);
}
}
static std::string apply(
const common_chat_template & tmpl,
const struct templates_params & inputs,
@ -1428,123 +1409,6 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
return data;
}
static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED;
// Handle thinking tags appropriately based on inputs.enable_thinking
if (string_ends_with(data.prompt, "<think>\n")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}
data.preserved_tokens = {
"<think>",
"</think>",
"<tool_call>",
"</tool_call>",
};
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = true;
auto parser = build_chat_peg_constructed_parser([&](auto & p) {
auto reasoning = p.eps();
if (inputs.enable_thinking && extract_reasoning) {
auto reasoning_content = p.reasoning(p.until("</think>")) + ("</think>" | p.end());
if (data.thinking_forced_open) {
reasoning = reasoning_content;
}
}
// Response format parser
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
return reasoning << p.content(p.schema(p.json(), "response-format", inputs.json_schema));
}
// Tool call parser
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
auto tool_choice = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
auto schema_info = common_schema_info();
schema_info.resolve_refs(parameters);
auto tool_open = "<function=" + p.tool_name(p.literal(name)) + ">\n";
auto tool_close = p.literal("</function>\n");
auto args = p.sequence();
auto arg_string = p.rule("xml-arg-string", p.until_one_of({
"\n</parameter>",
"\n<parameter=",
"\n</function>"
}));
foreach_parameter(function, [&](const auto & param_name, const json & param_schema, bool is_required) {
auto rule_name = "tool-" + name + "-arg-" + param_name;
auto arg_open = "<parameter=" + p.tool_arg_name(p.literal(param_name)) + ">\n";
auto arg_close = p.literal("</parameter>\n");
auto arg_value = p.eps();
if (schema_info.resolves_to_string(param_schema)) {
arg_value = p.tool_arg_string_value(arg_string) + "\n";
} else {
arg_value = p.tool_arg_json_value(p.schema(p.json(), rule_name + "-schema", param_schema));
}
// Model may or my not close with </parameter>
auto arg_rule = p.rule(rule_name, p.tool_arg_open(arg_open) + arg_value + p.optional(p.tool_arg_close(arg_close)));
args += p.repeat(arg_rule, /* min = */ is_required ? 1 : 0, /* max = */ 1);
});
tool_choice |= p.rule("tool-" + name, p.tool_open(tool_open) + args + p.tool_close(tool_close));
});
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto tool_call = p.rule("tool-call", "<tool_call>\n" + tool_choice + "</tool_call>" + p.space());
auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
return reasoning << p.content(p.until("<tool_call>")) << tool_calls;
}
// Content only parser
include_grammar = false;
return reasoning << p.content(p.rest());
});
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.at("parameters");
builder.resolve_refs(schema);
});
parser.build_grammar(builder, data.grammar_lazy);
});
data.grammar_triggers = {
{COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<tool_call>"}
};
}
return data;
}
static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
@ -2670,10 +2534,6 @@ static common_chat_params common_chat_templates_apply_jinja(
src.find("<function=") != std::string::npos &&
src.find("<parameters>") != std::string::npos &&
src.find("<parameter=") != std::string::npos) {
// Nemotron 3 Nano 30B A3B
if (src.find("<think>") != std::string::npos) {
return common_chat_params_init_nemotron_v3(tmpl, params);
}
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
}

View File

@ -1013,40 +1013,31 @@ bool tty_can_use_colors() {
// Model utils
//
// TODO: move to common/sampling
static void common_init_sampler_from_model(
static inline void common_init_sampler_from_model(
const llama_model * model,
common_params_sampling & sparams) {
const uint64_t config = sparams.user_sampling_config;
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
if (config & user_config) {
return;
}
if (config & user_config) return;
char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
int32_t v = strtol(buf, &end, 10);
if (end && end != buf) {
dst = v;
}
if (end && end != buf) dst = v;
}
};
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
if (config & user_config) {
return;
}
if (config & user_config) return;
char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
float v = strtof(buf, &end);
if (end && end != buf) {
dst = v;
}
if (end && end != buf) dst = v;
}
};
@ -1074,125 +1065,31 @@ static void common_init_sampler_from_model(
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
}
struct common_init_result::impl {
impl() = default;
~impl() = default;
llama_model_ptr model;
llama_context_ptr context;
std::vector<llama_adapter_lora_ptr> lora;
std::vector<common_sampler_ptr> samplers;
};
common_init_result::common_init_result(common_params & params) :
pimpl(new impl{}) {
struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
auto mparams = common_model_params_to_llama(params);
auto cparams = common_context_params_to_llama(params);
if (params.fit_params) {
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
}
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) {
return;
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return iparams;
}
pimpl->model.reset(model);
common_init_sampler_from_model(model, params.sampling);
const llama_vocab * vocab = llama_model_get_vocab(model);
// updates params.sampling
// TODO: fix naming
common_init_sampler_from_model(model, params.sampling);
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}
//if (params.sampling.penalty_last_n == -1) {
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
//}
//if (params.sampling.dry_penalty_last_n == -1) {
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
//}
pimpl->samplers.resize(cparams.n_seq_max);
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
}
auto cparams = common_context_params_to_llama(params);
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
return;
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
llama_model_free(model);
return iparams;
}
pimpl->context.reset(lctx);
}
llama_model * common_init_result::model() {
return pimpl->model.get();
}
llama_context * common_init_result::context() {
return pimpl->context.get();
}
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
return pimpl->samplers[seq_id].get();
}
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}
void common_init_result::free_context() {
pimpl->context.reset();
}
common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result_ptr res(new common_init_result(params));
llama_model * model = res->model();
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
return res;
}
llama_context * lctx = res->context();
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
return res;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
@ -1204,7 +1101,10 @@ common_init_result_ptr common_init_from_params(common_params & params) {
const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) {
return res;
llama_free(lctx);
llama_model_free(model);
return iparams;
}
int err = llama_apply_adapter_cvec(
@ -1215,7 +1115,10 @@ common_init_result_ptr common_init_from_params(common_params & params) {
params.control_vector_layer_start,
params.control_vector_layer_end);
if (err) {
return res;
llama_free(lctx);
llama_model_free(model);
return iparams;
}
}
@ -1239,7 +1142,10 @@ common_init_result_ptr common_init_from_params(common_params & params) {
}
if (!ok) {
return res;
llama_free(lctx);
llama_model_free(model);
return iparams;
}
}
@ -1249,7 +1155,9 @@ common_init_result_ptr common_init_from_params(common_params & params) {
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
return res;
llama_free(lctx);
llama_model_free(model);
return iparams;
}
char buf[1024];
@ -1258,13 +1166,43 @@ common_init_result_ptr common_init_from_params(common_params & params) {
la.task_name = buf;
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
la.prompt_prefix = buf;
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
}
if (!params.lora_init_without_apply) {
common_set_adapter_lora(lctx, params.lora_adapters);
}
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}
if (params.sampling.penalty_last_n == -1) {
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.penalty_last_n = llama_n_ctx(lctx);
}
if (params.sampling.dry_penalty_last_n == -1) {
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
}
if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
@ -1303,10 +1241,11 @@ common_init_result_ptr common_init_from_params(common_params & params) {
llama_set_warmup(lctx, false);
}
return res;
}
iparams.model.reset(model);
iparams.context.reset(lctx);
common_init_result::~common_init_result() = default;
return iparams;
}
std::string get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
@ -1316,9 +1255,7 @@ std::string get_model_endpoint() {
std::string model_endpoint = "https://huggingface.co/";
if (endpoint_env) {
model_endpoint = endpoint_env;
if (model_endpoint.back() != '/') {
model_endpoint += '/';
}
if (model_endpoint.back() != '/') model_endpoint += '/';
}
return model_endpoint;
}

View File

@ -99,7 +99,6 @@ enum llama_example {
LLAMA_EXAMPLE_TTS,
LLAMA_EXAMPLE_DIFFUSION,
LLAMA_EXAMPLE_FINETUNE,
LLAMA_EXAMPLE_FIT_PARAMS,
LLAMA_EXAMPLE_COUNT,
};
@ -196,6 +195,7 @@ struct common_params_sampling {
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY,
@ -216,10 +216,6 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
bool has_logit_bias() const {
return !logit_bias.empty();
}
// print the parameters into a string
std::string print() const;
};
@ -307,8 +303,8 @@ struct lr_opt {
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
struct common_params {
int32_t n_predict = -1; // max. number of new tokens to predict, -1 == no limit
int32_t n_ctx = 0; // context size, 0 == context the model was trained with
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
@ -329,12 +325,9 @@ struct common_params {
// offload params
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
@ -484,11 +477,8 @@ struct common_params {
std::map<std::string, std::string> default_template_kwargs;
// webui configs
bool webui = true;
std::string webui_config_json;
// "advanced" endpoints are disabled by default for better security
bool webui = true;
bool endpoint_slots = true;
bool endpoint_props = false; // only control POST requests, not GET
bool endpoint_metrics = false;
@ -679,29 +669,15 @@ bool tty_can_use_colors();
// Model utils
//
struct common_sampler;
// note: defines the model, context, samplers, ets. lifetimes
// note: defines object's lifetime
struct common_init_result {
common_init_result(common_params & params);
~common_init_result();
llama_model_ptr model;
llama_context_ptr context;
llama_model * model();
llama_context * context();
common_sampler * sampler(llama_seq_id seq_id);
std::vector<llama_adapter_lora_ptr> & lora();
void free_context();
private:
struct impl;
std::unique_ptr<impl> pimpl;
std::vector<llama_adapter_lora_ptr> lora;
};
using common_init_result_ptr = std::unique_ptr<common_init_result>;
common_init_result_ptr common_init_from_params(common_params & params);
struct common_init_result common_init_from_params(common_params & params);
struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);

View File

@ -305,9 +305,8 @@ static std::string format_literal(const std::string & literal) {
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
class common_schema_converter {
class SchemaConverter {
private:
friend class common_schema_info;
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
@ -730,7 +729,7 @@ private:
}
public:
common_schema_converter(
SchemaConverter(
const std::function<json(const std::string &)> & fetch_json,
bool dotall)
: _fetch_json(fetch_json), _dotall(dotall)
@ -991,134 +990,6 @@ public:
}
};
// common_schema_info implementation (pimpl)
common_schema_info::common_schema_info()
: impl_(std::make_unique<common_schema_converter>(
[](const std::string &) { return json(); },
false)) {}
common_schema_info::~common_schema_info() = default;
common_schema_info::common_schema_info(common_schema_info &&) noexcept = default;
common_schema_info & common_schema_info::operator=(common_schema_info &&) noexcept = default;
void common_schema_info::resolve_refs(nlohmann::ordered_json & schema) {
impl_->resolve_refs(schema, "");
}
// Determines if a JSON schema can resolve to a string type through any path.
// Some models emit raw string values rather than JSON-encoded strings for string parameters.
// If any branch of the schema (via oneOf, anyOf, $ref, etc.) permits a string, this returns
// true, allowing callers to handle the value as a raw string for simplicity.
bool common_schema_info::resolves_to_string(const nlohmann::ordered_json & schema) {
std::unordered_set<std::string> visited_refs;
std::function<bool(const json &)> check = [&](const json & s) -> bool {
if (!s.is_object()) {
return false;
}
// Handle $ref
if (s.contains("$ref")) {
const std::string & ref = s["$ref"];
if (visited_refs.find(ref) != visited_refs.end()) {
// Circular reference, assume not a string to be safe
return false;
}
visited_refs.insert(ref);
auto it = impl_->_refs.find(ref);
if (it != impl_->_refs.end()) {
return check(it->second);
}
return false;
}
// Check type field
if (s.contains("type")) {
const json & schema_type = s["type"];
if (schema_type.is_string()) {
if (schema_type == "string") {
return true;
}
} else if (schema_type.is_array()) {
// Type can be an array like ["string", "null"]
for (const auto & t : schema_type) {
if (t == "string") {
return true;
}
}
}
}
// Check oneOf/anyOf - if any alternative can be a string
if (s.contains("oneOf")) {
for (const auto & alt : s["oneOf"]) {
if (check(alt)) {
return true;
}
}
}
if (s.contains("anyOf")) {
for (const auto & alt : s["anyOf"]) {
if (check(alt)) {
return true;
}
}
}
// Check allOf - all components must be compatible with string type
if (s.contains("allOf")) {
bool all_string = true;
for (const auto & component : s["allOf"]) {
if (!check(component)) {
all_string = false;
break;
}
}
if (all_string) {
return true;
}
}
// Check const - if the constant value is a string
if (s.contains("const")) {
if (s["const"].is_string()) {
return true;
}
}
// Check enum - if any enum value is a string
if (s.contains("enum")) {
for (const auto & val : s["enum"]) {
if (val.is_string()) {
return true;
}
}
}
// String-specific keywords imply string type
if (s.contains("pattern") || s.contains("minLength") || s.contains("maxLength")) {
return true;
}
// Check format - many formats imply string
if (s.contains("format")) {
const std::string & fmt = s["format"];
if (fmt == "date" || fmt == "time" || fmt == "date-time" ||
fmt == "uri" || fmt == "email" || fmt == "hostname" ||
fmt == "ipv4" || fmt == "ipv6" || fmt == "uuid" ||
fmt.find("uuid") == 0) {
return true;
}
}
return false;
};
return check(schema);
}
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
#ifdef LLAMA_USE_LLGUIDANCE
if (!force_gbnf) {
@ -1135,7 +1006,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
}
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
common_schema_converter converter([&](const std::string &) { return json(); }, options.dotall);
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);

View File

@ -3,31 +3,11 @@
#include <nlohmann/json_fwd.hpp>
#include <functional>
#include <memory>
#include <string>
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
bool force_gbnf = false);
class common_schema_converter;
// Probes a JSON schema to extract information about its structure and type constraints.
class common_schema_info {
std::unique_ptr<common_schema_converter> impl_;
public:
common_schema_info();
~common_schema_info();
common_schema_info(const common_schema_info &) = delete;
common_schema_info & operator=(const common_schema_info &) = delete;
common_schema_info(common_schema_info &&) noexcept;
common_schema_info & operator=(common_schema_info &&) noexcept;
void resolve_refs(nlohmann::ordered_json & schema);
bool resolves_to_string(const nlohmann::ordered_json & schema);
};
struct common_grammar_builder {
std::function<std::string(const std::string &, const std::string &)> add_rule;
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;

View File

@ -425,7 +425,7 @@ struct parser_executor {
if (result.need_more_input()) {
// Propagate - need to know what child would match before negating
return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos);
return result;
}
// Child failed, so negation succeeds

View File

@ -157,21 +157,6 @@ static std::map<std::string, common_arg> get_map_key_opt(common_params_context &
return mapping;
}
static bool is_bool_arg(const common_arg & arg) {
return !arg.args_neg.empty();
}
static std::string parse_bool_arg(const common_arg & arg, const std::string & key, const std::string & value) {
// if this is a negated arg, we need to reverse the value
for (const auto & neg_arg : arg.args_neg) {
if (rm_leading_dashes(neg_arg) == key) {
return common_arg_utils::is_truthy(value) ? "false" : "true";
}
}
// otherwise, not negated
return value;
}
common_presets common_presets_load(const std::string & path, common_params_context & ctx_params) {
common_presets out;
auto key_to_opt = get_map_key_opt(ctx_params);
@ -188,13 +173,8 @@ common_presets common_presets_load(const std::string & path, common_params_conte
for (const auto & [key, value] : section.second) {
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
if (key_to_opt.find(key) != key_to_opt.end()) {
auto & opt = key_to_opt[key];
if (is_bool_arg(opt)) {
preset.options[opt] = parse_bool_arg(opt, key, value);
} else {
preset.options[opt] = value;
}
LOG_DBG("accepted option: %s = %s\n", key.c_str(), preset.options[opt].c_str());
preset.options[key_to_opt[key]] = value;
LOG_DBG("accepted option: %s = %s\n", key.c_str(), value.c_str());
} else {
// TODO: maybe warn about unknown key?
}

View File

@ -116,6 +116,7 @@ struct common_sampler {
void reset() {
prev.clear();
llama_sampler_reset(grmr);
llama_sampler_reset(chain);
}
@ -166,11 +167,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
lparams.no_perf = params.no_perf;
llama_sampler * grmr = nullptr;
llama_sampler * chain = llama_sampler_chain_init(lparams);
std::vector<llama_sampler *> samplers;
struct llama_sampler * grmr;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
@ -220,20 +217,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
trigger_patterns_c.push_back(regex.c_str());
}
if (!params.grammar.empty()) {
if (params.grammar_lazy) {
grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size());
} else {
grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}
grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
if (!grmr) {
return nullptr;
}
}
if (params.has_logit_bias()) {
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
};
llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) {
@ -246,71 +253,58 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
c_breakers.push_back(str.c_str());
}
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
samplers.push_back(llama_sampler_init_top_k (params.top_k));
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_XTC:
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
samplers.push_back(llama_sampler_init_infill (vocab));
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
break;
case COMMON_SAMPLER_TYPE_PENALTIES:
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
}
samplers.push_back(llama_sampler_init_dist(params.seed));
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) {
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) {
samplers.push_back(llama_sampler_init_temp(params.temp));
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
} else {
GGML_ASSERT(false && "unknown mirostat version");
}
for (auto * smpl : samplers) {
llama_sampler_chain_add(chain, smpl);
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ chain,
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
};
return result;
}
void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
delete gsmpl;
@ -320,7 +314,7 @@ void common_sampler_free(struct common_sampler * gsmpl) {
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
const auto tm = gsmpl->tm();
if (gsmpl->grmr && accept_grammar) {
if (accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token);
}
@ -335,12 +329,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
};
}
@ -389,37 +383,33 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}
}
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
return gsmpl->chain;
}
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
const auto tm = gsmpl->tm();
llama_token id = LLAMA_TOKEN_NULL;
gsmpl->set_logits(ctx, idx);
auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
gsmpl->set_logits(ctx, idx);
if (grammar_first) {
llama_sampler_apply(grmr, &cur_p);
}
llama_sampler_apply(chain, &cur_p);
id = cur_p.data[cur_p.selected].id;
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
const llama_token id = cur_p.data[cur_p.selected].id;
if (grammar_first) {
return id;
}
// check if it the sampled token fits the grammar (grammar-based rejection sampling)
// check if it the sampled token fits the grammar
{
llama_token_data single_token_data = { id, 1.0f, 0.0f };
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
@ -439,11 +429,9 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
id = cur_p.data[cur_p.selected].id;
return id;
return cur_p.data[cur_p.selected].id;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
@ -527,8 +515,7 @@ std::string common_sampler_print(const struct common_sampler * gsmpl) {
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
result += std::string("-> ");
result += std::string(llama_sampler_name(smpl)) + " ";
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
}
return result;

View File

@ -48,8 +48,6 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
// arguments can be nullptr to skip printing
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
// extended sampling implementation:
//
// - set logits
@ -109,9 +107,3 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data);
struct common_sampler_deleter {
void operator()(common_sampler * s) { common_sampler_free(s); }
};
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;

File diff suppressed because it is too large Load Diff

View File

@ -143,7 +143,6 @@ models = [
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
{"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", },
]
# some models are known to be broken upstream, so we will skip them as exceptions

View File

@ -1,26 +1,6 @@
# Android
## Build with Android Studio
Import the `examples/llama.android` directory into Android Studio, then perform a Gradle sync and build the project.
![Project imported into Android Studio](./android/imported-into-android-studio.png)
This Android binding supports hardware acceleration up to `SME2` for **Arm** and `AMX` for **x86-64** CPUs on Android and ChromeOS devices.
It automatically detects the host's hardware to load compatible kernels. As a result, it runs seamlessly on both the latest premium devices and older devices that may lack modern CPU features or have limited RAM, without requiring any manual configuration.
A minimal Android app frontend is included to showcase the bindings core functionalities:
1. **Parse GGUF metadata** via `GgufMetadataReader` from either a `ContentResolver` provided `Uri` or a local `File`.
2. **Obtain a `TierDetection` or `InferenceEngine`** instance through the high-level facade APIs.
3. **Send a raw user prompt** for automatic template formatting, prefill, and decoding. Then collect the generated tokens in a Kotlin `Flow`.
For a production-ready experience that leverages advanced features such as system prompts and benchmarks, check out [Arm AI Chat](https://play.google.com/store/apps/details?id=com.arm.aichat) on Google Play.
This project is made possible through a collaborative effort by Arm's **CT-ML**, **CE-ML** and **STE** groups:
| ![Home screen](./android/arm-ai-chat-home-screen.png) | ![System prompt](./android/system-prompt-setup.png) | !["Haiku"](./android/chat-with-system-prompt-haiku.png) |
|:------------------------------------------------------:|:----------------------------------------------------:|:--------------------------------------------------------:|
| Home screen | System prompt | "Haiku" |
## Build on Android using Termux
[Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid.

View File

@ -103,8 +103,6 @@ SYCL backend supports Intel GPU Family:
- Intel Built-in Arc GPU
- Intel iGPU in Core CPU (11th Generation Core CPU and newer, refer to [oneAPI supported GPU](https://www.intel.com/content/www/us/en/developer/articles/system-requirements/intel-oneapi-base-toolkit-system-requirements.html#inpage-nav-1-1)).
On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the performance is not optimal, and some GPUs may not support OpenCL nor have any GPGPU capabilities.
#### Verified devices
| Intel GPU | Status | Verified Model |

View File

@ -9,8 +9,7 @@ Adding a model requires few steps:
After following these steps, you can open PR.
Also, it is important to check that the examples and main ggml backends (CUDA, METAL, CPU) are working with the new architecture, especially:
- [cli](/tools/cli/)
- [completion](/tools/completion/)
- [main](/tools/main/)
- [imatrix](/tools/imatrix/)
- [quantize](/tools/quantize/)
- [server](/tools/server/)
@ -97,7 +96,7 @@ The model params and tensors layout must be defined in `llama.cpp` source files:
1. Define a new `llm_arch` enum value in `src/llama-arch.h`.
2. In `src/llama-arch.cpp`:
- Add the architecture name to the `LLM_ARCH_NAMES` map.
- Add the list of model tensors to `llm_get_tensor_names` (you may also need to update `LLM_TENSOR_NAMES`)
- Add the tensor mappings to the `LLM_TENSOR_NAMES` map.
3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.

View File

@ -7,9 +7,9 @@
## Images
We have three Docker images available for this project:
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the `llama-cli` and `llama-completion` executables and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the `llama-cli` and `llama-completion` executables. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the `llama-server` executable. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
Additionally, there the following images, similar to the above:
@ -44,15 +44,13 @@ docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --all-in-o
On completion, you are ready to play!
```bash
docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --run -m /models/7B/ggml-model-q4_0.gguf
docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --run-legacy -m /models/32B/ggml-model-q8_0.gguf -no-cnv -p "Building a mobile app can be done in 15 steps:" -n 512
docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512
```
or with a light image:
```bash
docker run -v /path/to/models:/models --entrypoint /app/llama-cli ghcr.io/ggml-org/llama.cpp:light -m /models/7B/ggml-model-q4_0.gguf
docker run -v /path/to/models:/models --entrypoint /app/llama-completion ghcr.io/ggml-org/llama.cpp:light -m /models/32B/ggml-model-q8_0.gguf -no-cnv -p "Building a mobile app can be done in 15 steps:" -n 512
docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:light -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512
```
or with a server image:
@ -61,8 +59,6 @@ or with a server image:
docker run -v /path/to/models:/models -p 8080:8080 ghcr.io/ggml-org/llama.cpp:server -m /models/7B/ggml-model-q4_0.gguf --port 8080 --host 0.0.0.0 -n 512
```
In the above examples, `--entrypoint /app/llama-cli` is specified for clarity, but you can safely omit it since it's the default entrypoint in the container.
## Docker With CUDA
Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-container-toolkit) properly installed on Linux, or is using a GPU enabled cloud, `cuBLAS` should be accessible inside the container.
@ -84,9 +80,9 @@ The defaults are:
The resulting images, are essentially the same as the non-CUDA images:
1. `local/llama.cpp:full-cuda`: This image includes both the `llama-cli` and `llama-completion` executables and the tools to convert LLaMA models into ggml and convert into 4-bit quantization.
2. `local/llama.cpp:light-cuda`: This image only includes the `llama-cli` and `llama-completion` executables.
3. `local/llama.cpp:server-cuda`: This image only includes the `llama-server` executable.
1. `local/llama.cpp:full-cuda`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization.
2. `local/llama.cpp:light-cuda`: This image only includes the main executable file.
3. `local/llama.cpp:server-cuda`: This image only includes the server executable file.
## Usage
@ -118,9 +114,9 @@ The defaults are:
The resulting images, are essentially the same as the non-MUSA images:
1. `local/llama.cpp:full-musa`: This image includes both the `llama-cli` and `llama-completion` executables and the tools to convert LLaMA models into ggml and convert into 4-bit quantization.
2. `local/llama.cpp:light-musa`: This image only includes the `llama-cli` and `llama-completion` executables.
3. `local/llama.cpp:server-musa`: This image only includes the `llama-server` executable.
1. `local/llama.cpp:full-musa`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization.
2. `local/llama.cpp:light-musa`: This image only includes the main executable file.
3. `local/llama.cpp:server-musa`: This image only includes the server executable file.
## Usage

View File

@ -18,12 +18,12 @@ Legend:
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | ✅ | ❌ | ❌ | ❌ |
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | ✅ | ❌ | ❌ | ❌ |
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | | ✅ | ❌ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | | 🟡 | ❌ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
@ -31,7 +31,7 @@ Legend:
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
@ -64,7 +64,7 @@ Legend:
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | ✅ | ❌ | ❌ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
@ -98,14 +98,14 @@ Legend:
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | | 🟡 | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | | 🟡 | ❌ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | | 🟡 | ❌ | ❌ | ❌ |
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@ -113,7 +113,7 @@ Legend:
| SUM | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | 🟡 | ✅ | ❌ | ❌ |
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | 🟡 | ✅ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,6 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include <algorithm>
#include <cstdio>
@ -65,23 +64,17 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_predict, n_parallel);
llama_context * ctx = llama_init_from_model(model, ctx_params);
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
std::vector<llama_sampler *> samplers;
llama_sampler * smpl = llama_sampler_chain_init(sparams);
for (int32_t i = 0; i < n_parallel; ++i) {
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
samplers.push_back(smpl);
}
llama_context * ctx = llama_init_from_model(model, ctx_params);
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
if (ctx == NULL) {
LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);
@ -180,7 +173,7 @@ int main(int argc, char ** argv) {
continue;
}
const llama_token new_token_id = llama_sampler_sample(samplers[i], ctx, i_batch[i]);
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
// is it an end of generation? -> mark the stream as finished
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
@ -236,17 +229,14 @@ int main(int argc, char ** argv) {
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
LOG("\n");
llama_perf_sampler_print(samplers[0]);
llama_perf_sampler_print(smpl);
llama_perf_context_print(ctx);
fprintf(stderr, "\n");
llama_batch_free(batch);
for (auto & sampler_config : samplers) {
llama_sampler_free(sampler_config);
}
llama_sampler_free(smpl);
llama_free(ctx);
llama_model_free(model);

View File

@ -131,10 +131,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
// load the model
auto llama_init = common_init_from_params(params);
common_init_result llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);

View File

@ -202,10 +202,10 @@ int main(int argc, char ** argv) {
params.warmup = false;
// init
auto llama_init = common_init_from_params(params);
common_init_result llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
if (model == nullptr || ctx == nullptr) {
LOG_ERR("%s : failed to init\n", __func__);

View File

@ -48,7 +48,7 @@ static void write_table(std::ofstream & file, std::vector<common_arg *> & opts)
}
}
static void export_md(std::string fname, llama_example ex, std::string name) {
static void export_md(std::string fname, llama_example ex) {
std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc);
common_params params;
@ -72,14 +72,13 @@ static void export_md(std::string fname, llama_example ex, std::string name) {
write_table(file, common_options);
file << "\n\n**Sampling params**\n\n";
write_table(file, sparam_options);
file << "\n\n**" << name << "-specific params**\n\n";
file << "\n\n**Example-specific params**\n\n";
write_table(file, specific_options);
}
int main(int, char **) {
// TODO: add CLI
export_md("autogen-completion.md", LLAMA_EXAMPLE_COMPLETION, "Tool");
export_md("autogen-server.md", LLAMA_EXAMPLE_SERVER, "Server");
export_md("autogen-main.md", LLAMA_EXAMPLE_COMPLETION);
export_md("autogen-server.md", LLAMA_EXAMPLE_SERVER);
return 0;
}

View File

@ -1,18 +1,16 @@
plugins {
alias(libs.plugins.android.application)
alias(libs.plugins.jetbrains.kotlin.android)
id("com.android.application")
id("org.jetbrains.kotlin.android")
}
android {
namespace = "com.example.llama"
compileSdk = 36
compileSdk = 34
defaultConfig {
applicationId = "com.example.llama.aichat"
applicationId = "com.example.llama"
minSdk = 33
targetSdk = 36
targetSdk = 34
versionCode = 1
versionName = "1.0"
@ -23,17 +21,8 @@ android {
}
buildTypes {
debug {
isMinifyEnabled = true
isShrinkResources = true
proguardFiles(
getDefaultProguardFile("proguard-android.txt"),
"proguard-rules.pro"
)
}
release {
isMinifyEnabled = true
isShrinkResources = true
isMinifyEnabled = false
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
@ -47,15 +36,30 @@ android {
kotlinOptions {
jvmTarget = "1.8"
}
buildFeatures {
compose = true
}
composeOptions {
kotlinCompilerExtensionVersion = "1.5.1"
}
}
dependencies {
implementation(libs.bundles.androidx)
implementation(libs.material)
implementation(project(":lib"))
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)
implementation("androidx.core:core-ktx:1.12.0")
implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.2")
implementation("androidx.activity:activity-compose:1.8.2")
implementation(platform("androidx.compose:compose-bom:2023.08.00"))
implementation("androidx.compose.ui:ui")
implementation("androidx.compose.ui:ui-graphics")
implementation("androidx.compose.ui:ui-tooling-preview")
implementation("androidx.compose.material3:material3")
implementation(project(":llama"))
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
androidTestImplementation(platform("androidx.compose:compose-bom:2023.08.00"))
androidTestImplementation("androidx.compose.ui:ui-test-junit4")
debugImplementation("androidx.compose.ui:ui-tooling")
debugImplementation("androidx.compose.ui:ui-test-manifest")
}

View File

@ -19,11 +19,3 @@
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
-keep class com.arm.aichat.* { *; }
-keep class com.arm.aichat.gguf.* { *; }
-assumenosideeffects class android.util.Log {
public static int v(...);
public static int d(...);
}

View File

@ -1,21 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools">
<uses-permission android:name="android.permission.INTERNET" />
<application
android:allowBackup="true"
android:dataExtractionRules="@xml/data_extraction_rules"
android:extractNativeLibs="true"
android:fullBackupContent="@xml/backup_rules"
android:icon="@mipmap/ic_launcher_round"
android:icon="@mipmap/ic_launcher"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/Theme.AiChatSample"
android:theme="@style/Theme.LlamaAndroid"
>
<activity
android:name=".MainActivity"
android:exported="true">
android:exported="true"
android:theme="@style/Theme.LlamaAndroid">
<intent-filter>
<action android:name="android.intent.action.MAIN" />

View File

@ -0,0 +1,119 @@
package com.example.llama
import android.app.DownloadManager
import android.net.Uri
import android.util.Log
import androidx.compose.material3.Button
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableDoubleStateOf
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import androidx.core.database.getLongOrNull
import androidx.core.net.toUri
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import java.io.File
data class Downloadable(val name: String, val source: Uri, val destination: File) {
companion object {
@JvmStatic
private val tag: String? = this::class.qualifiedName
sealed interface State
data object Ready: State
data class Downloading(val id: Long): State
data class Downloaded(val downloadable: Downloadable): State
data class Error(val message: String): State
@JvmStatic
@Composable
fun Button(viewModel: MainViewModel, dm: DownloadManager, item: Downloadable) {
var status: State by remember {
mutableStateOf(
if (item.destination.exists()) Downloaded(item)
else Ready
)
}
var progress by remember { mutableDoubleStateOf(0.0) }
val coroutineScope = rememberCoroutineScope()
suspend fun waitForDownload(result: Downloading, item: Downloadable): State {
while (true) {
val cursor = dm.query(DownloadManager.Query().setFilterById(result.id))
if (cursor == null) {
Log.e(tag, "dm.query() returned null")
return Error("dm.query() returned null")
}
if (!cursor.moveToFirst() || cursor.count < 1) {
cursor.close()
Log.i(tag, "cursor.moveToFirst() returned false or cursor.count < 1, download canceled?")
return Ready
}
val pix = cursor.getColumnIndex(DownloadManager.COLUMN_BYTES_DOWNLOADED_SO_FAR)
val tix = cursor.getColumnIndex(DownloadManager.COLUMN_TOTAL_SIZE_BYTES)
val sofar = cursor.getLongOrNull(pix) ?: 0
val total = cursor.getLongOrNull(tix) ?: 1
cursor.close()
if (sofar == total) {
return Downloaded(item)
}
progress = (sofar * 1.0) / total
delay(1000L)
}
}
fun onClick() {
when (val s = status) {
is Downloaded -> {
viewModel.load(item.destination.path)
}
is Downloading -> {
coroutineScope.launch {
status = waitForDownload(s, item)
}
}
else -> {
item.destination.delete()
val request = DownloadManager.Request(item.source).apply {
setTitle("Downloading model")
setDescription("Downloading model: ${item.name}")
setAllowedNetworkTypes(DownloadManager.Request.NETWORK_WIFI)
setDestinationUri(item.destination.toUri())
}
viewModel.log("Saving ${item.name} to ${item.destination.path}")
Log.i(tag, "Saving ${item.name} to ${item.destination.path}")
val id = dm.enqueue(request)
status = Downloading(id)
onClick()
}
}
}
Button(onClick = { onClick() }, enabled = status !is Downloading) {
when (status) {
is Downloading -> Text(text = "Downloading ${(progress * 100).toInt()}%")
is Downloaded -> Text("Load ${item.name}")
is Ready -> Text("Download ${item.name}")
is Error -> Text("Download ${item.name}")
}
}
}
}
}

View File

@ -1,257 +1,154 @@
package com.example.llama
import android.app.ActivityManager
import android.app.DownloadManager
import android.content.ClipData
import android.content.ClipboardManager
import android.net.Uri
import android.os.Bundle
import android.util.Log
import android.widget.EditText
import android.widget.TextView
import android.widget.Toast
import androidx.activity.enableEdgeToEdge
import androidx.activity.result.contract.ActivityResultContracts
import androidx.appcompat.app.AppCompatActivity
import androidx.lifecycle.lifecycleScope
import androidx.recyclerview.widget.LinearLayoutManager
import androidx.recyclerview.widget.RecyclerView
import com.arm.aichat.AiChat
import com.arm.aichat.InferenceEngine
import com.arm.aichat.gguf.GgufMetadata
import com.arm.aichat.gguf.GgufMetadataReader
import com.google.android.material.floatingactionbutton.FloatingActionButton
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import android.os.StrictMode
import android.os.StrictMode.VmPolicy
import android.text.format.Formatter
import androidx.activity.ComponentActivity
import androidx.activity.compose.setContent
import androidx.activity.viewModels
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxSize
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.foundation.lazy.rememberLazyListState
import androidx.compose.material3.Button
import androidx.compose.material3.LocalContentColor
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.OutlinedTextField
import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier
import androidx.compose.ui.unit.dp
import androidx.core.content.getSystemService
import com.example.llama.ui.theme.LlamaAndroidTheme
import java.io.File
import java.io.FileOutputStream
import java.io.InputStream
import java.util.UUID
class MainActivity : AppCompatActivity() {
class MainActivity(
activityManager: ActivityManager? = null,
downloadManager: DownloadManager? = null,
clipboardManager: ClipboardManager? = null,
): ComponentActivity() {
private val tag: String? = this::class.simpleName
// Android views
private lateinit var ggufTv: TextView
private lateinit var messagesRv: RecyclerView
private lateinit var userInputEt: EditText
private lateinit var userActionFab: FloatingActionButton
private val activityManager by lazy { activityManager ?: getSystemService<ActivityManager>()!! }
private val downloadManager by lazy { downloadManager ?: getSystemService<DownloadManager>()!! }
private val clipboardManager by lazy { clipboardManager ?: getSystemService<ClipboardManager>()!! }
// Arm AI Chat inference engine
private lateinit var engine: InferenceEngine
private val viewModel: MainViewModel by viewModels()
// Conversation states
private var isModelReady = false
private val messages = mutableListOf<Message>()
private val lastAssistantMsg = StringBuilder()
private val messageAdapter = MessageAdapter(messages)
// Get a MemoryInfo object for the device's current memory status.
private fun availableMemory(): ActivityManager.MemoryInfo {
return ActivityManager.MemoryInfo().also { memoryInfo ->
activityManager.getMemoryInfo(memoryInfo)
}
}
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
enableEdgeToEdge()
setContentView(R.layout.activity_main)
// Find views
ggufTv = findViewById(R.id.gguf)
messagesRv = findViewById(R.id.messages)
messagesRv.layoutManager = LinearLayoutManager(this)
messagesRv.adapter = messageAdapter
userInputEt = findViewById(R.id.user_input)
userActionFab = findViewById(R.id.fab)
StrictMode.setVmPolicy(
VmPolicy.Builder(StrictMode.getVmPolicy())
.detectLeakedClosableObjects()
.build()
)
// Arm AI Chat initialization
lifecycleScope.launch(Dispatchers.Default) {
engine = AiChat.getInferenceEngine(applicationContext)
}
val free = Formatter.formatFileSize(this, availableMemory().availMem)
val total = Formatter.formatFileSize(this, availableMemory().totalMem)
// Upon CTA button tapped
userActionFab.setOnClickListener {
if (isModelReady) {
// If model is ready, validate input and send to engine
handleUserInput()
} else {
// Otherwise, prompt user to select a GGUF metadata on the device
getContent.launch(arrayOf("*/*"))
}
}
}
viewModel.log("Current memory: $free / $total")
viewModel.log("Downloads directory: ${getExternalFilesDir(null)}")
private val getContent = registerForActivityResult(
ActivityResultContracts.OpenDocument()
) { uri ->
Log.i(TAG, "Selected file uri:\n $uri")
uri?.let { handleSelectedModel(it) }
}
val extFilesDir = getExternalFilesDir(null)
/**
* Handles the file Uri from [getContent] result
*/
private fun handleSelectedModel(uri: Uri) {
// Update UI states
userActionFab.isEnabled = false
userInputEt.hint = "Parsing GGUF..."
ggufTv.text = "Parsing metadata from selected file \n$uri"
val models = listOf(
Downloadable(
"Phi-2 7B (Q4_0, 1.6 GiB)",
Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true"),
File(extFilesDir, "phi-2-q4_0.gguf"),
),
Downloadable(
"TinyLlama 1.1B (f16, 2.2 GiB)",
Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true"),
File(extFilesDir, "tinyllama-1.1-f16.gguf"),
),
Downloadable(
"Phi 2 DPO (Q3_K_M, 1.48 GiB)",
Uri.parse("https://huggingface.co/TheBloke/phi-2-dpo-GGUF/resolve/main/phi-2-dpo.Q3_K_M.gguf?download=true"),
File(extFilesDir, "phi-2-dpo.Q3_K_M.gguf")
),
)
lifecycleScope.launch(Dispatchers.IO) {
// Parse GGUF metadata
Log.i(TAG, "Parsing GGUF metadata...")
contentResolver.openInputStream(uri)?.use {
GgufMetadataReader.create().readStructuredMetadata(it)
}?.let { metadata ->
// Update UI to show GGUF metadata to user
Log.i(TAG, "GGUF parsed: \n$metadata")
withContext(Dispatchers.Main) {
ggufTv.text = metadata.toString()
setContent {
LlamaAndroidTheme {
// A surface container using the 'background' color from the theme
Surface(
modifier = Modifier.fillMaxSize(),
color = MaterialTheme.colorScheme.background
) {
MainCompose(
viewModel,
clipboardManager,
downloadManager,
models,
)
}
// Ensure the model file is available
val modelName = metadata.filename() + FILE_EXTENSION_GGUF
contentResolver.openInputStream(uri)?.use { input ->
ensureModelFile(modelName, input)
}?.let { modelFile ->
loadModel(modelName, modelFile)
withContext(Dispatchers.Main) {
isModelReady = true
userInputEt.hint = "Type and send a message!"
userInputEt.isEnabled = true
userActionFab.setImageResource(R.drawable.outline_send_24)
userActionFab.isEnabled = true
}
}
}
}
}
/**
* Prepare the model file within app's private storage
*/
private suspend fun ensureModelFile(modelName: String, input: InputStream) =
withContext(Dispatchers.IO) {
File(ensureModelsDirectory(), modelName).also { file ->
// Copy the file into local storage if not yet done
if (!file.exists()) {
Log.i(TAG, "Start copying file to $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Copying file..."
}
FileOutputStream(file).use { input.copyTo(it) }
Log.i(TAG, "Finished copying file to $modelName")
} else {
Log.i(TAG, "File already exists $modelName")
}
}
}
/**
* Load the model file from the app private storage
*/
private suspend fun loadModel(modelName: String, modelFile: File) =
withContext(Dispatchers.IO) {
Log.i(TAG, "Loading model $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Loading model..."
}
engine.loadModel(modelFile.path)
}
/**
* Validate and send the user message into [InferenceEngine]
*/
private fun handleUserInput() {
userInputEt.text.toString().also { userSsg ->
if (userSsg.isEmpty()) {
Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
} else {
userInputEt.text = null
userActionFab.isEnabled = false
// Update message states
messages.add(Message(UUID.randomUUID().toString(), userSsg, true))
lastAssistantMsg.clear()
messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
lifecycleScope.launch(Dispatchers.Default) {
engine.sendUserPrompt(userSsg)
.onCompletion {
withContext(Dispatchers.Main) {
userActionFab.isEnabled = true
}
}.collect { token ->
val messageCount = messages.size
check(messageCount > 0 && !messages[messageCount - 1].isUser)
messages.removeAt(messageCount - 1).copy(
content = lastAssistantMsg.append(token).toString()
).let { messages.add(it) }
withContext(Dispatchers.Main) {
messageAdapter.notifyItemChanged(messages.size - 1)
}
}
}
}
}
}
/**
* Run a benchmark with the model file
*/
private suspend fun runBenchmark(modelName: String, modelFile: File) =
withContext(Dispatchers.Default) {
Log.i(TAG, "Starts benchmarking $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Running benchmark..."
}
engine.bench(
pp=BENCH_PROMPT_PROCESSING_TOKENS,
tg=BENCH_TOKEN_GENERATION_TOKENS,
pl=BENCH_SEQUENCE,
nr=BENCH_REPETITION
).let { result ->
messages.add(Message(UUID.randomUUID().toString(), result, false))
withContext(Dispatchers.Main) {
messageAdapter.notifyItemChanged(messages.size - 1)
}
}
}
/**
* Create the `models` directory if not exist.
*/
private fun ensureModelsDirectory() =
File(filesDir, DIRECTORY_MODELS).also {
if (it.exists() && !it.isDirectory) { it.delete() }
if (!it.exists()) { it.mkdir() }
}
companion object {
private val TAG = MainActivity::class.java.simpleName
private const val DIRECTORY_MODELS = "models"
private const val FILE_EXTENSION_GGUF = ".gguf"
private const val BENCH_PROMPT_PROCESSING_TOKENS = 512
private const val BENCH_TOKEN_GENERATION_TOKENS = 128
private const val BENCH_SEQUENCE = 1
private const val BENCH_REPETITION = 3
}
}
fun GgufMetadata.filename() = when {
basic.name != null -> {
basic.name?.let { name ->
basic.sizeLabel?.let { size ->
"$name-$size"
} ?: name
@Composable
fun MainCompose(
viewModel: MainViewModel,
clipboard: ClipboardManager,
dm: DownloadManager,
models: List<Downloadable>
) {
Column {
val scrollState = rememberLazyListState()
Box(modifier = Modifier.weight(1f)) {
LazyColumn(state = scrollState) {
items(viewModel.messages) {
Text(
it,
style = MaterialTheme.typography.bodyLarge.copy(color = LocalContentColor.current),
modifier = Modifier.padding(16.dp)
)
}
}
}
}
architecture?.architecture != null -> {
architecture?.architecture?.let { arch ->
basic.uuid?.let { uuid ->
"$arch-$uuid"
} ?: "$arch-${System.currentTimeMillis()}"
OutlinedTextField(
value = viewModel.message,
onValueChange = { viewModel.updateMessage(it) },
label = { Text("Message") },
)
Row {
Button({ viewModel.send() }) { Text("Send") }
Button({ viewModel.bench(8, 4, 1) }) { Text("Bench") }
Button({ viewModel.clear() }) { Text("Clear") }
Button({
viewModel.messages.joinToString("\n").let {
clipboard.setPrimaryClip(ClipData.newPlainText("", it))
}
}) { Text("Copy") }
}
Column {
for (model in models) {
Downloadable.Button(viewModel, dm, model)
}
}
}
else -> {
"model-${System.currentTimeMillis().toHexString()}"
}
}

View File

@ -0,0 +1,105 @@
package com.example.llama
import android.llama.cpp.LLamaAndroid
import android.util.Log
import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.setValue
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch
class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() {
companion object {
@JvmStatic
private val NanosPerSecond = 1_000_000_000.0
}
private val tag: String? = this::class.simpleName
var messages by mutableStateOf(listOf("Initializing..."))
private set
var message by mutableStateOf("")
private set
override fun onCleared() {
super.onCleared()
viewModelScope.launch {
try {
llamaAndroid.unload()
} catch (exc: IllegalStateException) {
messages += exc.message!!
}
}
}
fun send() {
val text = message
message = ""
// Add to messages console.
messages += text
messages += ""
viewModelScope.launch {
llamaAndroid.send(text)
.catch {
Log.e(tag, "send() failed", it)
messages += it.message!!
}
.collect { messages = messages.dropLast(1) + (messages.last() + it) }
}
}
fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) {
viewModelScope.launch {
try {
val start = System.nanoTime()
val warmupResult = llamaAndroid.bench(pp, tg, pl, nr)
val end = System.nanoTime()
messages += warmupResult
val warmup = (end - start).toDouble() / NanosPerSecond
messages += "Warm up time: $warmup seconds, please wait..."
if (warmup > 5.0) {
messages += "Warm up took too long, aborting benchmark"
return@launch
}
messages += llamaAndroid.bench(512, 128, 1, 3)
} catch (exc: IllegalStateException) {
Log.e(tag, "bench() failed", exc)
messages += exc.message!!
}
}
}
fun load(pathToModel: String) {
viewModelScope.launch {
try {
llamaAndroid.load(pathToModel)
messages += "Loaded $pathToModel"
} catch (exc: IllegalStateException) {
Log.e(tag, "load() failed", exc)
messages += exc.message!!
}
}
}
fun updateMessage(newMessage: String) {
message = newMessage
}
fun clear() {
messages = listOf()
}
fun log(message: String) {
messages += message
}
}

View File

@ -1,51 +0,0 @@
package com.example.llama
import android.view.LayoutInflater
import android.view.View
import android.view.ViewGroup
import android.widget.TextView
import androidx.recyclerview.widget.RecyclerView
data class Message(
val id: String,
val content: String,
val isUser: Boolean
)
class MessageAdapter(
private val messages: List<Message>
) : RecyclerView.Adapter<RecyclerView.ViewHolder>() {
companion object {
private const val VIEW_TYPE_USER = 1
private const val VIEW_TYPE_ASSISTANT = 2
}
override fun getItemViewType(position: Int): Int {
return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT
}
override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder {
val layoutInflater = LayoutInflater.from(parent.context)
return if (viewType == VIEW_TYPE_USER) {
val view = layoutInflater.inflate(R.layout.item_message_user, parent, false)
UserMessageViewHolder(view)
} else {
val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false)
AssistantMessageViewHolder(view)
}
}
override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) {
val message = messages[position]
if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) {
val textView = holder.itemView.findViewById<TextView>(R.id.msg_content)
textView.text = message.content
}
}
override fun getItemCount(): Int = messages.size
class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
}

View File

@ -0,0 +1,11 @@
package com.example.llama.ui.theme
import androidx.compose.ui.graphics.Color
val Purple80 = Color(0xFFD0BCFF)
val PurpleGrey80 = Color(0xFFCCC2DC)
val Pink80 = Color(0xFFEFB8C8)
val Purple40 = Color(0xFF6650a4)
val PurpleGrey40 = Color(0xFF625b71)
val Pink40 = Color(0xFF7D5260)

View File

@ -0,0 +1,70 @@
package com.example.llama.ui.theme
import android.app.Activity
import android.os.Build
import androidx.compose.foundation.isSystemInDarkTheme
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.darkColorScheme
import androidx.compose.material3.dynamicDarkColorScheme
import androidx.compose.material3.dynamicLightColorScheme
import androidx.compose.material3.lightColorScheme
import androidx.compose.runtime.Composable
import androidx.compose.runtime.SideEffect
import androidx.compose.ui.graphics.toArgb
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.LocalView
import androidx.core.view.WindowCompat
private val DarkColorScheme = darkColorScheme(
primary = Purple80,
secondary = PurpleGrey80,
tertiary = Pink80
)
private val LightColorScheme = lightColorScheme(
primary = Purple40,
secondary = PurpleGrey40,
tertiary = Pink40
/* Other default colors to override
background = Color(0xFFFFFBFE),
surface = Color(0xFFFFFBFE),
onPrimary = Color.White,
onSecondary = Color.White,
onTertiary = Color.White,
onBackground = Color(0xFF1C1B1F),
onSurface = Color(0xFF1C1B1F),
*/
)
@Composable
fun LlamaAndroidTheme(
darkTheme: Boolean = isSystemInDarkTheme(),
// Dynamic color is available on Android 12+
dynamicColor: Boolean = true,
content: @Composable () -> Unit
) {
val colorScheme = when {
dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {
val context = LocalContext.current
if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)
}
darkTheme -> DarkColorScheme
else -> LightColorScheme
}
val view = LocalView.current
if (!view.isInEditMode) {
SideEffect {
val window = (view.context as Activity).window
window.statusBarColor = colorScheme.primary.toArgb()
WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme
}
}
MaterialTheme(
colorScheme = colorScheme,
typography = Typography,
content = content
)
}

View File

@ -0,0 +1,34 @@
package com.example.llama.ui.theme
import androidx.compose.material3.Typography
import androidx.compose.ui.text.TextStyle
import androidx.compose.ui.text.font.FontFamily
import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.unit.sp
// Set of Material typography styles to start with
val Typography = Typography(
bodyLarge = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Normal,
fontSize = 16.sp,
lineHeight = 24.sp,
letterSpacing = 0.5.sp
)
/* Other default text styles to override
titleLarge = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Normal,
fontSize = 22.sp,
lineHeight = 28.sp,
letterSpacing = 0.sp
),
labelSmall = TextStyle(
fontFamily = FontFamily.Default,
fontWeight = FontWeight.Medium,
fontSize = 11.sp,
lineHeight = 16.sp,
letterSpacing = 0.5.sp
)
*/
)

View File

@ -1,4 +0,0 @@
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
<solid android:color="#E5E5EA" />
<corners android:radius="16dp" />
</shape>

View File

@ -1,4 +0,0 @@
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
<solid android:color="#4285F4" />
<corners android:radius="16dp" />
</shape>

View File

@ -1,10 +0,0 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal">
<path
android:fillColor="@android:color/white"
android:pathData="M20,6h-8l-2,-2L4,4c-1.1,0 -1.99,0.9 -1.99,2L2,18c0,1.1 0.9,2 2,2h16c1.1,0 2,-0.9 2,-2L22,8c0,-1.1 -0.9,-2 -2,-2zM20,18L4,18L4,8h16v10z"/>
</vector>

View File

@ -1,11 +0,0 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal"
android:autoMirrored="true">
<path
android:fillColor="@android:color/white"
android:pathData="M4.01,6.03l7.51,3.22 -7.52,-1 0.01,-2.22m7.5,8.72L4,17.97v-2.22l7.51,-1M2.01,3L2,10l15,2 -15,2 0.01,7L23,12 2.01,3z"/>
</vector>

View File

@ -1,76 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:id="@+id/main"
android:layout_height="match_parent"
android:layout_width="match_parent">
<LinearLayout
android:fitsSystemWindows="true"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
tools:context=".MainActivity">
<FrameLayout
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="1">
<ScrollView
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:fadeScrollbars="false">
<TextView
android:id="@+id/gguf"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_margin="16dp"
android:text="Selected GGUF model's metadata will show here."
style="@style/TextAppearance.MaterialComponents.Body2"
android:maxLines="100" />
</ScrollView>
</FrameLayout>
<androidx.recyclerview.widget.RecyclerView
android:id="@+id/messages"
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="4"
android:padding="16dp"
android:fadeScrollbars="false"
app:reverseLayout="true"
tools:listitem="@layout/item_message_assistant"/>
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:orientation="horizontal">
<EditText
android:id="@+id/user_input"
android:enabled="false"
android:layout_width="0dp"
android:layout_weight="1"
android:layout_height="match_parent"
android:padding="8dp"
style="@style/TextAppearance.MaterialComponents.Body2"
android:hint="Please first pick a GGUF model file to import." />
<com.google.android.material.floatingactionbutton.FloatingActionButton
android:id="@+id/fab"
android:enabled="true"
style="@style/Widget.Material3.FloatingActionButton.Primary"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_margin="8dp"
android:src="@drawable/outline_folder_open_24" />
</LinearLayout>
</LinearLayout>
</androidx.constraintlayout.widget.ConstraintLayout>

View File

@ -1,15 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:padding="8dp"
android:gravity="start">
<TextView
android:id="@+id/msg_content"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:background="@drawable/bg_assistant_message"
android:padding="12dp"
android:textColor="@android:color/black" />
</LinearLayout>

View File

@ -1,15 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:padding="8dp"
android:gravity="end">
<TextView
android:id="@+id/msg_content"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:background="@drawable/bg_user_message"
android:padding="12dp"
android:textColor="@android:color/white" />
</LinearLayout>

View File

@ -1,3 +1,3 @@
<resources>
<string name="app_name">AI Chat basic sample</string>
<string name="app_name">LlamaAndroid</string>
</resources>

View File

@ -1,10 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<style name="Base.Theme.AiChatSample" parent="Theme.Material3.DayNight.NoActionBar">
<!-- Customize your light theme here. -->
<!-- <item name="colorPrimary">@color/my_light_primary</item> -->
</style>
<style name="Theme.AiChatSample" parent="Base.Theme.AiChatSample" />
<style name="Theme.LlamaAndroid" parent="android:Theme.Material.Light.NoActionBar" />
</resources>

View File

@ -1,6 +1,6 @@
// Top-level build file where you can add configuration options common to all sub-projects/modules.
plugins {
alias(libs.plugins.android.application) apply false
alias(libs.plugins.android.library) apply false
alias(libs.plugins.jetbrains.kotlin.android) apply false
id("com.android.application") version "8.2.0" apply false
id("org.jetbrains.kotlin.android") version "1.9.0" apply false
id("com.android.library") version "8.2.0" apply false
}

View File

@ -21,4 +21,3 @@ kotlin.code.style=official
# resources declared in the library itself and none from the library's dependencies,
# thereby reducing the size of the R class for that library
android.nonTransitiveRClass=true
android.native.buildOutput=verbose

View File

@ -1,53 +0,0 @@
[versions]
# Plugins
agp = "8.13.0"
kotlin = "2.2.20"
# AndroidX
activity = "1.11.0"
appcompat = "1.7.1"
core-ktx = "1.17.0"
constraint-layout = "2.2.1"
datastore-preferences = "1.1.7"
# Material
material = "1.13.0"
# Testing
espresso-core = "3.7.0"
androidx-junit = "1.3.0"
junit = "4.13.2"
[plugins]
android-application = { id = "com.android.application", version.ref = "agp" }
android-library = { id = "com.android.library", version.ref = "agp" }
jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
[libraries]
# AndroidX
androidx-activity = { group = "androidx.activity", name = "activity", version.ref = "activity" }
androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" }
androidx-constraintlayout = { group = "androidx.constraintlayout", name = "constraintlayout", version.ref = "constraint-layout" }
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "core-ktx" }
androidx-datastore-preferences = { group = "androidx.datastore", name = "datastore-preferences", version.ref = "datastore-preferences" }
#Material
material = { group = "com.google.android.material", name = "material", version.ref = "material" }
# Testing
androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espresso-core" }
androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-junit" }
junit = { group = "junit", name = "junit", version.ref = "junit" }
[bundles]
androidx = [
"androidx-activity",
"androidx-appcompat",
"androidx-constraintlayout",
"androidx-core-ktx",
"androidx-datastore-preferences",
]

View File

@ -1,6 +1,6 @@
#Tue Apr 01 11:15:06 PDT 2025
#Thu Dec 21 14:31:09 AEDT 2023
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

View File

@ -1,78 +0,0 @@
plugins {
alias(libs.plugins.android.library)
alias(libs.plugins.jetbrains.kotlin.android)
}
android {
namespace = "com.arm.aichat"
compileSdk = 36
ndkVersion = "29.0.13113456"
defaultConfig {
minSdk = 33
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
consumerProguardFiles("consumer-rules.pro")
ndk {
abiFilters += listOf("arm64-v8a", "x86_64")
}
externalNativeBuild {
cmake {
arguments += "-DCMAKE_BUILD_TYPE=Release"
arguments += "-DCMAKE_MESSAGE_LOG_LEVEL=DEBUG"
arguments += "-DCMAKE_VERBOSE_MAKEFILE=ON"
arguments += "-DBUILD_SHARED_LIBS=ON"
arguments += "-DLLAMA_BUILD_COMMON=ON"
arguments += "-DLLAMA_CURL=OFF"
arguments += "-DGGML_NATIVE=OFF"
arguments += "-DGGML_BACKEND_DL=ON"
arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
arguments += "-DGGML_LLAMAFILE=OFF"
}
}
aarMetadata {
minCompileSdk = 35
}
}
externalNativeBuild {
cmake {
path("src/main/cpp/CMakeLists.txt")
version = "3.31.6"
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
kotlin {
jvmToolchain(17)
compileOptions {
targetCompatibility = JavaVersion.VERSION_17
}
}
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
}
}
publishing {
singleVariant("release") {
withJavadocJar()
}
}
}
dependencies {
implementation(libs.androidx.core.ktx)
implementation(libs.androidx.datastore.preferences)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
}

View File

@ -1,8 +0,0 @@
-keep class com.arm.aichat.* { *; }
-keep class com.arm.aichat.gguf.* { *; }
-keepclasseswithmembernames class * {
native <methods>;
}
-keep class kotlin.Metadata { *; }

View File

@ -1,56 +0,0 @@
cmake_minimum_required(VERSION 3.31.6)
project("ai-chat" VERSION 1.0.0 LANGUAGES C CXX)
set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED true)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED true)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE)
# --------------------------------------------------------------------------
# AI Chat library
# --------------------------------------------------------------------------
if(DEFINED ANDROID_ABI)
message(STATUS "Detected Android ABI: ${ANDROID_ABI}")
if(ANDROID_ABI STREQUAL "arm64-v8a")
set(GGML_SYSTEM_ARCH "ARM")
set(GGML_CPU_KLEIDIAI ON)
set(GGML_OPENMP ON)
elseif(ANDROID_ABI STREQUAL "x86_64")
set(GGML_SYSTEM_ARCH "x86")
set(GGML_CPU_KLEIDIAI OFF)
set(GGML_OPENMP OFF)
else()
message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}")
endif()
endif()
set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
add_subdirectory(${LLAMA_SRC} build-llama)
add_library(${CMAKE_PROJECT_NAME} SHARED
ai_chat.cpp)
target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH}
GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}>
GGML_OPENMP=$<BOOL:${GGML_OPENMP}>
)
target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
${LLAMA_SRC}
${LLAMA_SRC}/common
${LLAMA_SRC}/include
${LLAMA_SRC}/ggml/include
${LLAMA_SRC}/ggml/src)
target_link_libraries(${CMAKE_PROJECT_NAME}
llama
common
android
log)

View File

@ -1,565 +0,0 @@
#include <android/log.h>
#include <jni.h>
#include <iomanip>
#include <cmath>
#include <string>
#include <unistd.h>
#include <sampling.h>
#include "logging.h"
#include "chat.h"
#include "common.h"
#include "llama.h"
template<class T>
static std::string join(const std::vector<T> &values, const std::string &delim) {
std::ostringstream str;
for (size_t i = 0; i < values.size(); i++) {
str << values[i];
if (i < values.size() - 1) { str << delim; }
}
return str.str();
}
/**
* LLama resources: context, model, batch and sampler
*/
constexpr int N_THREADS_MIN = 2;
constexpr int N_THREADS_MAX = 4;
constexpr int N_THREADS_HEADROOM = 2;
constexpr int DEFAULT_CONTEXT_SIZE = 8192;
constexpr int OVERFLOW_HEADROOM = 4;
constexpr int BATCH_SIZE = 512;
constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
static llama_model * g_model;
static llama_context * g_context;
static llama_batch g_batch;
static common_chat_templates_ptr g_chat_templates;
static common_sampler * g_sampler;
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) {
// Set llama log handler to Android
llama_log_set(aichat_android_log_callback, nullptr);
// Loading all CPU backend variants
const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0);
LOGi("Loading backends from %s", path_to_backend);
ggml_backend_load_all_from_path(path_to_backend);
env->ReleaseStringUTFChars(nativeLibDir, path_to_backend);
// Initialize backends
llama_backend_init();
LOGi("Backend initiated; Log handler set.");
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
llama_model_params model_params = llama_model_default_params();
const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
auto *model = llama_model_load_from_file(model_path, model_params);
env->ReleaseStringUTFChars(jmodel_path, model_path);
if (!model) {
return 1;
}
g_model = model;
return 0;
}
static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
if (!model) {
LOGe("%s: model cannot be null", __func__);
return nullptr;
}
// Multi-threading setup
const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
(int) sysconf(_SC_NPROCESSORS_ONLN) -
N_THREADS_HEADROOM));
LOGi("%s: Using %d threads", __func__, n_threads);
// Context parameters setup
llama_context_params ctx_params = llama_context_default_params();
const int trained_context_size = llama_model_n_ctx_train(model);
if (n_ctx > trained_context_size) {
LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
__func__, trained_context_size, n_ctx);
}
ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = BATCH_SIZE;
ctx_params.n_ubatch = BATCH_SIZE;
ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads;
auto *context = llama_init_from_model(g_model, ctx_params);
if (context == nullptr) {
LOGe("%s: llama_new_context_with_model() returned null)", __func__);
}
return context;
}
static common_sampler *new_sampler(float temp) {
common_params_sampling sparams;
sparams.temp = temp;
return common_sampler_init(g_model, sparams);
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
auto *context = init_context(g_model);
if (!context) { return 1; }
g_context = context;
g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
g_chat_templates = common_chat_templates_init(g_model, "");
g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
return 0;
}
static std::string get_backend() {
std::vector<std::string> backends;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
auto *reg = ggml_backend_reg_get(i);
std::string name = ggml_backend_reg_name(reg);
if (name != "CPU") {
backends.push_back(ggml_backend_reg_name(reg));
}
}
return backends.empty() ? "CPU" : join(backends, ",");
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
return env->NewStringUTF(llama_print_system_info());
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
jint pl, jint nr) {
auto *context = init_context(g_model, pp);
if (!context) {
const auto *const err_msg = "Fail to init_context! Bench aborted.";
LOGe(err_msg);
return env->NewStringUTF(err_msg);
}
auto pp_avg = 0.0;
auto tg_avg = 0.0;
auto pp_std = 0.0;
auto tg_std = 0.0;
const uint32_t n_ctx = llama_n_ctx(context);
LOGi("n_ctx = %d", n_ctx);
int i, j;
int nri;
for (nri = 0; nri < nr; nri++) {
LOGi("Benchmark prompt processing (pp = %d)", pp);
common_batch_clear(g_batch);
const int n_tokens = pp;
for (i = 0; i < n_tokens; i++) {
common_batch_add(g_batch, 0, i, {0}, false);
}
g_batch.logits[g_batch.n_tokens - 1] = true;
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp_start = ggml_time_us();
if (llama_decode(context, g_batch) != 0) {
LOGe("llama_decode() failed during prompt processing");
}
const auto t_pp_end = ggml_time_us();
// bench text generation
LOGi("Benchmark text generation (tg = %d)", tg);
llama_memory_clear(llama_get_memory(context), false);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {
common_batch_clear(g_batch);
for (j = 0; j < pl; j++) {
common_batch_add(g_batch, 0, i, {j}, true);
}
if (llama_decode(context, g_batch) != 0) {
LOGe("llama_decode() failed during text generation");
}
}
const auto t_tg_end = ggml_time_us();
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
const auto speed_pp = double(pp) / t_pp;
const auto speed_tg = double(pl * tg) / t_tg;
pp_avg += speed_pp;
tg_avg += speed_tg;
pp_std += speed_pp * speed_pp;
tg_std += speed_tg * speed_tg;
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
}
llama_free(context);
pp_avg /= double(nr);
tg_avg /= double(nr);
if (nr > 1) {
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
} else {
pp_std = 0;
tg_std = 0;
}
char model_desc[128];
llama_model_desc(g_model, model_desc, sizeof(model_desc));
const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
const auto backend = get_backend();
std::stringstream result;
result << std::setprecision(3);
result << "| model | size | params | backend | test | t/s |\n";
result << "| --- | --- | --- | --- | --- | --- |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
<< backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
<< backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
return env->NewStringUTF(result.str().c_str());
}
/**
* Completion loop's long-term states:
* - chat management
* - position tracking
*/
constexpr const char *ROLE_SYSTEM = "system";
constexpr const char *ROLE_USER = "user";
constexpr const char *ROLE_ASSISTANT = "assistant";
static std::vector<common_chat_msg> chat_msgs;
static llama_pos system_prompt_position;
static llama_pos current_position;
static void reset_long_term_states(const bool clear_kv_cache = true) {
chat_msgs.clear();
system_prompt_position = 0;
current_position = 0;
if (clear_kv_cache)
llama_memory_clear(llama_get_memory(g_context), false);
}
/**
* TODO-hyin: implement sliding-window version as a better alternative
*
* Context shifting by discarding the older half of the tokens appended after system prompt:
* - take the [system_prompt_position] first tokens from the original prompt
* - take half of the last (system_prompt_position - system_prompt_position) tokens
* - recompute the logits in batches
*/
static void shift_context() {
const int n_discard = (current_position - system_prompt_position) / 2;
LOGi("%s: Discarding %d tokens", __func__, n_discard);
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
current_position -= n_discard;
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
}
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
common_chat_msg new_msg;
new_msg.role = role;
new_msg.content = content;
auto formatted = common_chat_format_single(
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
chat_msgs.push_back(new_msg);
LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
return formatted;
}
/**
* Completion loop's short-term states:
* - stop generation position
* - token chars caching
* - current assistant message being generated
*/
static llama_pos stop_generation_position;
static std::string cached_token_chars;
static std::ostringstream assistant_ss;
static void reset_short_term_states() {
stop_generation_position = 0;
cached_token_chars.clear();
assistant_ss.str("");
}
static int decode_tokens_in_batches(
llama_context *context,
llama_batch &batch,
const llama_tokens &tokens,
const llama_pos start_pos,
const bool compute_last_logit = false) {
// Process tokens in batches using the global batch
LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
common_batch_clear(batch);
LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
// Shift context if current batch cannot fit into the context
if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
shift_context();
}
// Add tokens to the batch with proper positions
for (int j = 0; j < cur_batch_size; j++) {
const llama_token token_id = tokens[i + j];
const llama_pos position = start_pos + i + j;
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
common_batch_add(batch, token_id, position, {0}, want_logit);
}
// Decode this batch
const int decode_result = llama_decode(context, batch);
if (decode_result) {
LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
return 1;
}
}
return 0;
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt(
JNIEnv *env,
jobject /*unused*/,
jstring jsystem_prompt
) {
// Reset long-term & short-term states
reset_long_term_states();
reset_short_term_states();
// Obtain system prompt from JEnv
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
std::string formatted_system_prompt(system_prompt);
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
// Format system prompt if applicable
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
if (has_chat_template) {
formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
}
// Tokenize system prompt
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
has_chat_template, has_chat_template);
for (auto id: system_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// Handle context overflow
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
if ((int) system_tokens.size() > max_batch_size) {
LOGe("%s: System prompt too long for context! %d tokens, max: %d",
__func__, (int) system_tokens.size(), max_batch_size);
return 1;
}
// Decode system tokens in batches
if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
LOGe("%s: llama_decode() failed!", __func__);
return 2;
}
// Update position
system_prompt_position = current_position = (int) system_tokens.size();
return 0;
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt(
JNIEnv *env,
jobject /*unused*/,
jstring juser_prompt,
jint n_predict
) {
// Reset short-term states
reset_short_term_states();
// Obtain and tokenize user prompt
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
std::string formatted_user_prompt(user_prompt);
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
// Format user prompt if applicable
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
if (has_chat_template) {
formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
}
// Decode formatted user prompts
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
for (auto id: user_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
const int user_prompt_size = (int) user_tokens.size();
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
if (user_prompt_size > max_batch_size) {
const int skipped_tokens = user_prompt_size - max_batch_size;
user_tokens.resize(max_batch_size);
LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
}
// Decode user tokens in batches
if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
LOGe("%s: llama_decode() failed!", __func__);
return 2;
}
// Update position
current_position += user_prompt_size;
stop_generation_position = current_position + user_prompt_size + n_predict;
return 0;
}
static bool is_valid_utf8(const char *string) {
if (!string) { return true; }
const auto *bytes = (const unsigned char *) string;
int num;
while (*bytes != 0x00) {
if ((*bytes & 0x80) == 0x00) {
// U+0000 to U+007F
num = 1;
} else if ((*bytes & 0xE0) == 0xC0) {
// U+0080 to U+07FF
num = 2;
} else if ((*bytes & 0xF0) == 0xE0) {
// U+0800 to U+FFFF
num = 3;
} else if ((*bytes & 0xF8) == 0xF0) {
// U+10000 to U+10FFFF
num = 4;
} else {
return false;
}
bytes += 1;
for (int i = 1; i < num; ++i) {
if ((*bytes & 0xC0) != 0x80) {
return false;
}
bytes += 1;
}
}
return true;
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken(
JNIEnv *env,
jobject /*unused*/
) {
// Infinite text generation via context shifting
if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
LOGw("%s: Context full! Shifting...", __func__);
shift_context();
}
// Stop if reaching the marked position
if (current_position >= stop_generation_position) {
LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
return nullptr;
}
// Sample next token
const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
common_sampler_accept(g_sampler, new_token_id, true);
// Populate the batch with new token, then decode
common_batch_clear(g_batch);
common_batch_add(g_batch, new_token_id, current_position, {0}, true);
if (llama_decode(g_context, g_batch) != 0) {
LOGe("%s: llama_decode() failed for generated token", __func__);
return nullptr;
}
// Update position
current_position++;
// Stop if next token is EOG
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
return nullptr;
}
// If not EOG, convert to text
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
cached_token_chars += new_token_chars;
// Create and return a valid UTF-8 Java string
jstring result = nullptr;
if (is_valid_utf8(cached_token_chars.c_str())) {
result = env->NewStringUTF(cached_token_chars.c_str());
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
assistant_ss << cached_token_chars;
cached_token_chars.clear();
} else {
LOGv("id: %d,\tappend to cache", new_token_id);
result = env->NewStringUTF("");
}
return result;
}
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
// Reset long-term & short-term states
reset_long_term_states();
reset_short_term_states();
// Free up resources
common_sampler_free(g_sampler);
g_chat_templates.reset();
llama_batch_free(g_batch);
llama_free(g_context);
llama_model_free(g_model);
}
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *env, jobject /*unused*/) {
llama_backend_free();
}

View File

@ -1,61 +0,0 @@
//
// Created by Han Yin on 10/31/25.
//
#ifndef AICHAT_LOGGING_H
#define AICHAT_LOGGING_H
#endif //AICHAT_LOGGING_H
#pragma once
#include <android/log.h>
#ifndef LOG_TAG
#define LOG_TAG "ai-chat"
#endif
#ifndef LOG_MIN_LEVEL
#if defined(NDEBUG)
#define LOG_MIN_LEVEL ANDROID_LOG_INFO
#else
#define LOG_MIN_LEVEL ANDROID_LOG_VERBOSE
#endif
#endif
static inline int ai_should_log(int prio) {
return __android_log_is_loggable(prio, LOG_TAG, LOG_MIN_LEVEL);
}
#if LOG_MIN_LEVEL <= ANDROID_LOG_VERBOSE
#define LOGv(...) do { if (ai_should_log(ANDROID_LOG_VERBOSE)) __android_log_print(ANDROID_LOG_VERBOSE, LOG_TAG, __VA_ARGS__); } while (0)
#else
#define LOGv(...) ((void)0)
#endif
#if LOG_MIN_LEVEL <= ANDROID_LOG_DEBUG
#define LOGd(...) do { if (ai_should_log(ANDROID_LOG_DEBUG)) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__); } while (0)
#else
#define LOGd(...) ((void)0)
#endif
#define LOGi(...) do { if (ai_should_log(ANDROID_LOG_INFO )) __android_log_print(ANDROID_LOG_INFO , LOG_TAG, __VA_ARGS__); } while (0)
#define LOGw(...) do { if (ai_should_log(ANDROID_LOG_WARN )) __android_log_print(ANDROID_LOG_WARN , LOG_TAG, __VA_ARGS__); } while (0)
#define LOGe(...) do { if (ai_should_log(ANDROID_LOG_ERROR)) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__); } while (0)
static inline int android_log_prio_from_ggml(enum ggml_log_level level) {
switch (level) {
case GGML_LOG_LEVEL_ERROR: return ANDROID_LOG_ERROR;
case GGML_LOG_LEVEL_WARN: return ANDROID_LOG_WARN;
case GGML_LOG_LEVEL_INFO: return ANDROID_LOG_INFO;
case GGML_LOG_LEVEL_DEBUG: return ANDROID_LOG_DEBUG;
default: return ANDROID_LOG_DEFAULT;
}
}
static inline void aichat_android_log_callback(enum ggml_log_level level,
const char* text,
void* /*user*/) {
const int prio = android_log_prio_from_ggml(level);
if (!ai_should_log(prio)) return;
__android_log_write(prio, LOG_TAG, text);
}

View File

@ -1,14 +0,0 @@
package com.arm.aichat
import android.content.Context
import com.arm.aichat.internal.InferenceEngineImpl
/**
* Main entry point for Arm's AI Chat library.
*/
object AiChat {
/**
* Get the inference engine single instance.
*/
fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context)
}

View File

@ -1,89 +0,0 @@
package com.arm.aichat
import com.arm.aichat.InferenceEngine.State
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.StateFlow
/**
* Interface defining the core LLM inference operations.
*/
interface InferenceEngine {
/**
* Current state of the inference engine
*/
val state: StateFlow<State>
/**
* Load a model from the given path.
*
* @throws UnsupportedArchitectureException if model architecture not supported
*/
suspend fun loadModel(pathToModel: String)
/**
* Sends a system prompt to the loaded model
*/
suspend fun setSystemPrompt(systemPrompt: String)
/**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
/**
* Runs a benchmark with the specified parameters.
*/
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
/**
* Unloads the currently loaded model.
*/
suspend fun cleanUp()
/**
* Cleans up resources when the engine is no longer needed.
*/
fun destroy()
/**
* States of the inference engine
*/
sealed class State {
object Uninitialized : State()
object Initializing : State()
object Initialized : State()
object LoadingModel : State()
object UnloadingModel : State()
object ModelReady : State()
object Benchmarking : State()
object ProcessingSystemPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()
data class Error(val exception: Exception) : State()
}
companion object {
const val DEFAULT_PREDICT_LENGTH = 1024
}
}
val State.isUninterruptible
get() = this is State.Initializing ||
this is State.LoadingModel ||
this is State.UnloadingModel ||
this is State.Benchmarking ||
this is State.ProcessingSystemPrompt ||
this is State.ProcessingUserPrompt
val State.isModelLoaded: Boolean
get() = this is State.ModelReady ||
this is State.Benchmarking ||
this is State.ProcessingSystemPrompt ||
this is State.ProcessingUserPrompt ||
this is State.Generating
class UnsupportedArchitectureException : Exception()

View File

@ -1,61 +0,0 @@
package com.arm.aichat.gguf
import kotlin.collections.get
/**
* Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`).
* The `label` matches what llamacli prints.
*/
enum class FileType(val code: Int, val label: String) {
ALL_F32(0, "all F32"),
MOSTLY_F16(1, "F16"),
MOSTLY_Q4_0(2, "Q4_0"),
MOSTLY_Q4_1(3, "Q4_1"),
// 4 removed
MOSTLY_Q8_0(7, "Q8_0"),
MOSTLY_Q5_0(8, "Q5_0"),
MOSTLY_Q5_1(9, "Q5_1"),
/* Kquants ------------------------------------------------------------ */
MOSTLY_Q2_K (10, "Q2_K - Medium"),
MOSTLY_Q3_K_S (11, "Q3_K - Small"),
MOSTLY_Q3_K_M (12, "Q3_K - Medium"),
MOSTLY_Q3_K_L (13, "Q3_K - Large"),
MOSTLY_Q4_K_S (14, "Q4_K - Small"),
MOSTLY_Q4_K_M (15, "Q4_K - Medium"),
MOSTLY_Q5_K_S (16, "Q5_K - Small"),
MOSTLY_Q5_K_M (17, "Q5_K - Medium"),
MOSTLY_Q6_K (18, "Q6_K"),
/* IQ quants ----------------------------------------------------------- */
MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"),
MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"),
MOSTLY_Q2_K_S (21, "Q2_K - Small"),
MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"),
MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"),
MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"),
MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"),
MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"),
MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"),
MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"),
MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"),
MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"),
MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"),
/* BF16 & Ternary ------------------------------------------------------ */
MOSTLY_BF16 (32, "BF16"),
MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"),
MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"),
/* Special flag -------------------------------------------------------- */
GUESSED(1024, "(guessed)"),
UNKNOWN(-1, "unknown");
companion object {
private val map = entries.associateBy(FileType::code)
fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN
}
}

View File

@ -1,132 +0,0 @@
package com.arm.aichat.gguf
import java.io.IOException
/**
* Structured metadata of GGUF
*/
data class GgufMetadata(
// Basic file info
val version: GgufVersion,
val tensorCount: Long,
val kvCount: Long,
// General info
val basic: BasicInfo,
val author: AuthorInfo? = null,
val additional: AdditionalInfo? = null,
val architecture: ArchitectureInfo? = null,
val baseModels: List<BaseModelInfo>? = null,
val tokenizer: TokenizerInfo? = null,
// Derivative info
val dimensions: DimensionsInfo? = null,
val attention: AttentionInfo? = null,
val rope: RopeInfo? = null,
val experts: ExpertsInfo? = null
) {
enum class GgufVersion(val code: Int, val label: String) {
/** First public draft; littleendian only, no alignment key. */
LEGACY_V1(1, "Legacy v1"),
/** Added splitfile support and some extra metadata keys. */
EXTENDED_V2(2, "Extended v2"),
/** Current spec: endianaware, mandatory alignment, fully validated. */
VALIDATED_V3(3, "Validated v3");
companion object {
fun fromCode(code: Int): GgufVersion =
entries.firstOrNull { it.code == code }
?: throw IOException("Unknown GGUF version code $code")
}
override fun toString(): String = "$label (code=$code)"
}
data class BasicInfo(
val uuid: String? = null,
val name: String? = null,
val nameLabel: String? = null,
val sizeLabel: String? = null, // Size label like "7B"
)
data class AuthorInfo(
val organization: String? = null,
val author: String? = null,
val doi: String? = null,
val url: String? = null,
val repoUrl: String? = null,
val license: String? = null,
val licenseLink: String? = null,
)
data class AdditionalInfo(
val type: String? = null,
val description: String? = null,
val tags: List<String>? = null,
val languages: List<String>? = null,
)
data class ArchitectureInfo(
val architecture: String? = null,
val fileType: Int? = null,
val vocabSize: Int? = null,
val finetune: String? = null,
val quantizationVersion: Int? = null,
)
data class BaseModelInfo(
val name: String? = null,
val author: String? = null,
val version: String? = null,
val organization: String? = null,
val url: String? = null,
val doi: String? = null,
val uuid: String? = null,
val repoUrl: String? = null,
)
data class TokenizerInfo(
val model: String? = null,
val bosTokenId: Int? = null,
val eosTokenId: Int? = null,
val unknownTokenId: Int? = null,
val paddingTokenId: Int? = null,
val addBosToken: Boolean? = null,
val addEosToken: Boolean? = null,
val chatTemplate: String? = null,
)
data class DimensionsInfo(
val contextLength: Int? = null,
val embeddingSize: Int? = null,
val blockCount: Int? = null,
val feedForwardSize: Int? = null,
)
data class AttentionInfo(
val headCount: Int? = null,
val headCountKv: Int? = null,
val keyLength: Int? = null,
val valueLength: Int? = null,
val layerNormEpsilon: Float? = null,
val layerNormRmsEpsilon: Float? = null,
)
data class RopeInfo(
val frequencyBase: Float? = null,
val dimensionCount: Int? = null,
val scalingType: String? = null,
val scalingFactor: Float? = null,
val attnFactor: Float? = null,
val originalContextLength: Int? = null,
val finetuned: Boolean? = null,
)
data class ExpertsInfo(
val count: Int? = null,
val usedCount: Int? = null,
)
}

View File

@ -1,77 +0,0 @@
package com.arm.aichat.gguf
import android.content.Context
import android.net.Uri
import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl
import java.io.File
import java.io.IOException
import java.io.InputStream
/**
* Interface for reading GGUF metadata from model files.
* Use `GgufMetadataReader.create()` to get an instance.
*/
interface GgufMetadataReader {
/**
* Reads the magic number from the specified file path.
*
* @param file Java File to the GGUF file with absolute path
* @return true if file is valid GGUF, otherwise false
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun ensureSourceFileFormat(file: File): Boolean
/**
* Reads the magic number from the specified file path.
*
* @param context Context for obtaining [android.content.ContentProvider]
* @param uri Uri to the GGUF file provided by [android.content.ContentProvider]
* @return true if file is valid GGUF, otherwise false
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean
/**
* Reads and parses GGUF metadata from the specified file path.
*
* @param input the [InputStream] obtained from a readable file or content
* @return Structured metadata extracted from the file
* @throws IOException if file is damaged or cannot be read
* @throws InvalidFileFormatException if file format is invalid
*/
suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
companion object {
private val DEFAULT_SKIP_KEYS = setOf(
"tokenizer.chat_template",
"tokenizer.ggml.scores",
"tokenizer.ggml.tokens",
"tokenizer.ggml.token_type"
)
/**
* Creates a default GgufMetadataReader instance
*/
fun create(): GgufMetadataReader = GgufMetadataReaderImpl(
skipKeys = DEFAULT_SKIP_KEYS,
arraySummariseThreshold = 1_000
)
/**
* Creates a GgufMetadataReader with custom configuration
*
* @param skipKeys Keys whose value should be skipped entirely (not kept in the result map)
* @param arraySummariseThreshold If 0, arrays longer get summarised, not materialised;
* If -1, never summarise.
*/
fun create(
skipKeys: Set<String> = DEFAULT_SKIP_KEYS,
arraySummariseThreshold: Int = 1_000
): GgufMetadataReader = GgufMetadataReaderImpl(
skipKeys = skipKeys,
arraySummariseThreshold = arraySummariseThreshold
)
}
}
class InvalidFileFormatException : IOException()

View File

@ -1,309 +0,0 @@
package com.arm.aichat.internal
import android.content.Context
import android.util.Log
import com.arm.aichat.InferenceEngine
import com.arm.aichat.UnsupportedArchitectureException
import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance
import dalvik.annotation.optimization.FastNative
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
import java.io.IOException
/**
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
*
* This class implements a singleton pattern for managing the lifecycle of a single LLM instance.
* All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety
* with the underlying C++ native code.
*
* The typical usage flow is:
* 1. Get instance via [getInstance]
* 2. Load a model with [loadModel]
* 3. Send prompts with [sendUserPrompt]
* 4. Generate responses as token streams
* 5. Perform [cleanUp] when done with a model
* 6. Properly [destroy] when completely done
*
* State transitions are managed automatically and validated at each operation.
*
* @see ai_chat.cpp for the native implementation details
*/
internal class InferenceEngineImpl private constructor(
private val nativeLibDir: String
) : InferenceEngine {
companion object {
private val TAG = InferenceEngineImpl::class.java.simpleName
@Volatile
private var instance: InferenceEngine? = null
/**
* Create or obtain [InferenceEngineImpl]'s single instance.
*
* @param Context for obtaining native library directory
* @throws IllegalArgumentException if native library path is invalid
* @throws UnsatisfiedLinkError if library failed to load
*/
internal fun getInstance(context: Context) =
instance ?: synchronized(this) {
val nativeLibDir = context.applicationInfo.nativeLibraryDir
require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" }
try {
Log.i(TAG, "Instantiating InferenceEngineImpl,,,")
InferenceEngineImpl(nativeLibDir).also { instance = it }
} catch (e: UnsatisfiedLinkError) {
Log.e(TAG, "Failed to load native library from $nativeLibDir", e)
throw e
}
}
}
/**
* JNI methods
* @see ai_chat.cpp
*/
@FastNative
private external fun init(nativeLibDir: String)
@FastNative
private external fun load(modelPath: String): Int
@FastNative
private external fun prepare(): Int
@FastNative
private external fun systemInfo(): String
@FastNative
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
@FastNative
private external fun processSystemPrompt(systemPrompt: String): Int
@FastNative
private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
@FastNative
private external fun generateNextToken(): String?
@FastNative
private external fun unload()
@FastNative
private external fun shutdown()
private val _state =
MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized)
override val state: StateFlow<InferenceEngine.State> = _state
private var _readyForSystemPrompt = false
/**
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
*/
@OptIn(ExperimentalCoroutinesApi::class)
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
init {
llamaScope.launch {
try {
check(_state.value is InferenceEngine.State.Uninitialized) {
"Cannot load native library in ${_state.value.javaClass.simpleName}!"
}
_state.value = InferenceEngine.State.Initializing
Log.i(TAG, "Loading native library...")
System.loadLibrary("ai-chat")
init(nativeLibDir)
_state.value = InferenceEngine.State.Initialized
Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}")
} catch (e: Exception) {
Log.e(TAG, "Failed to load native library", e)
throw e
}
}
}
/**
* Load the LLM
*/
override suspend fun loadModel(pathToModel: String) =
withContext(llamaDispatcher) {
check(_state.value is InferenceEngine.State.Initialized) {
"Cannot load model in ${_state.value.javaClass.simpleName}!"
}
try {
Log.i(TAG, "Checking access to model file... \n$pathToModel")
File(pathToModel).let {
require(it.exists()) { "File not found" }
require(it.isFile) { "Not a valid file" }
require(it.canRead()) { "Cannot read file" }
}
Log.i(TAG, "Loading model... \n$pathToModel")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.LoadingModel
load(pathToModel).let {
// TODO-han.yin: find a better way to pass other error codes
if (it != 0) throw UnsupportedArchitectureException()
}
prepare().let {
if (it != 0) throw IOException("Failed to prepare resources")
}
Log.i(TAG, "Model loaded!")
_readyForSystemPrompt = true
_state.value = InferenceEngine.State.ModelReady
} catch (e: Exception) {
Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e)
_state.value = InferenceEngine.State.Error(e)
throw e
}
}
/**
* Process the plain text system prompt
*
* TODO-han.yin: return error code if system prompt not correct processed?
*/
override suspend fun setSystemPrompt(prompt: String) =
withContext(llamaDispatcher) {
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
check(_state.value is InferenceEngine.State.ModelReady) {
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
}
Log.i(TAG, "Sending system prompt...")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.ProcessingSystemPrompt
processSystemPrompt(prompt).let { result ->
if (result != 0) {
RuntimeException("Failed to process system prompt: $result").also {
_state.value = InferenceEngine.State.Error(it)
throw it
}
}
}
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
_state.value = InferenceEngine.State.ModelReady
}
/**
* Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
*/
override fun sendUserPrompt(
message: String,
predictLength: Int,
): Flow<String> = flow {
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
check(_state.value is InferenceEngine.State.ModelReady) {
"User prompt discarded due to: ${_state.value.javaClass.simpleName}"
}
try {
Log.i(TAG, "Sending user prompt...")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.ProcessingUserPrompt
processUserPrompt(message, predictLength).let { result ->
if (result != 0) {
Log.e(TAG, "Failed to process user prompt: $result")
return@flow
}
}
Log.i(TAG, "User prompt processed. Generating assistant prompt...")
_state.value = InferenceEngine.State.Generating
while (true) {
generateNextToken()?.let { utf8token ->
if (utf8token.isNotEmpty()) emit(utf8token)
} ?: break
}
Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
_state.value = InferenceEngine.State.ModelReady
} catch (e: CancellationException) {
Log.i(TAG, "Generation cancelled by user.")
_state.value = InferenceEngine.State.ModelReady
throw e
} catch (e: Exception) {
Log.e(TAG, "Error during generation!", e)
_state.value = InferenceEngine.State.Error(e)
throw e
}
}.flowOn(llamaDispatcher)
/**
* Benchmark the model
*/
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
withContext(llamaDispatcher) {
check(_state.value is InferenceEngine.State.ModelReady) {
"Benchmark request discarded due to: $state"
}
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
_readyForSystemPrompt = false // Just to be safe
_state.value = InferenceEngine.State.Benchmarking
benchModel(pp, tg, pl, nr).also {
_state.value = InferenceEngine.State.ModelReady
}
}
/**
* Unloads the model and frees resources, or reset error states
*/
override suspend fun cleanUp() =
withContext(llamaDispatcher) {
when (val state = _state.value) {
is InferenceEngine.State.ModelReady -> {
Log.i(TAG, "Unloading model and free resources...")
_readyForSystemPrompt = false
_state.value = InferenceEngine.State.UnloadingModel
unload()
_state.value = InferenceEngine.State.Initialized
Log.i(TAG, "Model unloaded!")
Unit
}
is InferenceEngine.State.Error -> {
Log.i(TAG, "Resetting error states...")
_state.value = InferenceEngine.State.Initialized
Log.i(TAG, "States reset!")
Unit
}
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
}
}
/**
* Cancel all ongoing coroutines and free GGML backends
*/
override fun destroy() {
_readyForSystemPrompt = false
llamaScope.cancel()
when(_state.value) {
is InferenceEngine.State.Uninitialized -> {}
is InferenceEngine.State.Initialized -> shutdown()
else -> { unload(); shutdown() }
}
}
}

View File

@ -1,590 +0,0 @@
package com.arm.aichat.internal.gguf
import android.content.Context
import android.net.Uri
import com.arm.aichat.gguf.GgufMetadata
import com.arm.aichat.gguf.GgufMetadataReader
import com.arm.aichat.gguf.InvalidFileFormatException
import java.io.File
import java.io.IOException
import java.io.InputStream
/**
* Utility class to read GGUF model files and extract metadata key-value pairs.
* This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data.
*/
internal class GgufMetadataReaderImpl(
private val skipKeys: Set<String>,
private val arraySummariseThreshold: Int,
) : GgufMetadataReader {
companion object {
private const val ARCH_LLAMA = "llama"
}
/** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */
enum class MetadataType(val code: Int) {
UINT8(0), INT8(1), UINT16(2), INT16(3),
UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
companion object {
private val codeMap = entries.associateBy(MetadataType::code)
fun fromCode(code: Int): MetadataType = codeMap[code]
?: throw IOException("Unknown metadata value type code: $code")
}
}
/** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */
sealed class MetadataValue {
data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int
data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int
data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian)
data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian)
data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian)
data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian)
data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float
data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true)
data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed)
data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue()
data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian)
data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian)
data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double
}
/* Convert MetadataValue to plain Kotlin primitives for allMetadata map */
private fun MetadataValue.toPrimitive(): Any = when (this) {
is MetadataValue.UInt8 -> value
is MetadataValue.Int8 -> value
is MetadataValue.UInt16 -> value
is MetadataValue.Int16 -> value
is MetadataValue.UInt32 -> value
is MetadataValue.Int32 -> value
is MetadataValue.Float32 -> value
is MetadataValue.Bool -> value
is MetadataValue.StringVal -> value
is MetadataValue.UInt64 -> value
is MetadataValue.Int64 -> value
is MetadataValue.Float64 -> value
is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
}
/**
* Reads the magic number from the specified file path.
*
* @param context Context for obtaining ContentResolver
* @param uri Uri to the GGUF file provided by ContentProvider
* @return true if file is valid GGUF, otherwise false
*/
override suspend fun ensureSourceFileFormat(file: File): Boolean =
file.inputStream().buffered().use { ensureMagic(it) }
/**
* Reads the magic number from the specified file path.
*
* @param context Context for obtaining ContentResolver
* @param uri Uri to the GGUF file provided by ContentProvider
* @return true if file is valid GGUF, otherwise false
*/
override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean =
context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
/** Reads the 4byte magic; throws if magic ≠ "GGUF". */
private fun ensureMagic(input: InputStream): Boolean =
ByteArray(4).let {
if (input.read(it) != 4) throw IOException("Not a valid file!")
it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
}
/**
* Highlevel entry point: parses a `.gguf` file on disk and returns the fully
* populated [GgufMetadata] tree.
*
* Steps performed internally:
* 1. Reads and validates the 8byte header (`"GGUF"` magic + version).
* 2. Streams through the keyvalue section, skipping large blobs if the key
* appears in [skipKeys] or if an array exceeds [arraySummariseThreshold].
* 3. Converts the resulting raw map into stronglytyped substructures
* (basic info, tokenizer, rope, etc.).
*
* The method is STREAMINGONLY: tensors are never mapped or loaded into
* memory, so even multiGB model files can be processed in < 50 ms.
*
* @param path Absolute or relative filesystem path to a `.gguf` file.
* @return A [GgufMetadata] instance containing all recognised metadata plus
* an `allMetadata` map with any keys that were not given a dedicated
* field.
* @throws IOException if the file is not GGUF, the version is unsupported,
* or the metadata block is truncated / corrupt.
*/
override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
// ── 1. header ──────────────────────────────────────────────────────────
// throws on mismatch
val version = ensureMagicAndVersion(input)
val tensorCount = readLittleLong(input)
val kvCount = readLittleLong(input)
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
// ── 3. build structured object ────────────────────────────────────────
return buildStructured(meta, version, tensorCount, kvCount)
}
/** Reads the 4byte magic + 4byte version; throws if magic ≠ "GGUF". */
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
if (!ensureMagic(input)) throw InvalidFileFormatException()
return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
}
/**
* Read an unsigned 32bit littleendian integer.
*
* @throws IOException if fewer than four bytes are available.
*/
private fun readLEUInt32(input: InputStream): Int {
val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
return (b3 and 0xFF shl 24) or
(b2 and 0xFF shl 16) or
(b1 and 0xFF shl 8) or
(b0 and 0xFF)
}
/**
* Lowlevel helper that reads the entire key-value section from the current
* stream position.
*
* @param input Open stream positioned JUST AFTER the header.
* @param kvCnt Number of keyvalue pairs (taken from the header).
* @return Mutable map with one [MetadataValue] for every key that is NOT skipped.
*
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking
* [skipValue] or [parseValue] accordingly.
*/
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> =
mutableMapOf<String, MetadataValue>().apply {
repeat(kvCnt.toInt()) {
val key = readString(input)
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
if (key in skipKeys) {
skipValue(input, valueT)
} else {
this[key] = parseValue(input, valueT)
}
}
}
/**
* Converts a flat [Map]<[String], [MetadataValue]> into the stronglytyped
* [GgufMetadata] tree used by the rest of the app.
*
* Only the keys listed in the spec are copied into dedicated data classes;
* everything else is preserved in `GgufMetadata.allMetadata`.
*
* @param m Raw key/value map.
* @param version GGUF fileformat version (enum).
* @param tensorCnt Number of tensors (from the header).
* @param kvCnt Total metadata pair count (from the header).
*/
private fun buildStructured(
m: Map<String, MetadataValue>,
version: GgufMetadata.GgufVersion,
tensorCnt: Long,
kvCnt: Long
): GgufMetadata {
// ---------- helpers ----------
fun String.str() = (m[this] as? MetadataValue.StringVal)?.value
fun String.bool() = (m[this] as? MetadataValue.Bool)?.value
fun String.i32() = (m[this] as? MetadataValue.Int32)?.value
fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt()
fun String.f32() = (m[this] as? MetadataValue.Float32)?.value
fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat()
fun String.strList(): List<String>? =
(m[this] as? MetadataValue.ArrayVal)
?.elements
?.mapNotNull { (it as? MetadataValue.StringVal)?.value }
val arch = "general.architecture".str() ?: ARCH_LLAMA
// -------------- populate sections ----------------
val basic = GgufMetadata.BasicInfo(
uuid = "general.uuid".str(),
name = "general.basename".str(),
nameLabel = "general.name".str(),
sizeLabel = "general.size_label".str()
)
val author = GgufMetadata.AuthorInfo(
organization = "general.organization".str(),
author = "general.author".str(),
doi = "general.doi".str(),
url = "general.url".str(),
repoUrl = "general.repo_url".str(),
license = "general.license".str(),
licenseLink = "general.license.link".str()
).takeUnless {
organization == null && author == null && doi == null &&
url == null && repoUrl == null && license == null && licenseLink == null
}
val additional = GgufMetadata.AdditionalInfo(
type = "general.type".str(),
description = "general.description".str(),
tags = "general.tags".strList(),
languages = "general.languages".strList()
).takeUnless {
type == null && description == null && tags == null && languages == null
}
val architectureInfo = GgufMetadata.ArchitectureInfo(
architecture = arch,
fileType = "general.file_type".u32(),
vocabSize = "$arch.vocab_size".u32(),
finetune = "general.finetune".str(),
quantizationVersion = "general.quantization_version".u32()
).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null }
val baseModels = buildList {
val n = "general.base_model.count".u32() ?: 0
for (i in 0 until n) {
fun k(s: String) = "general.base_model.$i.$s"
add(
GgufMetadata.BaseModelInfo(
name = k("name").str(),
author = k("author").str(),
version = k("version").str(),
organization = k("organization").str(),
url = k("url").str(),
doi = k("doi").str(),
uuid = k("uuid").str(),
repoUrl = k("repo_url").str(),
)
)
}
}.takeIf { it.isNotEmpty() }
val tokenizer = GgufMetadata.TokenizerInfo(
model = "tokenizer.ggml.model".str(),
bosTokenId = "tokenizer.ggml.bos_token_id".u32(),
eosTokenId = "tokenizer.ggml.eos_token_id".u32(),
unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(),
paddingTokenId = "tokenizer.ggml.padding_token_id".u32(),
addBosToken = "tokenizer.ggml.add_bos_token".bool(),
addEosToken = "tokenizer.ggml.add_eos_token".bool(),
chatTemplate = "tokenizer.chat_template".str()
).takeUnless { model == null && bosTokenId == null && eosTokenId == null &&
unknownTokenId == null && paddingTokenId == null &&
addBosToken == null && addEosToken == null && chatTemplate == null
}
val dimensions = GgufMetadata.DimensionsInfo(
contextLength = "$arch.context_length".u32(),
embeddingSize = "$arch.embedding_length".u32(),
blockCount = "$arch.block_count".u32(),
feedForwardSize = "$arch.feed_forward_length".u32()
).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null }
val attention = GgufMetadata.AttentionInfo(
headCount = "$arch.attention.head_count".u32(),
headCountKv = "$arch.attention.head_count_kv".u32(),
keyLength = "$arch.attention.key_length".u32(),
valueLength = "$arch.attention.value_length".u32(),
layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(),
layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(),
).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null &&
layerNormEpsilon == null && layerNormRmsEpsilon == null
}
val rope = GgufMetadata.RopeInfo(
frequencyBase = "$arch.rope.freq_base".f32(),
dimensionCount = "$arch.rope.dimension_count".u32(),
scalingType = "$arch.rope.scaling.type".str(),
scalingFactor = "$arch.rope.scaling.factor".f32(),
attnFactor = "$arch.rope.scaling.attn_factor".f32(),
originalContextLength = "$arch.rope.scaling.original_context_length".u32(),
finetuned = "$arch.rope.scaling.finetuned".bool()
).takeUnless { frequencyBase == null && dimensionCount == null &&
scalingType == null && scalingFactor == null && attnFactor == null &&
originalContextLength == null && finetuned == null
}
val experts = GgufMetadata.ExpertsInfo(
count = "$arch.expert_count".u32(),
usedCount = "$arch.expert_used_count".u32()
).takeUnless { count == null && usedCount == null }
return GgufMetadata(
version = version,
tensorCount = tensorCnt,
kvCount = kvCnt,
basic = basic,
author = author,
additional = additional,
architecture = architectureInfo,
baseModels = baseModels,
tokenizer = tokenizer,
dimensions = dimensions,
attention = attention,
rope = rope,
experts = experts
)
}
/**
* Recursively parses a metadata value of the given type from the input stream.
* @param input The input stream positioned at the start of the value.
* @param type The metadata value type to parse.
*/
private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) {
MetadataType.UINT8 -> {
// 1-byte unsigned integer
val byteVal = input.read()
if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.")
MetadataValue.UInt8(byteVal.toUByte())
}
MetadataType.INT8 -> {
// 1-byte signed integer
val byteVal = input.read()
if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.")
MetadataValue.Int8(byteVal.toByte())
}
MetadataType.UINT16 -> {
// 2-byte unsigned integer (little-endian)
val bytes = ByteArray(2)
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.")
// Combine two bytes (little-endian) into an unsigned 16-bit value
val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
MetadataValue.UInt16(u16.toUShort())
}
MetadataType.INT16 -> {
// 2-byte signed integer (little-endian)
val bytes = ByteArray(2)
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.")
// Combine to 16-bit and interpret as signed
val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
MetadataValue.Int16(i16.toShort())
}
MetadataType.UINT32 -> {
// 4-byte unsigned integer (little-endian)
val bytes = ByteArray(4)
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.")
// Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt
val u32 = (bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
MetadataValue.UInt32(u32.toUInt())
}
MetadataType.INT32 -> {
// 4-byte signed integer (little-endian)
val bytes = ByteArray(4)
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.")
// Combine four bytes into a 32-bit signed int
val i32 = (bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF)
MetadataValue.Int32(i32)
}
MetadataType.FLOAT32 -> {
// 4-byte IEEE 754 float (little-endian)
val bytes = ByteArray(4)
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.")
// Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float
val bits = (bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF)
val floatVal = Float.fromBits(bits)
MetadataValue.Float32(floatVal)
}
MetadataType.BOOL -> {
// 1-byte boolean (0 = false, 1 = true)
val byteVal = input.read()
if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.")
if (byteVal != 0 && byteVal != 1) {
throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).")
}
MetadataValue.Bool(byteVal != 0)
}
MetadataType.STRING -> {
// UTF-8 string (length-prefixed with 8-byte length)
val str = readString(input)
MetadataValue.StringVal(str)
}
MetadataType.ARRAY -> {
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
val len = readLittleLong(input)
val count = len.toInt()
if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) {
// fastforward without allocation
repeat(count) { skipValue(input, elemType) }
MetadataValue.StringVal("Array($elemType, $count items) /* summarised */")
} else {
val list = ArrayList<MetadataValue>(count)
repeat(count) { list += parseValue(input, elemType) }
MetadataValue.ArrayVal(elemType, list)
}
}
MetadataType.UINT64 -> {
// 8-byte unsigned integer (little-endian)
val bytes = ByteArray(8)
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.")
// Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range.
val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or
(bytes[6].toULong() and 0xFFuL shl 48) or
(bytes[5].toULong() and 0xFFuL shl 40) or
(bytes[4].toULong() and 0xFFuL shl 32) or
(bytes[3].toULong() and 0xFFuL shl 24) or
(bytes[2].toULong() and 0xFFuL shl 16) or
(bytes[1].toULong() and 0xFFuL shl 8) or
(bytes[0].toULong() and 0xFFuL)
MetadataValue.UInt64(u64)
}
MetadataType.INT64 -> {
// 8-byte signed integer (little-endian)
val bytes = ByteArray(8)
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.")
// Combine 8 bytes into a signed 64-bit value (Long)
val i64 = (bytes[7].toLong() and 0xFFL shl 56) or
(bytes[6].toLong() and 0xFFL shl 48) or
(bytes[5].toLong() and 0xFFL shl 40) or
(bytes[4].toLong() and 0xFFL shl 32) or
(bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
MetadataValue.Int64(i64)
}
MetadataType.FLOAT64 -> {
// 8-byte IEEE 754 double (little-endian)
val bytes = ByteArray(8)
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.")
// Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double
val bits = (bytes[7].toLong() and 0xFFL shl 56) or
(bytes[6].toLong() and 0xFFL shl 48) or
(bytes[5].toLong() and 0xFFL shl 40) or
(bytes[4].toLong() and 0xFFL shl 32) or
(bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
val doubleVal = Double.fromBits(bits)
MetadataValue.Float64(doubleVal)
}
}
private fun <T> T?.takeUnless(check: T.() -> Boolean): T? =
this?.takeIf { !it.check() }
/** Helper: Skip a value in the stream without storing it (still maintains pointer). */
private fun skipValue(input: InputStream, type: MetadataType) {
when (type) {
MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1)
MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2)
MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4)
MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8)
MetadataType.STRING -> {
val len = readLittleLong(input); input.skipFully(len)
}
MetadataType.ARRAY -> {
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
val len = readLittleLong(input)
repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip
}
}
}
/** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */
private fun readLittleLong(input: InputStream): Long {
val bytes = ByteArray(8)
input.readFully(bytes)
// Combine 8 bytes into a 64-bit value (Little Endian).
// Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement).
// In our context (lengths/counts), such extremely large values are not expected.
return (bytes[7].toLong() and 0xFFL shl 56) or
(bytes[6].toLong() and 0xFFL shl 48) or
(bytes[5].toLong() and 0xFFL shl 40) or
(bytes[4].toLong() and 0xFFL shl 32) or
(bytes[3].toLong() and 0xFFL shl 24) or
(bytes[2].toLong() and 0xFFL shl 16) or
(bytes[1].toLong() and 0xFFL shl 8) or
(bytes[0].toLong() and 0xFFL)
}
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
private fun readString(input: InputStream): String =
// Read 8-byte little-endian length (number of bytes in the string).
readLittleLong(input).let { len ->
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
// Read the UTF-8 bytes of the given length.
ByteArray(len.toInt()).let {
if (it.isNotEmpty()) input.readFully(it)
String(it, Charsets.UTF_8)
}
}
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
private fun littleEndianBytesToInt(bytes: ByteArray): Int =
// Note: assumes bytes length is 4.
(bytes[3].toInt() and 0xFF shl 24) or
(bytes[2].toInt() and 0xFF shl 16) or
(bytes[1].toInt() and 0xFF shl 8) or
(bytes[0].toInt() and 0xFF)
/**
* Robust skip that works the same on JDK 11 and Androids desugared runtime.
*
* @param n Number of bytes to advance in the stream.
* @throws IOException on premature EOF.
*/
private fun InputStream.skipFully(n: Long) {
var remaining = n
val scratch = ByteArray(8192) // readandtoss buffer
while (remaining > 0) {
val skipped = skip(remaining)
when {
skipped > 0 -> remaining -= skipped // normal fast path
skipped == 0L -> {
// fallback: read and discard
val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt())
if (read == -1) throw IOException("EOF while skipping $n bytes")
remaining -= read
}
else -> throw IOException("Skip returned negative value")
}
}
}
/**
* Extension that keeps reading until the requested number of bytes are filled.
* Falls back to `read()` when `skip()` returns 0, which happens on some Android
* streams.
*
* @param buf Destination buffer.
* @param len Number of bytes to fill (defaults to `buf.size`).
* @throws IOException on premature EOF.
*/
private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) {
var off = 0
while (off < len) {
val n = read(buf, off, len - off)
if (n == -1) throw IOException("EOF after $off of $len bytes")
off += n
}
}
/**
* Read EXACTLY `n` bytes or throw never returns a partiallyfilled array.
* This is used for small fixedlength reads (e.g. 4byte type codes).
*
* @throws IOException on premature EOF.
*/
private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
if (read(it) != n) throw IOException("Unexpected EOF")
}
}

View File

@ -0,0 +1,71 @@
plugins {
id("com.android.library")
id("org.jetbrains.kotlin.android")
}
android {
namespace = "android.llama.cpp"
compileSdk = 34
defaultConfig {
minSdk = 33
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
consumerProguardFiles("consumer-rules.pro")
ndk {
// Add NDK properties if wanted, e.g.
// abiFilters += listOf("arm64-v8a")
}
externalNativeBuild {
cmake {
arguments += "-DLLAMA_CURL=OFF"
arguments += "-DLLAMA_BUILD_COMMON=ON"
arguments += "-DGGML_LLAMAFILE=OFF"
arguments += "-DCMAKE_BUILD_TYPE=Release"
cppFlags += listOf()
arguments += listOf()
cppFlags("")
}
}
}
buildTypes {
release {
isMinifyEnabled = false
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
}
}
externalNativeBuild {
cmake {
path("src/main/cpp/CMakeLists.txt")
version = "3.22.1"
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
}
kotlinOptions {
jvmTarget = "1.8"
}
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
}
}
}
dependencies {
implementation("androidx.core:core-ktx:1.12.0")
implementation("androidx.appcompat:appcompat:1.6.1")
implementation("com.google.android.material:material:1.11.0")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.1.5")
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
}

View File

@ -0,0 +1,53 @@
# For more information about using CMake with Android Studio, read the
# documentation: https://d.android.com/studio/projects/add-native-code.html.
# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
# Sets the minimum CMake version required for this project.
cmake_minimum_required(VERSION 3.22.1)
# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
# Since this is the top level CMakeLists.txt, the project name is also accessible
# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
# build script scope).
project("llama-android")
#include(FetchContent)
#FetchContent_Declare(
# llama
# GIT_REPOSITORY https://github.com/ggml-org/llama.cpp
# GIT_TAG master
#)
# Also provides "common"
#FetchContent_MakeAvailable(llama)
# Creates and names a library, sets it as either STATIC
# or SHARED, and provides the relative paths to its source code.
# You can define multiple libraries, and CMake builds them for you.
# Gradle automatically packages shared libraries with your APK.
#
# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
# is preferred for the same purpose.
#
#load local llama.cpp
add_subdirectory(../../../../../../ build-llama)
# In order to load a library into your app from Java/Kotlin, you must call
# System.loadLibrary() and pass the name of the library defined here;
# for GameActivity/NativeActivity derived applications, the same library name must be
# used in the AndroidManifest.xml file.
add_library(${CMAKE_PROJECT_NAME} SHARED
# List C/C++ source files with relative paths to this CMakeLists.txt.
llama-android.cpp)
# Specifies libraries CMake should link to your target library. You
# can link libraries from various origins, such as libraries defined in this
# build script, prebuilt third-party libraries, or Android system libraries.
target_link_libraries(${CMAKE_PROJECT_NAME}
# List libraries link to the target library
llama
common
android
log)

View File

@ -0,0 +1,452 @@
#include <android/log.h>
#include <jni.h>
#include <iomanip>
#include <math.h>
#include <string>
#include <unistd.h>
#include "llama.h"
#include "common.h"
// Write C++ code here.
//
// Do not forget to dynamically load the C++ library into your application.
//
// For instance,
//
// In MainActivity.java:
// static {
// System.loadLibrary("llama-android");
// }
//
// Or, in MainActivity.kt:
// companion object {
// init {
// System.loadLibrary("llama-android")
// }
// }
#define TAG "llama-android.cpp"
#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
jclass la_int_var;
jmethodID la_int_var_value;
jmethodID la_int_var_inc;
std::string cached_token_chars;
bool is_valid_utf8(const char * string) {
if (!string) {
return true;
}
const unsigned char * bytes = (const unsigned char *)string;
int num;
while (*bytes != 0x00) {
if ((*bytes & 0x80) == 0x00) {
// U+0000 to U+007F
num = 1;
} else if ((*bytes & 0xE0) == 0xC0) {
// U+0080 to U+07FF
num = 2;
} else if ((*bytes & 0xF0) == 0xE0) {
// U+0800 to U+FFFF
num = 3;
} else if ((*bytes & 0xF8) == 0xF0) {
// U+10000 to U+10FFFF
num = 4;
} else {
return false;
}
bytes += 1;
for (int i = 1; i < num; ++i) {
if ((*bytes & 0xC0) != 0x80) {
return false;
}
bytes += 1;
}
}
return true;
}
static void log_callback(ggml_log_level level, const char * fmt, void * data) {
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
llama_model_params model_params = llama_model_default_params();
auto path_to_model = env->GetStringUTFChars(filename, 0);
LOGi("Loading model from %s", path_to_model);
auto model = llama_model_load_from_file(path_to_model, model_params);
env->ReleaseStringUTFChars(filename, path_to_model);
if (!model) {
LOGe("load_model() failed");
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
return 0;
}
return reinterpret_cast<jlong>(model);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
llama_model_free(reinterpret_cast<llama_model *>(model));
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) {
auto model = reinterpret_cast<llama_model *>(jmodel);
if (!model) {
LOGe("new_context(): model cannot be null");
env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
return 0;
}
int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
LOGi("Using %d threads", n_threads);
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 2048;
ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads;
llama_context * context = llama_new_context_with_model(model, ctx_params);
if (!context) {
LOGe("llama_new_context_with_model() returned null)");
env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
"llama_new_context_with_model() returned null)");
return 0;
}
return reinterpret_cast<jlong>(context);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) {
llama_free(reinterpret_cast<llama_context *>(context));
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) {
llama_backend_free();
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
llama_log_set(log_callback, NULL);
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_bench_1model(
JNIEnv *env,
jobject,
jlong context_pointer,
jlong model_pointer,
jlong batch_pointer,
jint pp,
jint tg,
jint pl,
jint nr
) {
auto pp_avg = 0.0;
auto tg_avg = 0.0;
auto pp_std = 0.0;
auto tg_std = 0.0;
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto model = reinterpret_cast<llama_model *>(model_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const int n_ctx = llama_n_ctx(context);
LOGi("n_ctx = %d", n_ctx);
int i, j;
int nri;
for (nri = 0; nri < nr; nri++) {
LOGi("Benchmark prompt processing (pp)");
common_batch_clear(*batch);
const int n_tokens = pp;
for (i = 0; i < n_tokens; i++) {
common_batch_add(*batch, 0, i, { 0 }, false);
}
batch->logits[batch->n_tokens - 1] = true;
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) {
LOGi("llama_decode() failed during prompt processing");
}
const auto t_pp_end = ggml_time_us();
// bench text generation
LOGi("Benchmark text generation (tg)");
llama_memory_clear(llama_get_memory(context), false);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {
common_batch_clear(*batch);
for (j = 0; j < pl; j++) {
common_batch_add(*batch, 0, i, { j }, true);
}
LOGi("llama_decode() text generation: %d", i);
if (llama_decode(context, *batch) != 0) {
LOGi("llama_decode() failed during text generation");
}
}
const auto t_tg_end = ggml_time_us();
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
const auto speed_pp = double(pp) / t_pp;
const auto speed_tg = double(pl * tg) / t_tg;
pp_avg += speed_pp;
tg_avg += speed_tg;
pp_std += speed_pp * speed_pp;
tg_std += speed_tg * speed_tg;
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
}
pp_avg /= double(nr);
tg_avg /= double(nr);
if (nr > 1) {
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
} else {
pp_std = 0;
tg_std = 0;
}
char model_desc[128];
llama_model_desc(model, model_desc, sizeof(model_desc));
const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
const auto backend = "(Android)"; // TODO: What should this be?
std::stringstream result;
result << std::setprecision(2);
result << "| model | size | params | backend | test | t/s |\n";
result << "| --- | --- | --- | --- | --- | --- |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
return env->NewStringUTF(result.str().c_str());
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
llama_batch *batch = new llama_batch {
0,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
};
if (embd) {
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
} else {
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
}
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
for (int i = 0; i < n_tokens; ++i) {
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return reinterpret_cast<jlong>(batch);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
//llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
delete batch;
}
extern "C"
JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = true;
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
return reinterpret_cast<jlong>(smpl);
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
llama_sampler_free(reinterpret_cast<llama_sampler *>(sampler_pointer));
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
llama_backend_init();
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) {
return env->NewStringUTF(llama_print_system_info());
}
extern "C"
JNIEXPORT jint JNICALL
Java_android_llama_cpp_LLamaAndroid_completion_1init(
JNIEnv *env,
jobject,
jlong context_pointer,
jlong batch_pointer,
jstring jtext,
jboolean format_chat,
jint n_len
) {
cached_token_chars.clear();
const auto text = env->GetStringUTFChars(jtext, 0);
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
bool parse_special = (format_chat == JNI_TRUE);
const auto tokens_list = common_tokenize(context, text, true, parse_special);
auto n_ctx = llama_n_ctx(context);
auto n_kv_req = tokens_list.size() + n_len;
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
if (n_kv_req > n_ctx) {
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
}
for (auto id : tokens_list) {
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
}
common_batch_clear(*batch);
// evaluate the initial prompt
for (auto i = 0; i < tokens_list.size(); i++) {
common_batch_add(*batch, tokens_list[i], i, { 0 }, false);
}
// llama_decode will output logits only for the last token of the prompt
batch->logits[batch->n_tokens - 1] = true;
if (llama_decode(context, *batch) != 0) {
LOGe("llama_decode() failed");
}
env->ReleaseStringUTFChars(jtext, text);
return batch->n_tokens;
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_completion_1loop(
JNIEnv * env,
jobject,
jlong context_pointer,
jlong batch_pointer,
jlong sampler_pointer,
jint n_len,
jobject intvar_ncur
) {
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
const auto model = llama_get_model(context);
const auto vocab = llama_model_get_vocab(model);
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
// sample the most likely token
const auto new_token_id = llama_sampler_sample(sampler, context, -1);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
return nullptr;
}
auto new_token_chars = common_token_to_piece(context, new_token_id);
cached_token_chars += new_token_chars;
jstring new_token = nullptr;
if (is_valid_utf8(cached_token_chars.c_str())) {
new_token = env->NewStringUTF(cached_token_chars.c_str());
LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
cached_token_chars.clear();
} else {
new_token = env->NewStringUTF("");
}
common_batch_clear(*batch);
common_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
env->CallVoidMethod(intvar_ncur, la_int_var_inc);
if (llama_decode(context, *batch) != 0) {
LOGe("llama_decode() returned null");
}
return new_token;
}
extern "C"
JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_memory_clear(llama_get_memory(reinterpret_cast<llama_context *>(context)), true);
}

View File

@ -0,0 +1,180 @@
package android.llama.cpp
import android.util.Log
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.withContext
import java.util.concurrent.Executors
import kotlin.concurrent.thread
class LLamaAndroid {
private val tag: String? = this::class.simpleName
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
thread(start = false, name = "Llm-RunLoop") {
Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
// No-op if called more than once.
System.loadLibrary("llama-android")
// Set llama log handler to Android
log_to_android()
backend_init(false)
Log.d(tag, system_info())
it.run()
}.apply {
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
Log.e(tag, "Unhandled exception", exception)
}
}
}.asCoroutineDispatcher()
private val nlen: Int = 64
private external fun log_to_android()
private external fun load_model(filename: String): Long
private external fun free_model(model: Long)
private external fun new_context(model: Long): Long
private external fun free_context(context: Long)
private external fun backend_init(numa: Boolean)
private external fun backend_free()
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
private external fun free_batch(batch: Long)
private external fun new_sampler(): Long
private external fun free_sampler(sampler: Long)
private external fun bench_model(
context: Long,
model: Long,
batch: Long,
pp: Int,
tg: Int,
pl: Int,
nr: Int
): String
private external fun system_info(): String
private external fun completion_init(
context: Long,
batch: Long,
text: String,
formatChat: Boolean,
nLen: Int
): Int
private external fun completion_loop(
context: Long,
batch: Long,
sampler: Long,
nLen: Int,
ncur: IntVar
): String?
private external fun kv_cache_clear(context: Long)
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
return withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
Log.d(tag, "bench(): $state")
bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
}
else -> throw IllegalStateException("No model loaded")
}
}
}
suspend fun load(pathToModel: String) {
withContext(runLoop) {
when (threadLocalState.get()) {
is State.Idle -> {
val model = load_model(pathToModel)
if (model == 0L) throw IllegalStateException("load_model() failed")
val context = new_context(model)
if (context == 0L) throw IllegalStateException("new_context() failed")
val batch = new_batch(512, 0, 1)
if (batch == 0L) throw IllegalStateException("new_batch() failed")
val sampler = new_sampler()
if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
Log.i(tag, "Loaded model $pathToModel")
threadLocalState.set(State.Loaded(model, context, batch, sampler))
}
else -> throw IllegalStateException("Model already loaded")
}
}
}
fun send(message: String, formatChat: Boolean = false): Flow<String> = flow {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
while (ncur.value <= nlen) {
val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
if (str == null) {
break
}
emit(str)
}
kv_cache_clear(state.context)
}
else -> {}
}
}.flowOn(runLoop)
/**
* Unloads the model and frees resources.
*
* This is a no-op if there's no model loaded.
*/
suspend fun unload() {
withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
free_context(state.context)
free_model(state.model)
free_batch(state.batch)
free_sampler(state.sampler);
threadLocalState.set(State.Idle)
}
else -> {}
}
}
}
companion object {
private class IntVar(value: Int) {
@Volatile
var value: Int = value
private set
fun inc() {
synchronized(this) {
value += 1
}
}
}
private sealed interface State {
data object Idle: State
data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State
}
// Enforce only one instance of Llm.
private val _instance: LLamaAndroid = LLamaAndroid()
fun instance(): LLamaAndroid = _instance
}
}

View File

@ -8,11 +8,11 @@ pluginManagement {
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
mavenCentral()
google()
mavenCentral()
}
}
rootProject.name = "AiChat"
rootProject.name = "LlamaAndroid"
include(":app")
include(":lib")
include(":llama")

View File

@ -55,10 +55,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
// load the target model
auto llama_init = common_init_from_params(params);
common_init_result llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * mem = llama_get_memory(ctx);

View File

@ -18,16 +18,16 @@ int main(int argc, char ** argv){
llama_numa_init(params.numa);
// load the model
auto llama_init = common_init_from_params(params);
common_init_result llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
llama_model_ptr & model = llama_init.model;
llama_context_ptr & ctx = llama_init.context;
GGML_ASSERT(model != nullptr);
// tokenize the prompt
std::vector<llama_token> inp;
inp = common_tokenize(ctx, params.prompt, true, true);
inp = common_tokenize(ctx.get(), params.prompt, true, true);
fprintf(stderr, "%s: tokenization done\n", __func__);
common_ngram_cache ngram_cache;

View File

@ -28,13 +28,13 @@ int main(int argc, char ** argv){
llama_numa_init(params.numa);
// load the model
auto llama_init = common_init_from_params(params);
common_init_result llama_init = common_init_from_params(params);
llama_context * ctx = llama_init->context();
llama_context_ptr & ctx = llama_init.context;
// tokenize the prompt
std::vector<llama_token> inp;
inp = common_tokenize(ctx, params.prompt, true, true);
inp = common_tokenize(ctx.get(), params.prompt, true, true);
common_ngram_cache ngram_cache_context;
common_ngram_cache ngram_cache_dynamic;
@ -65,7 +65,7 @@ int main(int argc, char ** argv){
}
const int n_input = inp.size();
const int n_ctx = llama_n_ctx(ctx);
const int n_ctx = llama_n_ctx(ctx.get());
int n_drafted = 0;
int n_accept = 0;

View File

@ -29,10 +29,10 @@ int main(int argc, char ** argv){
llama_numa_init(params.numa);
// load the model
auto llama_init = common_init_from_params(params);
common_init_result llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model);

View File

@ -10,13 +10,6 @@ and in some cases perplexity checked of the quantized model. And finally the
model/models need to the ggml-org on Hugging Face. This tool/example tries to
help with this process.
> 📝 **Note:** When adding a new model from an existing family, verify the
> previous version passes logits verification first. Existing models can have
> subtle numerical differences that don't affect generation quality but cause
> logits mismatches. Identifying these upfront whether they exist in llama.cpp,
> the conversion script, or in an upstream implementation, can save significant
> debugging time.
### Overview
The idea is that the makefile targets and scripts here can be used in the
development/conversion process assisting with things like:

View File

@ -7,7 +7,7 @@ base_model:
Recommended way to run this model:
```sh
llama-server -hf {namespace}/{model_name}-GGUF -c 0
llama-server -hf {namespace}/{model_name}-GGUF -c 0 -fa
```
Then, access http://localhost:8080

View File

@ -5,7 +5,7 @@ import os
import importlib
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch
import numpy as np
@ -116,11 +116,11 @@ def debug_hook(name):
def fn(_m, input, output):
if isinstance(input, torch.Tensor):
summarize(input, name + "_in")
elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor):
summarize(input[0], name + "_in")
if isinstance(output, torch.Tensor):
summarize(output, name + "_out")
elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor):
summarize(output[0], name + "_out")
return fn
@ -130,7 +130,6 @@ unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
parser = argparse.ArgumentParser(description="Process model with specified path")
parser.add_argument("--model-path", "-m", help="Path to the model")
parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
args = parser.parse_args()
model_path = os.environ.get("MODEL_PATH", args.model_path)
@ -143,13 +142,8 @@ if model_path is None:
print("Loading model and tokenizer using AutoTokenizer:", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
multimodal = False
full_config = config
print("Model type: ", config.model_type)
if "vocab_size" not in config and "text_config" in config:
config = config.text_config
multimodal = True
print("Vocab size: ", config.vocab_size)
print("Hidden size: ", config.hidden_size)
print("Number of layers: ", config.num_hidden_layers)
@ -175,14 +169,9 @@ if unreleased_model_name:
print(f"Failed to import or load model: {e}")
exit(1)
else:
if multimodal:
model = AutoModelForImageTextToText.from_pretrained(
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=full_config
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
)
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
)
for name, module in model.named_modules():
if len(list(module.children())) == 0: # only leaf modules
@ -196,10 +185,7 @@ model_name = os.path.basename(model_path)
print(f"Model class: {model.__class__.__name__}")
device = next(model.parameters()).device
if args.prompt_file:
with open(args.prompt_file, encoding='utf-8') as f:
prompt = f.read()
elif os.getenv("MODEL_TESTING_PROMPT"):
if os.getenv("MODEL_TESTING_PROMPT"):
prompt = os.getenv("MODEL_TESTING_PROMPT")
else:
prompt = "Hello, my name is"
@ -209,21 +195,12 @@ print(f"Input tokens: {input_ids}")
print(f"Input text: {repr(prompt)}")
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
batch_size = 512
with torch.no_grad():
past = None
outputs = None
for i in range(0, input_ids.size(1), batch_size):
print(f"Processing chunk with tokens {i} to {i + batch_size}")
chunk = input_ids[:, i:i + batch_size]
outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True)
past = outputs.past_key_values
logits = outputs.logits # type: ignore
outputs = model(input_ids.to(model.device))
logits = outputs.logits
# Extract logits for the last token (next token prediction)
last_logits = logits[0, -1, :].float().cpu().numpy()
last_logits = logits[0, -1, :].cpu().numpy()
print(f"Logits shape: {logits.shape}")
print(f"Last token logits shape: {last_logits.shape}")

View File

@ -34,11 +34,8 @@ done
MODEL_PATH="${MODEL_PATH:-"$EMBEDDING_MODEL_PATH"}"
MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}"
CONVERTED_MODEL_PATH="${CONVERTED_EMBEDDING_PATH:-"$CONVERTED_EMBEDDING_MODEL"}"
CONVERTED_MODEL_NAME="${CONVERTED_MODEL_NAME:-$(basename "$CONVERTED_MODEL_PATH" .gguf)}"
if [ -t 0 ]; then
CPP_EMBEDDINGS="data/llamacpp-${CONVERTED_MODEL_NAME}-embeddings.bin"
CPP_EMBEDDINGS="data/llamacpp-${MODEL_NAME}-embeddings.bin"
else
# Process piped JSON data and convert to binary (matching logits.cpp format)
TEMP_FILE=$(mktemp /tmp/tmp.XXXXXX.binn)

View File

@ -192,10 +192,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
// load the target model
auto llama_init = common_init_from_params(params);
common_init_result llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * mem = llama_get_memory(ctx);

View File

@ -149,10 +149,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
// load the model
auto llama_init = common_init_from_params(params);
common_init_result llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);

View File

@ -34,10 +34,10 @@ int main(int argc, char ** argv) {
std::string result2;
// init
auto llama_init = common_init_from_params(params);
common_init_result llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
if (model == nullptr || ctx == nullptr) {
fprintf(stderr, "%s : failed to init\n", __func__);

View File

@ -40,10 +40,10 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL;
// load the target model
auto llama_init_tgt = common_init_from_params(params);
common_init_result llama_init_tgt = common_init_from_params(params);
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();
model_tgt = llama_init_tgt.model.get();
ctx_tgt = llama_init_tgt.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
@ -61,10 +61,10 @@ int main(int argc, char ** argv) {
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
auto llama_init_dft = common_init_from_params(params);
common_init_result llama_init_dft = common_init_from_params(params);
//model_dft = llama_init_dft->model();
ctx_dft = llama_init_dft->context();
//model_dft = llama_init_dft.model.get();
ctx_dft = llama_init_dft.context.get();
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
@ -255,8 +255,6 @@ int main(int argc, char ** argv) {
LOG_INF("target:\n\n");
common_perf_print(ctx_tgt, smpl);
llama_batch_free(batch_tgt);
common_sampler_free(smpl);
common_speculative_free(spec);

View File

@ -71,10 +71,10 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL;
// load the target model
auto llama_init_tgt = common_init_from_params(params);
common_init_result llama_init_tgt = common_init_from_params(params);
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();
model_tgt = llama_init_tgt.model.get();
ctx_tgt = llama_init_tgt.context.get();
// load the draft model
params.devices = params.speculative.devices;
@ -87,10 +87,10 @@ int main(int argc, char ** argv) {
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
auto llama_init_dft = common_init_from_params(params);
common_init_result llama_init_dft = common_init_from_params(params);
model_dft = llama_init_dft->model();
ctx_dft = llama_init_dft->context();
model_dft = llama_init_dft.model.get();
ctx_dft = llama_init_dft.context.get();
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);

View File

@ -39,10 +39,9 @@ int main(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);
// load the model and apply lora adapter, if any
auto llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
common_init_result llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model;
llama_context_ptr & ctx = llama_init.context;
if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);
@ -55,8 +54,8 @@ int main(int argc, char ** argv) {
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
}
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx, tokens, llama_n_ctx(ctx) / 2);
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get()) / 2);
struct lr_opt & lr = params.lr;
LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
@ -71,7 +70,7 @@ int main(int argc, char ** argv) {
/*get_opt_pars_ud =*/&params.lr,
/*optimizer_type =*/params.optimizer,
};
llama_opt_init(ctx, model, lopt_params);
llama_opt_init(ctx.get(), model.get(), lopt_params);
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
@ -79,7 +78,7 @@ int main(int argc, char ** argv) {
ggml_opt_result_t result_eval = ggml_opt_result_init();
for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split,
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
fprintf(stderr, "\n");
@ -89,7 +88,7 @@ int main(int argc, char ** argv) {
ggml_opt_result_free(result_train);
ggml_opt_result_free(result_eval);
llama_model_save_to_file(model, params.out_file.c_str());
llama_model_save_to_file(model.get(), params.out_file.c_str());
llama_backend_free();

View File

@ -54,10 +54,6 @@ if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
# TODO
else()
set(GGML_STANDALONE OFF)
if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
endif()
endif()
if (EMSCRIPTEN)

View File

@ -53,14 +53,7 @@ GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
// call with a worst-case graph to avoid buffer reallocations
// not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
// returns false if the buffer allocation failed
// ggml_gallocr_resrve_n_size writes the buffer sizes per galloc buffer that would be allocated by ggml_gallocr_reserve_n to sizes
GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
GGML_API void ggml_gallocr_reserve_n_size(
ggml_gallocr_t galloc,
struct ggml_cgraph * graph,
const int * node_buffer_ids,
const int * leaf_buffer_ids,
size_t * sizes);
GGML_API bool ggml_gallocr_reserve_n(
ggml_gallocr_t galloc,
struct ggml_cgraph * graph,
@ -75,8 +68,6 @@ GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_i
// Utils
// Create a buffer and allocate all the tensors in a ggml_context
// ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft
GGML_API size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);

Some files were not shown because too many files have changed in this diff Show More