diff --git a/.devops/cann.Dockerfile b/.devops/cann.Dockerfile
index 83182c9700..db221b0b81 100644
--- a/.devops/cann.Dockerfile
+++ b/.devops/cann.Dockerfile
@@ -107,7 +107,7 @@ ENTRYPOINT ["/app/tools.sh"]
# ENTRYPOINT ["/app/llama-server"]
### Target: light
-# Lightweight image containing only llama-cli
+# Lightweight image containing only llama-cli and llama-completion
# ==============================================================================
FROM base AS light
diff --git a/.devops/llama-cli-cann.Dockerfile b/.devops/llama-cli-cann.Dockerfile
index ef43d78cd2..6581187f32 100644
--- a/.devops/llama-cli-cann.Dockerfile
+++ b/.devops/llama-cli-cann.Dockerfile
@@ -23,11 +23,12 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH
RUN echo "Building with static libs" && \
source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \
cmake -B build -DGGML_NATIVE=OFF -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_TESTS=OFF && \
- cmake --build build --config Release --target llama-cli
+ cmake --build build --config Release --target llama-cli && \
+ cmake --build build --config Release --target llama-completion
# TODO: use image with NNRT
FROM ascendai/cann:$ASCEND_VERSION AS runtime
-COPY --from=build /app/build/bin/llama-cli /llama-cli
+COPY --from=build /app/build/bin/llama-cli /app/build/bin/llama-completion /
ENV LC_ALL=C.utf8
diff --git a/.devops/llama-cpp-cuda.srpm.spec b/.devops/llama-cpp-cuda.srpm.spec
index 3bbf4a4def..4d42a906b1 100644
--- a/.devops/llama-cpp-cuda.srpm.spec
+++ b/.devops/llama-cpp-cuda.srpm.spec
@@ -37,6 +37,7 @@ make -j GGML_CUDA=1
%install
mkdir -p %{buildroot}%{_bindir}/
cp -p llama-cli %{buildroot}%{_bindir}/llama-cuda-cli
+cp -p llama-completion %{buildroot}%{_bindir}/llama-cuda-completion
cp -p llama-server %{buildroot}%{_bindir}/llama-cuda-server
cp -p llama-simple %{buildroot}%{_bindir}/llama-cuda-simple
@@ -68,6 +69,7 @@ rm -rf %{_builddir}/*
%files
%{_bindir}/llama-cuda-cli
+%{_bindir}/llama-cuda-completion
%{_bindir}/llama-cuda-server
%{_bindir}/llama-cuda-simple
/usr/lib/systemd/system/llamacuda.service
diff --git a/.devops/llama-cpp.srpm.spec b/.devops/llama-cpp.srpm.spec
index 45902dcf89..0a4f43058d 100644
--- a/.devops/llama-cpp.srpm.spec
+++ b/.devops/llama-cpp.srpm.spec
@@ -39,6 +39,7 @@ make -j
%install
mkdir -p %{buildroot}%{_bindir}/
cp -p llama-cli %{buildroot}%{_bindir}/llama-cli
+cp -p llama-completion %{buildroot}%{_bindir}/llama-completion
cp -p llama-server %{buildroot}%{_bindir}/llama-server
cp -p llama-simple %{buildroot}%{_bindir}/llama-simple
@@ -70,6 +71,7 @@ rm -rf %{_builddir}/*
%files
%{_bindir}/llama-cli
+%{_bindir}/llama-completion
%{_bindir}/llama-server
%{_bindir}/llama-simple
/usr/lib/systemd/system/llama.service
diff --git a/.github/ISSUE_TEMPLATE/019-bug-misc.yml b/.github/ISSUE_TEMPLATE/019-bug-misc.yml
index 1904e31fdc..e1bd08ddd2 100644
--- a/.github/ISSUE_TEMPLATE/019-bug-misc.yml
+++ b/.github/ISSUE_TEMPLATE/019-bug-misc.yml
@@ -86,6 +86,7 @@ body:
description: >
If applicable, please copy and paste any relevant log output, including any generated text.
This will be automatically formatted into code, so no need for backticks.
+ If you are encountering problems specifically with the `llama_params_fit` module, always upload `--verbose` logs as well.
render: shell
validations:
required: false
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index af4c60be64..de3ad06065 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -70,6 +70,7 @@ jobs:
with:
key: macOS-latest-cmake-arm64
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@@ -106,6 +107,7 @@ jobs:
with:
key: macOS-latest-cmake-x64
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@@ -142,6 +144,7 @@ jobs:
with:
key: macOS-latest-cmake-arm64-webgpu
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dawn Dependency
id: dawn-depends
@@ -195,6 +198,7 @@ jobs:
with:
key: ubuntu-cpu-cmake-${{ matrix.build }}
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build Dependencies
id: build_depends
@@ -276,6 +280,7 @@ jobs:
with:
key: ubuntu-latest-cmake-sanitizer-${{ matrix.sanitizer }}
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@@ -396,6 +401,7 @@ jobs:
with:
key: ubuntu-24-cmake-vulkan-deb
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@@ -431,6 +437,7 @@ jobs:
with:
key: ubuntu-24-cmake-vulkan
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@@ -490,6 +497,7 @@ jobs:
with:
key: ubuntu-24-cmake-webgpu
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@@ -562,6 +570,7 @@ jobs:
with:
key: ubuntu-latest-wasm-webgpu
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Install Emscripten
run: |
@@ -609,6 +618,7 @@ jobs:
with:
key: ubuntu-22-cmake-hip
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with native CMake HIP support
id: cmake_build
@@ -641,6 +651,7 @@ jobs:
with:
key: ubuntu-22-cmake-musa
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with native CMake MUSA support
id: cmake_build
@@ -688,6 +699,7 @@ jobs:
with:
key: ubuntu-22-cmake-sycl
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@@ -738,6 +750,7 @@ jobs:
with:
key: ubuntu-22-cmake-sycl-fp16
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@@ -771,6 +784,7 @@ jobs:
with:
key: macOS-latest-cmake-ios
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@@ -802,6 +816,7 @@ jobs:
with:
key: macOS-latest-cmake-tvos
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@@ -863,6 +878,7 @@ jobs:
with:
key: macOS-latest-swift
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Download xcframework artifact
uses: actions/download-artifact@v4
@@ -905,6 +921,7 @@ jobs:
key: windows-msys2
variant: ccache
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Setup ${{ matrix.sys }}
uses: msys2/setup-msys2@v2
@@ -973,6 +990,7 @@ jobs:
key: windows-latest-cmake-${{ matrix.build }}
variant: ccache
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Download OpenBLAS
id: get_openblas
@@ -1077,6 +1095,7 @@ jobs:
with:
key: ubuntu-latest-cmake-cuda
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with CMake
run: |
@@ -1109,6 +1128,7 @@ jobs:
key: windows-cuda-${{ matrix.cuda }}
variant: ccache
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Install Cuda Toolkit
uses: ./.github/actions/windows-setup-cuda
@@ -1160,6 +1180,7 @@ jobs:
key: windows-latest-cmake-sycl
variant: ccache
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Install
run: |
@@ -1221,6 +1242,7 @@ jobs:
with:
key: ${{ github.job }}
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build
id: cmake_build
@@ -1466,6 +1488,7 @@ jobs:
with:
key: ggml-ci-x64-cpu-low-perf
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@@ -1491,6 +1514,7 @@ jobs:
with:
key: ggml-ci-arm64-cpu-low-perf
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@@ -1516,6 +1540,7 @@ jobs:
with:
key: ggml-ci-x64-cpu-high-perf
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@@ -1541,6 +1566,7 @@ jobs:
with:
key: ggml-ci-arm64-cpu-high-perf
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@@ -1566,6 +1592,7 @@ jobs:
with:
key: ggml-ci-arm64-cpu-high-perf-sve
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@@ -1701,6 +1728,7 @@ jobs:
with:
key: ggml-ci-arm64-cpu-kleidiai
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Dependencies
id: depends
@@ -2084,6 +2112,7 @@ jobs:
with:
key: ggml-ci-arm64-graviton4-kleidiai
evict-old-files: 1d
+ save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Test
id: ggml-ci
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 446cae9f84..4cc2f4665c 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -66,16 +66,9 @@ jobs:
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
- zip -y -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- - name: Upload artifacts (zip)
- uses: actions/upload-artifact@v4
- with:
- path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip
- name: llama-bin-macos-arm64.zip
-
- - name: Upload artifacts (tar)
+ - name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz
@@ -127,16 +120,9 @@ jobs:
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
- zip -y -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- - name: Upload artifacts (zip)
- uses: actions/upload-artifact@v4
- with:
- path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip
- name: llama-bin-macos-x64.zip
-
- - name: Upload artifacts (tar)
+ - name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz
@@ -196,16 +182,9 @@ jobs:
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
- zip -y -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- - name: Upload artifacts (zip)
- uses: actions/upload-artifact@v4
- with:
- path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip
- name: llama-bin-ubuntu-${{ matrix.build }}.zip
-
- - name: Upload artifacts (tar)
+ - name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz
@@ -256,16 +235,9 @@ jobs:
id: pack_artifacts
run: |
cp LICENSE ./build/bin/
- zip -y -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/*
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- - name: Upload artifacts (zip)
- uses: actions/upload-artifact@v4
- with:
- path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip
- name: llama-bin-ubuntu-vulkan-x64.zip
-
- - name: Upload artifacts (tar)
+ - name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz
@@ -716,21 +688,16 @@ jobs:
- name: Pack artifacts
id: pack_artifacts
run: |
- zip -y -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
- tar -czvf llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz -C build-apple llama.xcframework
+ # Zip file is required for Swift Package Manager, which does not support tar.gz for binary targets.
+ # For more details, see https://developer.apple.com/documentation/xcode/distributing-binary-frameworks-as-swift-packages
+ zip -r -y llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
- - name: Upload artifacts (zip)
+ - name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
name: llama-${{ steps.tag.outputs.name }}-xcframework.zip
- - name: Upload artifacts (tar)
- uses: actions/upload-artifact@v4
- with:
- path: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
- name: llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz
-
openEuler-cann:
strategy:
@@ -797,7 +764,7 @@ jobs:
cp LICENSE ./build/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
- - name: Upload artifacts (tar)
+ - name: Upload artifacts
uses: actions/upload-artifact@v4
with:
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}.tar.gz
@@ -889,9 +856,6 @@ jobs:
with:
tag_name: ${{ steps.tag.outputs.name }}
body: |
- > [!WARNING]
- > **Release Format Update**: Linux releases will soon use .tar.gz archives instead of .zip. Please make the necessary changes to your deployment scripts.
-
${{ github.event.head_commit.message }}
@@ -901,7 +865,7 @@ jobs:
**macOS/iOS:**
- [macOS Apple Silicon (arm64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz)
- [macOS Intel (x64)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz)
- - [iOS XCFramework](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-xcframework.tar.gz)
+ - [iOS XCFramework](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-xcframework.zip)
**Linux:**
- [Ubuntu x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.tar.gz)
@@ -911,8 +875,8 @@ jobs:
**Windows:**
- [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip)
- [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip)
- - [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip)
- - [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.1-x64.zip)
+ - [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip)
+ - [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.1-x64.zip) - [CUDA 13.1 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.1-x64.zip)
- [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip)
- [Windows x64 (SYCL)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip)
- [Windows x64 (HIP)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-hip-radeon-x64.zip)
diff --git a/.github/workflows/server-webui.yml b/.github/workflows/server-webui.yml
new file mode 100644
index 0000000000..544c4ad408
--- /dev/null
+++ b/.github/workflows/server-webui.yml
@@ -0,0 +1,225 @@
+# Server WebUI build and tests
+name: Server WebUI
+
+on:
+ workflow_dispatch: # allows manual triggering
+ inputs:
+ sha:
+ description: 'Commit SHA1 to build'
+ required: false
+ type: string
+ slow_tests:
+ description: 'Run slow tests'
+ required: true
+ type: boolean
+ push:
+ branches:
+ - master
+ paths: ['.github/workflows/server-webui.yml', 'tools/server/webui/**.*', 'tools/server/tests/**.*', 'tools/server/public/**']
+ pull_request:
+ types: [opened, synchronize, reopened]
+ paths: ['.github/workflows/server-webui.yml', 'tools/server/webui/**.*', 'tools/server/tests/**.*', 'tools/server/public/**']
+
+env:
+ LLAMA_LOG_COLORS: 1
+ LLAMA_LOG_PREFIX: 1
+ LLAMA_LOG_TIMESTAMPS: 1
+ LLAMA_LOG_VERBOSITY: 10
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ webui-check:
+ name: WebUI Checks
+ runs-on: ubuntu-latest
+ continue-on-error: true
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
+
+ - name: Setup Node.js
+ id: node
+ uses: actions/setup-node@v4
+ with:
+ node-version: "22"
+ cache: "npm"
+ cache-dependency-path: "tools/server/webui/package-lock.json"
+
+ - name: Install dependencies
+ id: setup
+ if: ${{ steps.node.conclusion == 'success' }}
+ run: npm ci
+ working-directory: tools/server/webui
+
+ - name: Run type checking
+ if: ${{ always() && steps.setup.conclusion == 'success' }}
+ run: npm run check
+ working-directory: tools/server/webui
+
+ - name: Run linting
+ if: ${{ always() && steps.setup.conclusion == 'success' }}
+ run: npm run lint
+ working-directory: tools/server/webui
+
+ - name: Build application
+ if: ${{ always() && steps.setup.conclusion == 'success' }}
+ run: npm run build
+ working-directory: tools/server/webui
+
+ - name: Install Playwright browsers
+ id: playwright
+ if: ${{ always() && steps.setup.conclusion == 'success' }}
+ run: npx playwright install --with-deps
+ working-directory: tools/server/webui
+
+ - name: Build Storybook
+ if: ${{ always() && steps.playwright.conclusion == 'success' }}
+ run: npm run build-storybook
+ working-directory: tools/server/webui
+
+ - name: Run Client tests
+ if: ${{ always() && steps.playwright.conclusion == 'success' }}
+ run: npm run test:client
+ working-directory: tools/server/webui
+
+ - name: Run Unit tests
+ if: ${{ always() && steps.playwright.conclusion == 'success' }}
+ run: npm run test:unit
+ working-directory: tools/server/webui
+
+ - name: Run UI tests
+ if: ${{ always() && steps.playwright.conclusion == 'success' }}
+ run: npm run test:ui -- --testTimeout=60000
+ working-directory: tools/server/webui
+
+ - name: Run E2E tests
+ if: ${{ always() && steps.playwright.conclusion == 'success' }}
+ run: npm run test:e2e
+ working-directory: tools/server/webui
+
+ server-build:
+ runs-on: ubuntu-latest
+
+ strategy:
+ matrix:
+ sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken
+ build_type: [RelWithDebInfo]
+ include:
+ - build_type: Release
+ sanitizer: ""
+ fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
+
+ steps:
+ - name: Dependencies
+ id: depends
+ run: |
+ sudo apt-get update
+ sudo apt-get -y install \
+ build-essential \
+ xxd \
+ git \
+ cmake \
+ curl \
+ wget \
+ language-pack-en \
+ libssl-dev
+
+ - name: Clone
+ id: checkout
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
+
+ - name: Python setup
+ id: setup_python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.11'
+
+ - name: Tests dependencies
+ id: test_dependencies
+ run: |
+ pip install -r tools/server/tests/requirements.txt
+
+ - name: Setup Node.js for WebUI
+ uses: actions/setup-node@v4
+ with:
+ node-version: "22"
+ cache: "npm"
+ cache-dependency-path: "tools/server/webui/package-lock.json"
+
+ - name: Install WebUI dependencies
+ run: npm ci
+ working-directory: tools/server/webui
+
+ - name: Build WebUI
+ run: npm run build
+ working-directory: tools/server/webui
+
+ - name: Build (no OpenMP)
+ id: cmake_build_no_openmp
+ if: ${{ matrix.sanitizer == 'THREAD' }}
+ run: |
+ cmake -B build \
+ -DGGML_NATIVE=OFF \
+ -DLLAMA_CURL=OFF \
+ -DLLAMA_OPENSSL=ON \
+ -DLLAMA_BUILD_SERVER=ON \
+ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
+ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
+ -DGGML_OPENMP=OFF ;
+ cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
+
+ - name: Build (sanitizers)
+ id: cmake_build_sanitizers
+ if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }}
+ run: |
+ cmake -B build \
+ -DGGML_NATIVE=OFF \
+ -DLLAMA_CURL=OFF \
+ -DLLAMA_OPENSSL=ON \
+ -DLLAMA_BUILD_SERVER=ON \
+ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
+ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
+ cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
+
+ - name: Build (sanitizers)
+ id: cmake_build
+ if: ${{ matrix.sanitizer == '' }}
+ run: |
+ cmake -B build \
+ -DGGML_NATIVE=OFF \
+ -DLLAMA_CURL=OFF \
+ -DLLAMA_OPENSSL=ON \
+ -DLLAMA_BUILD_SERVER=ON \
+ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
+ cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
+
+ - name: Tests
+ id: server_integration_tests
+ if: ${{ matrix.sanitizer == '' }}
+ env:
+ GITHUB_ACTIONS: "true"
+ run: |
+ cd tools/server/tests
+ ./tests.sh
+
+ - name: Tests (sanitizers)
+ id: server_integration_tests_sanitizers
+ if: ${{ matrix.sanitizer != '' }}
+ run: |
+ cd tools/server/tests
+ LLAMA_SANITIZE=1 ./tests.sh
+
+ - name: Slow tests
+ id: server_integration_tests_slow
+ if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
+ run: |
+ cd tools/server/tests
+ SLOW_TESTS=1 ./tests.sh
diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml
index a57d0e8b1c..f9e2a79af7 100644
--- a/.github/workflows/server.yml
+++ b/.github/workflows/server.yml
@@ -76,270 +76,6 @@ jobs:
run: |
pip install -r tools/server/tests/requirements.txt
- webui-setup:
- name: WebUI Setup
- runs-on: ubuntu-latest
- steps:
- - name: Checkout code
- uses: actions/checkout@v4
- with:
- fetch-depth: 0
- ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
-
- - name: Setup Node.js
- uses: actions/setup-node@v4
- with:
- node-version: "22"
- cache: "npm"
- cache-dependency-path: "tools/server/webui/package-lock.json"
-
- - name: Cache node_modules
- uses: actions/cache@v4
- id: cache-node-modules
- with:
- path: tools/server/webui/node_modules
- key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
- restore-keys: |
- ${{ runner.os }}-node-modules-
-
- - name: Install dependencies
- if: steps.cache-node-modules.outputs.cache-hit != 'true'
- run: npm ci
- working-directory: tools/server/webui
-
- webui-check:
- needs: webui-setup
- name: WebUI Check
- runs-on: ubuntu-latest
- steps:
- - name: Checkout code
- uses: actions/checkout@v4
- with:
- fetch-depth: 0
- ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
-
- - name: Setup Node.js
- uses: actions/setup-node@v4
- with:
- node-version: "22"
-
- - name: Restore node_modules cache
- uses: actions/cache@v4
- with:
- path: tools/server/webui/node_modules
- key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
- restore-keys: |
- ${{ runner.os }}-node-modules-
-
- - name: Run type checking
- run: npm run check
- working-directory: tools/server/webui
-
- - name: Run linting
- run: npm run lint
- working-directory: tools/server/webui
-
- webui-build:
- needs: webui-check
- name: WebUI Build
- runs-on: ubuntu-latest
- steps:
- - name: Checkout code
- uses: actions/checkout@v4
- with:
- fetch-depth: 0
- ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
-
- - name: Setup Node.js
- uses: actions/setup-node@v4
- with:
- node-version: "22"
-
- - name: Restore node_modules cache
- uses: actions/cache@v4
- with:
- path: tools/server/webui/node_modules
- key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
- restore-keys: |
- ${{ runner.os }}-node-modules-
-
- - name: Build application
- run: npm run build
- working-directory: tools/server/webui
-
- webui-tests:
- needs: webui-build
- name: Run WebUI tests
- permissions:
- contents: read
-
- runs-on: ubuntu-latest
-
- steps:
- - name: Checkout code
- uses: actions/checkout@v4
-
- - name: Setup Node.js
- uses: actions/setup-node@v4
- with:
- node-version: "22"
-
- - name: Restore node_modules cache
- uses: actions/cache@v4
- with:
- path: tools/server/webui/node_modules
- key: ${{ runner.os }}-node-modules-${{ hashFiles('tools/server/webui/package-lock.json') }}
- restore-keys: |
- ${{ runner.os }}-node-modules-
-
- - name: Install Playwright browsers
- run: npx playwright install --with-deps
- working-directory: tools/server/webui
-
- - name: Build Storybook
- run: npm run build-storybook
- working-directory: tools/server/webui
-
- - name: Run Client tests
- run: npm run test:client
- working-directory: tools/server/webui
-
- - name: Run Server tests
- run: npm run test:server
- working-directory: tools/server/webui
-
- - name: Run UI tests
- run: npm run test:ui -- --testTimeout=60000
- working-directory: tools/server/webui
-
- - name: Run E2E tests
- run: npm run test:e2e
- working-directory: tools/server/webui
-
- server-build:
- needs: [webui-tests]
- runs-on: ubuntu-latest
-
- strategy:
- matrix:
- sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken
- build_type: [RelWithDebInfo]
- include:
- - build_type: Release
- sanitizer: ""
- fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
-
- steps:
- - name: Dependencies
- id: depends
- run: |
- sudo apt-get update
- sudo apt-get -y install \
- build-essential \
- xxd \
- git \
- cmake \
- curl \
- wget \
- language-pack-en \
- libssl-dev
-
- - name: Clone
- id: checkout
- uses: actions/checkout@v4
- with:
- fetch-depth: 0
- ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
-
- - name: Python setup
- id: setup_python
- uses: actions/setup-python@v5
- with:
- python-version: '3.11'
-
- - name: Tests dependencies
- id: test_dependencies
- run: |
- pip install -r tools/server/tests/requirements.txt
-
- - name: Setup Node.js for WebUI
- uses: actions/setup-node@v4
- with:
- node-version: "22"
- cache: "npm"
- cache-dependency-path: "tools/server/webui/package-lock.json"
-
- - name: Install WebUI dependencies
- run: npm ci
- working-directory: tools/server/webui
-
- - name: Build WebUI
- run: npm run build
- working-directory: tools/server/webui
-
- - name: Build (no OpenMP)
- id: cmake_build_no_openmp
- if: ${{ matrix.sanitizer == 'THREAD' }}
- run: |
- cmake -B build \
- -DGGML_NATIVE=OFF \
- -DLLAMA_CURL=OFF \
- -DLLAMA_OPENSSL=ON \
- -DLLAMA_BUILD_SERVER=ON \
- -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
- -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
- -DGGML_OPENMP=OFF ;
- cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
-
- - name: Build (sanitizers)
- id: cmake_build_sanitizers
- if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }}
- run: |
- cmake -B build \
- -DGGML_NATIVE=OFF \
- -DLLAMA_CURL=OFF \
- -DLLAMA_OPENSSL=ON \
- -DLLAMA_BUILD_SERVER=ON \
- -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
- -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
- cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
-
- - name: Build (sanitizers)
- id: cmake_build
- if: ${{ matrix.sanitizer == '' }}
- run: |
- cmake -B build \
- -DGGML_NATIVE=OFF \
- -DLLAMA_CURL=OFF \
- -DLLAMA_OPENSSL=ON \
- -DLLAMA_BUILD_SERVER=ON \
- -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
- cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
-
- - name: Tests
- id: server_integration_tests
- if: ${{ matrix.sanitizer == '' }}
- env:
- GITHUB_ACTIONS: "true"
- run: |
- cd tools/server/tests
- ./tests.sh
-
- - name: Tests (sanitizers)
- id: server_integration_tests_sanitizers
- if: ${{ matrix.sanitizer != '' }}
- run: |
- cd tools/server/tests
- LLAMA_SANITIZE=1 ./tests.sh
-
- - name: Slow tests
- id: server_integration_tests_slow
- if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
- run: |
- cd tools/server/tests
- SLOW_TESTS=1 ./tests.sh
-
-
server-windows:
runs-on: windows-2022
diff --git a/.github/copilot-instructions.md b/AGENTS.md
similarity index 88%
rename from .github/copilot-instructions.md
rename to AGENTS.md
index ad13c6ea8d..e0a65c0ab7 100644
--- a/.github/copilot-instructions.md
+++ b/AGENTS.md
@@ -1,4 +1,4 @@
-# Copilot Instructions for llama.cpp
+# Instructions for llama.cpp
## Repository Overview
@@ -12,6 +12,24 @@ llama.cpp is a large-scale C/C++ project for efficient LLM (Large Language Model
- **Backends supported**: CPU (AVX/NEON/RVV optimized), CUDA, Metal, Vulkan, SYCL, ROCm, MUSA
- **License**: MIT
+## Disclose AI Usage
+
+It is crucial to remind contributors that the project mandates disclosure of any AI usage in pull requests. This requirement stems from the potential for AI-generated code to include suboptimal optimizations and hidden bugs, owing to the inherent overconfidence in AI outputs.
+
+When generating significant portions of code, address this by:
+- Informing the user that AI-generated content may be rejected by maintainers.
+- Clearly marking AI-generated code in commit messages and comments.
+ - Example of commit message: `[AI] Fix a race condition in ...`
+ - Example of code comment: `// [AI] spawn a new thread ...`
+
+These measures apply to:
+- Changes resulting in large portions of code or complex logic.
+- Modifications or additions to public APIs in `llama.h`, `ggml.h`, or `mtmd.h`.
+- Backend-related changes, such as those involving CPU, CUDA, Metal, Vulkan, etc.
+- Modifications to `tools/server`.
+
+Note: These measures can be omitted for small fixes or trivial changes.
+
## Build Instructions
### Prerequisites
@@ -251,6 +269,7 @@ Primary tools:
- **Cross-platform compatibility**: Test on Linux, macOS, Windows when possible
- **Performance focus**: This is a performance-critical inference library
- **API stability**: Changes to `include/llama.h` require careful consideration
+- **Disclose AI Usage**: Refer to the "Disclose AI Usage" earlier in this document
### Git Workflow
- Always create feature branches from `master`
diff --git a/CODEOWNERS b/CODEOWNERS
index 8a0c98c968..750096d9a1 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -32,7 +32,7 @@
/examples/export-docs/ @ggerganov
/examples/gen-docs/ @ggerganov
/examples/gguf/ @ggerganov
-/examples/llama.android/ @ggerganov
+/examples/llama.android/ @ggerganov @hanyin-arm @naco-siren
/examples/llama.swiftui/ @ggerganov
/examples/llama.vim @ggerganov
/examples/lookahead/ @ggerganov
diff --git a/README.md b/README.md
index 5f2076d0a3..ed956bb02e 100644
--- a/README.md
+++ b/README.md
@@ -190,6 +190,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
- Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma)
+- Android: [llama.android](/examples/llama.android)
diff --git a/SECURITY.md b/SECURITY.md
index 9c86ae91b5..ae496f4e3d 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -68,3 +68,6 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/
Please note that using AI to identify vulnerabilities and generate reports is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before submitting the report.
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
+
+> [!IMPORTANT]
+> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt
index 0182767c2b..f7b99159e3 100644
--- a/common/CMakeLists.txt
+++ b/common/CMakeLists.txt
@@ -85,6 +85,9 @@ add_library(${TARGET} STATIC
unicode.h
)
+target_include_directories(${TARGET} PUBLIC . ../vendor)
+target_compile_features (${TARGET} PUBLIC cxx_std_17)
+
if (BUILD_SHARED_LIBS)
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
@@ -151,9 +154,7 @@ if (LLAMA_LLGUIDANCE)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS})
endif ()
-target_include_directories(${TARGET} PUBLIC . ../vendor)
-target_compile_features (${TARGET} PUBLIC cxx_std_17)
-target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
+target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
#
diff --git a/common/arg.cpp b/common/arg.cpp
index acf4c8f8a8..1302065498 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -96,6 +96,11 @@ common_arg & common_arg::set_sparam() {
return *this;
}
+common_arg & common_arg::set_preset_only() {
+ is_preset_only = true;
+ return *this;
+}
+
bool common_arg::in_example(enum llama_example ex) {
return examples.find(ex) != examples.end();
}
@@ -420,6 +425,8 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}
};
+ std::set seen_args;
+
for (int i = 1; i < argc; i++) {
const std::string arg_prefix = "--";
@@ -430,6 +437,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
if (arg_to_options.find(arg) == arg_to_options.end()) {
throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str()));
}
+ if (!seen_args.insert(arg).second) {
+ LOG_WRN("DEPRECATED: argument '%s' specified multiple times, use comma-separated values instead (only last value will be used)\n", arg.c_str());
+ }
auto & tmp = arg_to_options[arg];
auto opt = *tmp.first;
bool is_positive = tmp.second;
@@ -750,6 +760,8 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map seen_args;
+
for (int i = 1; i < argc; i++) {
const std::string arg_prefix = "--";
@@ -760,8 +772,16 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map(value, ',')) {
+ std::ifstream file(item);
+ if (!file) {
+ throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
+ }
+ params.in_files.push_back(item);
}
- params.in_files.push_back(value);
}
).set_examples({LLAMA_EXAMPLE_IMATRIX}));
add_opt(common_arg(
@@ -1389,7 +1425,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_sparam());
add_opt(common_arg(
- {"--sampling-seq", "--sampler-seq"}, "SEQUENCE",
+ {"--sampler-seq", "--sampling-seq"}, "SEQUENCE",
string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()),
[](common_params & params, const std::string & value) {
params.sampling.samplers = common_sampler_types_from_chars(value);
@@ -1888,13 +1924,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
LOG_WRN("DEPRECATED: --defrag-thold is deprecated and no longer necessary to specify\n");
}
).set_env("LLAMA_ARG_DEFRAG_THOLD"));
- add_opt(common_arg(
- {"-np", "--parallel"}, "N",
- string_format("number of parallel sequences to decode (default: %d)", params.n_parallel),
- [](common_params & params, int value) {
- params.n_parallel = value;
- }
- ).set_env("LLAMA_ARG_N_PARALLEL"));
+ if (ex == LLAMA_EXAMPLE_SERVER) {
+ // this is to make sure this option appears in the server-specific section of the help message
+ add_opt(common_arg(
+ {"-np", "--parallel"}, "N",
+ string_format("number of server slots (default: %d, -1 = auto)", params.n_parallel),
+ [](common_params & params, int value) {
+ if (value == 0) {
+ throw std::invalid_argument("error: invalid value for n_parallel\n");
+ }
+ params.n_parallel = value;
+ }
+ ).set_env("LLAMA_ARG_N_PARALLEL").set_examples({LLAMA_EXAMPLE_SERVER}));
+ } else {
+ add_opt(common_arg(
+ {"-np", "--parallel"}, "N",
+ string_format("number of parallel sequences to decode (default: %d)", params.n_parallel),
+ [](common_params & params, int value) {
+ params.n_parallel = value;
+ }
+ ).set_env("LLAMA_ARG_N_PARALLEL"));
+ }
add_opt(common_arg(
{"-ns", "--sequences"}, "N",
string_format("number of sequences to decode (default: %d)", params.n_sequences),
@@ -1943,9 +1993,11 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_OFFLOAD"));
add_opt(common_arg(
{"--image", "--audio"}, "FILE",
- "path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n",
+ "path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n",
[](common_params & params, const std::string & value) {
- params.image.emplace_back(value);
+ for (const auto & item : string_split(value, ',')) {
+ params.image.emplace_back(item);
+ }
}
).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
@@ -2031,26 +2083,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
));
add_opt(common_arg(
- {"--override-tensor", "-ot"}, "=,...",
+ {"-ot", "--override-tensor"}, "=,...",
"override tensor buffer type", [](common_params & params, const std::string & value) {
parse_tensor_buffer_overrides(value, params.tensor_buft_overrides);
}
));
add_opt(common_arg(
- {"--override-tensor-draft", "-otd"}, "=,...",
+ {"-otd", "--override-tensor-draft"}, "=,...",
"override tensor buffer type for draft model", [](common_params & params, const std::string & value) {
parse_tensor_buffer_overrides(value, params.speculative.tensor_buft_overrides);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
- {"--cpu-moe", "-cmoe"},
+ {"-cmoe", "--cpu-moe"},
"keep all Mixture of Experts (MoE) weights in the CPU",
[](common_params & params) {
params.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
}
).set_env("LLAMA_ARG_CPU_MOE"));
add_opt(common_arg(
- {"--n-cpu-moe", "-ncmoe"}, "N",
+ {"-ncmoe", "--n-cpu-moe"}, "N",
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU",
[](common_params & params, int value) {
if (value < 0) {
@@ -2065,14 +2117,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_N_CPU_MOE"));
add_opt(common_arg(
- {"--cpu-moe-draft", "-cmoed"},
+ {"-cmoed", "--cpu-moe-draft"},
"keep all Mixture of Experts (MoE) weights in the CPU for the draft model",
[](common_params & params) {
params.speculative.tensor_buft_overrides.push_back(llm_ffn_exps_cpu_override());
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_CPU_MOE_DRAFT"));
add_opt(common_arg(
- {"--n-cpu-moe-draft", "-ncmoed"}, "N",
+ {"-ncmoed", "--n-cpu-moe-draft"}, "N",
"keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model",
[](common_params & params, int value) {
if (value < 0) {
@@ -2192,12 +2244,39 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
));
add_opt(common_arg(
- {"--override-kv"}, "KEY=TYPE:VALUE",
- "advanced option to override model metadata by key. may be specified multiple times.\n"
- "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false",
+ {"--override-kv"}, "KEY=TYPE:VALUE,...",
+ "advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n"
+ "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false",
[](common_params & params, const std::string & value) {
- if (!string_parse_kv_override(value.c_str(), params.kv_overrides)) {
- throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", value.c_str()));
+ std::vector kv_overrides;
+
+ std::string current;
+ bool escaping = false;
+
+ for (const char c : value) {
+ if (escaping) {
+ current.push_back(c);
+ escaping = false;
+ } else if (c == '\\') {
+ escaping = true;
+ } else if (c == ',') {
+ kv_overrides.push_back(current);
+ current.clear();
+ } else {
+ current.push_back(c);
+ }
+ }
+
+ if (escaping) {
+ current.push_back('\\');
+ }
+
+ kv_overrides.push_back(current);
+
+ for (const auto & kv_override : kv_overrides) {
+ if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) {
+ throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str()));
+ }
}
}
));
@@ -2211,33 +2290,50 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
));
add_opt(common_arg(
{"--lora"}, "FNAME",
- "path to LoRA adapter (can be repeated to use multiple adapters)",
+ "path to LoRA adapter (use comma-separated values to load multiple adapters)",
[](common_params & params, const std::string & value) {
- params.lora_adapters.push_back({ std::string(value), 1.0, "", "", nullptr });
+ for (const auto & item : string_split(value, ',')) {
+ params.lora_adapters.push_back({ item, 1.0, "", "", nullptr });
+ }
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(common_arg(
- {"--lora-scaled"}, "FNAME", "SCALE",
- "path to LoRA adapter with user defined scaling (can be repeated to use multiple adapters)",
- [](common_params & params, const std::string & fname, const std::string & scale) {
- params.lora_adapters.push_back({ fname, std::stof(scale), "", "", nullptr });
+ {"--lora-scaled"}, "FNAME:SCALE,...",
+ "path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n"
+ "note: use comma-separated values",
+ [](common_params & params, const std::string & value) {
+ for (const auto & item : string_split(value, ',')) {
+ auto parts = string_split(item, ':');
+ if (parts.size() != 2) {
+ throw std::invalid_argument("lora-scaled format: FNAME:SCALE");
+ }
+ params.lora_adapters.push_back({ parts[0], std::stof(parts[1]), "", "", nullptr });
+ }
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(common_arg(
{"--control-vector"}, "FNAME",
- "add a control vector\nnote: this argument can be repeated to add multiple control vectors",
+ "add a control vector\nnote: use comma-separated values to add multiple control vectors",
[](common_params & params, const std::string & value) {
- params.control_vectors.push_back({ 1.0f, value, });
+ for (const auto & item : string_split(value, ',')) {
+ params.control_vectors.push_back({ 1.0f, item, });
+ }
}
));
add_opt(common_arg(
- {"--control-vector-scaled"}, "FNAME", "SCALE",
+ {"--control-vector-scaled"}, "FNAME:SCALE,...",
"add a control vector with user defined scaling SCALE\n"
- "note: this argument can be repeated to add multiple scaled control vectors",
- [](common_params & params, const std::string & fname, const std::string & scale) {
- params.control_vectors.push_back({ std::stof(scale), fname });
+ "note: use comma-separated values (format: FNAME:SCALE,...)",
+ [](common_params & params, const std::string & value) {
+ for (const auto & item : string_split(value, ',')) {
+ auto parts = string_split(item, ':');
+ if (parts.size() != 2) {
+ throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE");
+ }
+ params.control_vectors.push_back({ std::stof(parts[1]), parts[0] });
+ }
}
));
add_opt(common_arg(
@@ -2327,13 +2423,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_env("HF_TOKEN"));
add_opt(common_arg(
{"--context-file"}, "FNAME",
- "file to load context from (repeat to specify multiple files)",
+ "file to load context from (use comma-separated values to specify multiple files)",
[](common_params & params, const std::string & value) {
- std::ifstream file(value, std::ios::binary);
- if (!file) {
- throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
+ for (const auto & item : string_split(value, ',')) {
+ std::ifstream file(item, std::ios::binary);
+ if (!file) {
+ throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
+ }
+ params.context_files.push_back(item);
}
- params.context_files.push_back(value);
}
).set_examples({LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg(
@@ -2524,6 +2622,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.api_prefix = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX"));
+ add_opt(common_arg(
+ {"--webui-config"}, "JSON",
+ "JSON that provides default WebUI settings (overrides WebUI defaults)",
+ [](common_params & params, const std::string & value) {
+ params.webui_config_json = value;
+ }
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG"));
+ add_opt(common_arg(
+ {"--webui-config-file"}, "PATH",
+ "JSON file that provides default WebUI settings (overrides WebUI defaults)",
+ [](common_params & params, const std::string & value) {
+ params.webui_config_json = read_file(value);
+ }
+ ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG_FILE"));
add_opt(common_arg(
{"--webui"},
{"--no-webui"},
@@ -2540,7 +2652,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(common_arg(
- {"--reranking", "--rerank"},
+ {"--rerank", "--reranking"},
string_format("enable reranking endpoint on server (default: %s)", "disabled"),
[](common_params & params) {
params.embedding = true;
@@ -2775,6 +2887,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.lora_init_without_apply = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
+ add_opt(common_arg(
+ {"--sleep-idle-seconds"}, "SECONDS",
+ string_format("number of seconds of idleness after which the server will sleep (default: %d; -1 = disabled)", params.sleep_idle_seconds),
+ [](common_params & params, int value) {
+ if (value == 0 || value < -1) {
+ throw std::invalid_argument("invalid value: cannot be 0 or less than -1");
+ }
+ params.sleep_idle_seconds = value;
+ }
+ ).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--simple-io"},
"use basic IO for better compatibility in subprocesses and limited consoles",
@@ -3011,7 +3133,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}));
add_opt(common_arg(
- {"--draft-max", "--draft", "--draft-n"}, "N",
+ {"--draft", "--draft-n", "--draft-max"}, "N",
string_format("number of tokens to draft for speculative decoding (default: %d)", params.speculative.n_max),
[](common_params & params, int value) {
params.speculative.n_max = value;
@@ -3387,3 +3509,24 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
return ctx_arg;
}
+
+void common_params_add_preset_options(std::vector & args) {
+ // arguments below won't be treated as CLI args, only preset options
+ args.push_back(common_arg(
+ {"load-on-startup"}, "NAME",
+ "in server router mode, autoload this model on startup",
+ [](common_params &, const std::string &) { /* unused */ }
+ ).set_env(COMMON_ARG_PRESET_LOAD_ON_STARTUP).set_preset_only());
+
+ // args.push_back(common_arg(
+ // {"pin"},
+ // "in server router mode, do not unload this model if models_max is exceeded",
+ // [](common_params &) { /* unused */ }
+ // ).set_preset_only());
+
+ // args.push_back(common_arg(
+ // {"unload-idle-seconds"}, "SECONDS",
+ // "in server router mode, unload models idle for more than this many seconds",
+ // [](common_params &, int) { /* unused */ }
+ // ).set_preset_only());
+}
diff --git a/common/arg.h b/common/arg.h
index 1321595c1a..f5111c658f 100644
--- a/common/arg.h
+++ b/common/arg.h
@@ -8,6 +8,9 @@
#include
#include
+// pseudo-env variable to identify preset-only arguments
+#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP"
+
//
// CLI argument parsing
//
@@ -22,6 +25,7 @@ struct common_arg {
const char * env = nullptr;
std::string help;
bool is_sparam = false; // is current arg a sampling param?
+ bool is_preset_only = false; // is current arg preset-only (not treated as CLI arg)
void (*handler_void) (common_params & params) = nullptr;
void (*handler_string) (common_params & params, const std::string &) = nullptr;
void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr;
@@ -70,6 +74,7 @@ struct common_arg {
common_arg & set_excludes(std::initializer_list excludes);
common_arg & set_env(const char * env);
common_arg & set_sparam();
+ common_arg & set_preset_only();
bool in_example(enum llama_example ex);
bool is_exclude(enum llama_example ex);
bool get_value_from_env(std::string & output) const;
@@ -114,9 +119,13 @@ struct common_params_context {
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
// parse input arguments from CLI into a map
-// TODO: support repeated args in the future
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map & out_map);
+// populate preset-only arguments
+// these arguments are not treated as command line arguments
+// see: https://github.com/ggml-org/llama.cpp/issues/18163
+void common_params_add_preset_options(std::vector & args);
+
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp
index 74a7b6a46d..1bcba9cd86 100644
--- a/common/chat-peg-parser.cpp
+++ b/common/chat-peg-parser.cpp
@@ -4,9 +4,14 @@
using json = nlohmann::json;
-static std::string_view trim_trailing_space(std::string_view sv) {
+static std::string_view trim_trailing_space(std::string_view sv, int max = -1) {
+ int count = 0;
while (!sv.empty() && std::isspace(static_cast(sv.back()))) {
+ if (max != -1 && count <= max) {
+ break;
+ }
sv.remove_suffix(1);
+ count++;
}
return sv;
}
@@ -93,7 +98,7 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
if (is_arg_string && current_tool) {
// Serialize to JSON, but exclude the end quote
- std::string dumped = json(node.text).dump();
+ std::string dumped = json(trim_trailing_space(node.text)).dump();
current_tool->arguments += dumped.substr(0, dumped.size() - 1);
needs_closing_quote = true;
}
@@ -101,6 +106,7 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
if (is_arg_close && current_tool) {
if (needs_closing_quote) {
current_tool->arguments += "\"";
+ needs_closing_quote = false;
}
}
@@ -109,6 +115,10 @@ void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) {
}
if (is_tool_close && current_tool) {
+ if (needs_closing_quote) {
+ current_tool->arguments += "\"";
+ needs_closing_quote = false;
+ }
current_tool->arguments += "}";
}
}
diff --git a/common/chat.cpp b/common/chat.cpp
index c371edaa5a..0a426f4478 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -711,6 +711,25 @@ static void foreach_function(const json & tools, const std::function & fn) {
+ if (!function.contains("parameters") || !function.at("parameters").is_object()) {
+ return;
+ }
+ const auto & params = function.at("parameters");
+ if (!params.contains("properties") || !params.at("properties").is_object()) {
+ return;
+ }
+ const auto & props = params.at("properties");
+ std::set required;
+ if (params.contains("required") && params.at("required").is_array()) {
+ params.at("required").get_to(required);
+ }
+ for (const auto & [name, prop] : props.items()) {
+ bool is_required = (required.find(name) != required.end());
+ fn(name, prop, is_required);
+ }
+}
+
static std::string apply(
const common_chat_template & tmpl,
const struct templates_params & inputs,
@@ -1409,6 +1428,123 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
return data;
}
+static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_template & tmpl, const struct templates_params & inputs) {
+ common_chat_params data;
+
+ data.prompt = apply(tmpl, inputs);
+ data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED;
+
+ // Handle thinking tags appropriately based on inputs.enable_thinking
+ if (string_ends_with(data.prompt, "\n")) {
+ if (!inputs.enable_thinking) {
+ data.prompt += "";
+ } else {
+ data.thinking_forced_open = true;
+ }
+ }
+
+ data.preserved_tokens = {
+ "",
+ "",
+ "",
+ "",
+ };
+
+ auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
+ auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
+ auto include_grammar = true;
+
+ auto parser = build_chat_peg_constructed_parser([&](auto & p) {
+ auto reasoning = p.eps();
+ if (inputs.enable_thinking && extract_reasoning) {
+ auto reasoning_content = p.reasoning(p.until("")) + ("" | p.end());
+ if (data.thinking_forced_open) {
+ reasoning = reasoning_content;
+ }
+ }
+
+ // Response format parser
+ if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
+ return reasoning << p.content(p.schema(p.json(), "response-format", inputs.json_schema));
+ }
+
+ // Tool call parser
+ if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
+ auto tool_choice = p.choice();
+ foreach_function(inputs.tools, [&](const json & tool) {
+ const auto & function = tool.at("function");
+ std::string name = function.at("name");
+ auto parameters = function.at("parameters");
+
+ auto schema_info = common_schema_info();
+ schema_info.resolve_refs(parameters);
+
+ auto tool_open = "\n";
+ auto tool_close = p.literal("\n");
+ auto args = p.sequence();
+ auto arg_string = p.rule("xml-arg-string", p.until_one_of({
+ "\n",
+ "\n"
+ }));
+
+ foreach_parameter(function, [&](const auto & param_name, const json & param_schema, bool is_required) {
+ auto rule_name = "tool-" + name + "-arg-" + param_name;
+
+ auto arg_open = "\n";
+ auto arg_close = p.literal("\n");
+ auto arg_value = p.eps();
+
+ if (schema_info.resolves_to_string(param_schema)) {
+ arg_value = p.tool_arg_string_value(arg_string) + "\n";
+ } else {
+ arg_value = p.tool_arg_json_value(p.schema(p.json(), rule_name + "-schema", param_schema));
+ }
+
+ // Model may or my not close with
+ auto arg_rule = p.rule(rule_name, p.tool_arg_open(arg_open) + arg_value + p.optional(p.tool_arg_close(arg_close)));
+ args += p.repeat(arg_rule, /* min = */ is_required ? 1 : 0, /* max = */ 1);
+ });
+
+ tool_choice |= p.rule("tool-" + name, p.tool_open(tool_open) + args + p.tool_close(tool_close));
+ });
+
+ auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
+ auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
+ auto tool_call = p.rule("tool-call", "\n" + tool_choice + "" + p.space());
+ auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls));
+
+ return reasoning << p.content(p.until("")) << tool_calls;
+ }
+
+ // Content only parser
+ include_grammar = false;
+ return reasoning << p.content(p.rest());
+ });
+
+ data.parser = parser.save();
+
+ if (include_grammar) {
+ data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
+
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+ foreach_function(inputs.tools, [&](const json & tool) {
+ const auto & function = tool.at("function");
+ auto schema = function.at("parameters");
+ builder.resolve_refs(schema);
+ });
+ parser.build_grammar(builder, data.grammar_lazy);
+ });
+
+ data.grammar_triggers = {
+ {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, ""}
+ };
+ }
+
+ return data;
+}
+
+
static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
@@ -2534,6 +2670,10 @@ static common_chat_params common_chat_templates_apply_jinja(
src.find("") != std::string::npos &&
src.find("") != std::string::npos) {
+ return common_chat_params_init_nemotron_v3(tmpl, params);
+ }
return common_chat_params_init_qwen3_coder_xml(tmpl, params);
}
diff --git a/common/common.cpp b/common/common.cpp
index 5a8cf52485..acf2ec841d 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1078,6 +1078,8 @@ struct common_init_result::impl {
impl() = default;
~impl() = default;
+ // note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top
+
llama_model_ptr model;
llama_context_ptr context;
@@ -1092,7 +1094,7 @@ common_init_result::common_init_result(common_params & params) :
auto cparams = common_context_params_to_llama(params);
if (params.fit_params) {
- LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__);
+ LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx,
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
diff --git a/common/common.h b/common/common.h
index d70744840f..334372073a 100644
--- a/common/common.h
+++ b/common/common.h
@@ -475,7 +475,8 @@ struct common_params {
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
int reasoning_budget = -1;
- bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
+ bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
+ int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time
std::vector api_keys;
@@ -484,8 +485,11 @@ struct common_params {
std::map default_template_kwargs;
+ // webui configs
+ bool webui = true;
+ std::string webui_config_json;
+
// "advanced" endpoints are disabled by default for better security
- bool webui = true;
bool endpoint_slots = true;
bool endpoint_props = false; // only control POST requests, not GET
bool endpoint_metrics = false;
diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
index c3b4e5d9dc..2f67c74d79 100644
--- a/common/json-schema-to-grammar.cpp
+++ b/common/json-schema-to-grammar.cpp
@@ -305,8 +305,9 @@ static std::string format_literal(const std::string & literal) {
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
-class SchemaConverter {
+class common_schema_converter {
private:
+ friend class common_schema_info;
friend std::string build_grammar(const std::function & cb, const common_grammar_options & options);
std::function _fetch_json;
bool _dotall;
@@ -729,7 +730,7 @@ private:
}
public:
- SchemaConverter(
+ common_schema_converter(
const std::function & fetch_json,
bool dotall)
: _fetch_json(fetch_json), _dotall(dotall)
@@ -990,6 +991,134 @@ public:
}
};
+// common_schema_info implementation (pimpl)
+
+common_schema_info::common_schema_info()
+ : impl_(std::make_unique(
+ [](const std::string &) { return json(); },
+ false)) {}
+
+common_schema_info::~common_schema_info() = default;
+
+common_schema_info::common_schema_info(common_schema_info &&) noexcept = default;
+common_schema_info & common_schema_info::operator=(common_schema_info &&) noexcept = default;
+
+void common_schema_info::resolve_refs(nlohmann::ordered_json & schema) {
+ impl_->resolve_refs(schema, "");
+}
+
+// Determines if a JSON schema can resolve to a string type through any path.
+// Some models emit raw string values rather than JSON-encoded strings for string parameters.
+// If any branch of the schema (via oneOf, anyOf, $ref, etc.) permits a string, this returns
+// true, allowing callers to handle the value as a raw string for simplicity.
+bool common_schema_info::resolves_to_string(const nlohmann::ordered_json & schema) {
+ std::unordered_set visited_refs;
+
+ std::function check = [&](const json & s) -> bool {
+ if (!s.is_object()) {
+ return false;
+ }
+
+ // Handle $ref
+ if (s.contains("$ref")) {
+ const std::string & ref = s["$ref"];
+ if (visited_refs.find(ref) != visited_refs.end()) {
+ // Circular reference, assume not a string to be safe
+ return false;
+ }
+ visited_refs.insert(ref);
+ auto it = impl_->_refs.find(ref);
+ if (it != impl_->_refs.end()) {
+ return check(it->second);
+ }
+ return false;
+ }
+
+ // Check type field
+ if (s.contains("type")) {
+ const json & schema_type = s["type"];
+ if (schema_type.is_string()) {
+ if (schema_type == "string") {
+ return true;
+ }
+ } else if (schema_type.is_array()) {
+ // Type can be an array like ["string", "null"]
+ for (const auto & t : schema_type) {
+ if (t == "string") {
+ return true;
+ }
+ }
+ }
+ }
+
+ // Check oneOf/anyOf - if any alternative can be a string
+ if (s.contains("oneOf")) {
+ for (const auto & alt : s["oneOf"]) {
+ if (check(alt)) {
+ return true;
+ }
+ }
+ }
+ if (s.contains("anyOf")) {
+ for (const auto & alt : s["anyOf"]) {
+ if (check(alt)) {
+ return true;
+ }
+ }
+ }
+
+ // Check allOf - all components must be compatible with string type
+ if (s.contains("allOf")) {
+ bool all_string = true;
+ for (const auto & component : s["allOf"]) {
+ if (!check(component)) {
+ all_string = false;
+ break;
+ }
+ }
+ if (all_string) {
+ return true;
+ }
+ }
+
+ // Check const - if the constant value is a string
+ if (s.contains("const")) {
+ if (s["const"].is_string()) {
+ return true;
+ }
+ }
+
+ // Check enum - if any enum value is a string
+ if (s.contains("enum")) {
+ for (const auto & val : s["enum"]) {
+ if (val.is_string()) {
+ return true;
+ }
+ }
+ }
+
+ // String-specific keywords imply string type
+ if (s.contains("pattern") || s.contains("minLength") || s.contains("maxLength")) {
+ return true;
+ }
+
+ // Check format - many formats imply string
+ if (s.contains("format")) {
+ const std::string & fmt = s["format"];
+ if (fmt == "date" || fmt == "time" || fmt == "date-time" ||
+ fmt == "uri" || fmt == "email" || fmt == "hostname" ||
+ fmt == "ipv4" || fmt == "ipv6" || fmt == "uuid" ||
+ fmt.find("uuid") == 0) {
+ return true;
+ }
+ }
+
+ return false;
+ };
+
+ return check(schema);
+}
+
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
#ifdef LLAMA_USE_LLGUIDANCE
if (!force_gbnf) {
@@ -1006,7 +1135,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
}
std::string build_grammar(const std::function & cb, const common_grammar_options & options) {
- SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall);
+ common_schema_converter converter([&](const std::string &) { return json(); }, options.dotall);
common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);
diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h
index c89ab7f997..240d642311 100644
--- a/common/json-schema-to-grammar.h
+++ b/common/json-schema-to-grammar.h
@@ -3,11 +3,31 @@
#include
#include
+#include
#include
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
bool force_gbnf = false);
+class common_schema_converter;
+
+// Probes a JSON schema to extract information about its structure and type constraints.
+class common_schema_info {
+ std::unique_ptr impl_;
+
+ public:
+ common_schema_info();
+ ~common_schema_info();
+
+ common_schema_info(const common_schema_info &) = delete;
+ common_schema_info & operator=(const common_schema_info &) = delete;
+ common_schema_info(common_schema_info &&) noexcept;
+ common_schema_info & operator=(common_schema_info &&) noexcept;
+
+ void resolve_refs(nlohmann::ordered_json & schema);
+ bool resolves_to_string(const nlohmann::ordered_json & schema);
+};
+
struct common_grammar_builder {
std::function add_rule;
std::function add_schema;
diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp
index dec99e1820..f2fc84500f 100644
--- a/common/peg-parser.cpp
+++ b/common/peg-parser.cpp
@@ -425,7 +425,7 @@ struct parser_executor {
if (result.need_more_input()) {
// Propagate - need to know what child would match before negating
- return result;
+ return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos);
}
// Child failed, so negation succeeds
diff --git a/common/preset.cpp b/common/preset.cpp
index 60746aad58..e2fc18c5da 100644
--- a/common/preset.cpp
+++ b/common/preset.cpp
@@ -2,6 +2,7 @@
#include "preset.h"
#include "peg-parser.h"
#include "log.h"
+#include "download.h"
#include
#include
@@ -15,11 +16,22 @@ static std::string rm_leading_dashes(const std::string & str) {
return str.substr(pos);
}
-std::vector common_preset::to_args() const {
+std::vector common_preset::to_args(const std::string & bin_path) const {
std::vector args;
+ if (!bin_path.empty()) {
+ args.push_back(bin_path);
+ }
+
for (const auto & [opt, value] : options) {
- args.push_back(opt.args.back()); // use the last arg as the main arg
+ if (opt.is_preset_only) {
+ continue; // skip preset-only options (they are not CLI args)
+ }
+
+ // use the last arg as the main arg (i.e. --long-form)
+ args.push_back(opt.args.back());
+
+ // handle value(s)
if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) {
// flag option, no value
if (common_arg_utils::is_falsey(value)) {
@@ -63,6 +75,52 @@ std::string common_preset::to_ini() const {
return ss.str();
}
+void common_preset::set_option(const common_preset_context & ctx, const std::string & env, const std::string & value) {
+ // try if option exists, update it
+ for (auto & [opt, val] : options) {
+ if (opt.env && env == opt.env) {
+ val = value;
+ return;
+ }
+ }
+ // if option does not exist, we need to add it
+ if (ctx.key_to_opt.find(env) == ctx.key_to_opt.end()) {
+ throw std::runtime_error(string_format(
+ "%s: option with env '%s' not found in ctx_params",
+ __func__, env.c_str()
+ ));
+ }
+ options[ctx.key_to_opt.at(env)] = value;
+}
+
+void common_preset::unset_option(const std::string & env) {
+ for (auto it = options.begin(); it != options.end(); ) {
+ const common_arg & opt = it->first;
+ if (opt.env && env == opt.env) {
+ it = options.erase(it);
+ return;
+ } else {
+ ++it;
+ }
+ }
+}
+
+bool common_preset::get_option(const std::string & env, std::string & value) const {
+ for (const auto & [opt, val] : options) {
+ if (opt.env && env == opt.env) {
+ value = val;
+ return true;
+ }
+ }
+ return false;
+}
+
+void common_preset::merge(const common_preset & other) {
+ for (const auto & [opt, val] : other.options) {
+ options[opt] = val; // overwrite existing options
+ }
+}
+
static std::map> parse_ini_from_file(const std::string & path) {
std::map> parsed;
@@ -172,9 +230,14 @@ static std::string parse_bool_arg(const common_arg & arg, const std::string & ke
return value;
}
-common_presets common_presets_load(const std::string & path, common_params_context & ctx_params) {
+common_preset_context::common_preset_context(llama_example ex)
+ : ctx_params(common_params_parser_init(default_params, ex)) {
+ common_params_add_preset_options(ctx_params.options);
+ key_to_opt = get_map_key_opt(ctx_params);
+}
+
+common_presets common_preset_context::load_from_ini(const std::string & path, common_preset & global) const {
common_presets out;
- auto key_to_opt = get_map_key_opt(ctx_params);
auto ini_data = parse_ini_from_file(path);
for (auto section : ini_data) {
@@ -188,7 +251,7 @@ common_presets common_presets_load(const std::string & path, common_params_conte
for (const auto & [key, value] : section.second) {
LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str());
if (key_to_opt.find(key) != key_to_opt.end()) {
- auto & opt = key_to_opt[key];
+ const auto & opt = key_to_opt.at(key);
if (is_bool_arg(opt)) {
preset.options[opt] = parse_bool_arg(opt, key, value);
} else {
@@ -199,8 +262,137 @@ common_presets common_presets_load(const std::string & path, common_params_conte
// TODO: maybe warn about unknown key?
}
}
+
+ if (preset.name == "*") {
+ // handle global preset
+ global = preset;
+ } else {
+ out[preset.name] = preset;
+ }
+ }
+
+ return out;
+}
+
+common_presets common_preset_context::load_from_cache() const {
+ common_presets out;
+
+ auto cached_models = common_list_cached_models();
+ for (const auto & model : cached_models) {
+ common_preset preset;
+ preset.name = model.to_string();
+ preset.set_option(*this, "LLAMA_ARG_HF_REPO", model.to_string());
out[preset.name] = preset;
}
return out;
}
+
+struct local_model {
+ std::string name;
+ std::string path;
+ std::string path_mmproj;
+};
+
+common_presets common_preset_context::load_from_models_dir(const std::string & models_dir) const {
+ if (!std::filesystem::exists(models_dir) || !std::filesystem::is_directory(models_dir)) {
+ throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", models_dir.c_str()));
+ }
+
+ std::vector models;
+ auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
+ auto files = fs_list(subdir_path, false);
+ common_file_info model_file;
+ common_file_info first_shard_file;
+ common_file_info mmproj_file;
+ for (const auto & file : files) {
+ if (string_ends_with(file.name, ".gguf")) {
+ if (file.name.find("mmproj") != std::string::npos) {
+ mmproj_file = file;
+ } else if (file.name.find("-00001-of-") != std::string::npos) {
+ first_shard_file = file;
+ } else {
+ model_file = file;
+ }
+ }
+ }
+ // single file model
+ local_model model{
+ /* name */ name,
+ /* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
+ /* path_mmproj */ mmproj_file.path // can be empty
+ };
+ if (!model.path.empty()) {
+ models.push_back(model);
+ }
+ };
+
+ auto files = fs_list(models_dir, true);
+ for (const auto & file : files) {
+ if (file.is_dir) {
+ scan_subdir(file.path, file.name);
+ } else if (string_ends_with(file.name, ".gguf")) {
+ // single file model
+ std::string name = file.name;
+ string_replace_all(name, ".gguf", "");
+ local_model model{
+ /* name */ name,
+ /* path */ file.path,
+ /* path_mmproj */ ""
+ };
+ models.push_back(model);
+ }
+ }
+
+ // convert local models to presets
+ common_presets out;
+ for (const auto & model : models) {
+ common_preset preset;
+ preset.name = model.name;
+ preset.set_option(*this, "LLAMA_ARG_MODEL", model.path);
+ if (!model.path_mmproj.empty()) {
+ preset.set_option(*this, "LLAMA_ARG_MMPROJ", model.path_mmproj);
+ }
+ out[preset.name] = preset;
+ }
+
+ return out;
+}
+
+common_preset common_preset_context::load_from_args(int argc, char ** argv) const {
+ common_preset preset;
+ preset.name = COMMON_PRESET_DEFAULT_NAME;
+
+ bool ok = common_params_to_map(argc, argv, ctx_params.ex, preset.options);
+ if (!ok) {
+ throw std::runtime_error("failed to parse CLI arguments into preset");
+ }
+
+ return preset;
+}
+
+common_presets common_preset_context::cascade(const common_presets & base, const common_presets & added) const {
+ common_presets out = base; // copy
+ for (const auto & [name, preset_added] : added) {
+ if (out.find(name) != out.end()) {
+ // if exists, merge
+ common_preset & target = out[name];
+ target.merge(preset_added);
+ } else {
+ // otherwise, add directly
+ out[name] = preset_added;
+ }
+ }
+ return out;
+}
+
+common_presets common_preset_context::cascade(const common_preset & base, const common_presets & presets) const {
+ common_presets out;
+ for (const auto & [name, preset] : presets) {
+ common_preset tmp = base; // copy
+ tmp.name = name;
+ tmp.merge(preset);
+ out[name] = std::move(tmp);
+ }
+ return out;
+}
diff --git a/common/preset.h b/common/preset.h
index dceb849eb8..3a84d1be29 100644
--- a/common/preset.h
+++ b/common/preset.h
@@ -13,20 +13,62 @@
constexpr const char * COMMON_PRESET_DEFAULT_NAME = "default";
+struct common_preset_context;
+
struct common_preset {
std::string name;
- // TODO: support repeated args in the future
+
+ // options are stored as common_arg to string mapping, representing CLI arg and its value
std::map options;
// convert preset to CLI argument list
- std::vector to_args() const;
+ std::vector to_args(const std::string & bin_path = "") const;
// convert preset to INI format string
std::string to_ini() const;
// TODO: maybe implement to_env() if needed
+
+ // modify preset options where argument is identified by its env variable
+ void set_option(const common_preset_context & ctx, const std::string & env, const std::string & value);
+
+ // unset option by its env variable
+ void unset_option(const std::string & env);
+
+ // get option value by its env variable, return false if not found
+ bool get_option(const std::string & env, std::string & value) const;
+
+ // merge another preset into this one, overwriting existing options
+ void merge(const common_preset & other);
};
// interface for multiple presets in one file
using common_presets = std::map;
-common_presets common_presets_load(const std::string & path, common_params_context & ctx_params);
+
+// context for loading and editing presets
+struct common_preset_context {
+ common_params default_params; // unused for now
+ common_params_context ctx_params;
+ std::map key_to_opt;
+ common_preset_context(llama_example ex);
+
+ // load presets from INI file
+ common_presets load_from_ini(const std::string & path, common_preset & global) const;
+
+ // generate presets from cached models
+ common_presets load_from_cache() const;
+
+ // generate presets from local models directory
+ // for the directory structure, see "Using multiple models" in server/README.md
+ common_presets load_from_models_dir(const std::string & models_dir) const;
+
+ // generate one preset from CLI arguments
+ common_preset load_from_args(int argc, char ** argv) const;
+
+ // cascade multiple presets if exist on both: base < added
+ // if preset does not exist in base, it will be added without modification
+ common_presets cascade(const common_presets & base, const common_presets & added) const;
+
+ // apply presets over a base preset (same idea as CSS cascading)
+ common_presets cascade(const common_preset & base, const common_presets & presets) const;
+};
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 6935d84e22..c66f935c65 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -104,10 +104,9 @@ struct ring_buffer {
struct common_sampler {
common_params_sampling params;
+ struct llama_sampler * grmr;
struct llama_sampler * chain;
- bool grammar;
-
ring_buffer prev;
std::vector cur;
@@ -167,15 +166,14 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
lparams.no_perf = params.no_perf;
+ llama_sampler * grmr = nullptr;
llama_sampler * chain = llama_sampler_chain_init(lparams);
- bool grammar = false;
std::vector samplers;
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
- samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
- grammar = true;
+ grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE
@@ -224,15 +222,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
if (!params.grammar.empty()) {
if (params.grammar_lazy) {
- samplers.push_back(
- llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
- trigger_patterns_c.data(), trigger_patterns_c.size(),
- trigger_tokens.data(), trigger_tokens.size()));
+ grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
+ trigger_patterns_c.data(), trigger_patterns_c.size(),
+ trigger_tokens.data(), trigger_tokens.size());
} else {
- samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
+ grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
}
-
- grammar = true;
}
}
@@ -303,8 +298,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
auto * result = new common_sampler {
/* .params = */ params,
+ /* .grmr = */ grmr,
/* .chain = */ chain,
- /* .grammar = */ grammar,
/* .prev = */ ring_buffer(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
@@ -315,6 +310,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
+ llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
delete gsmpl;
@@ -324,25 +320,12 @@ void common_sampler_free(struct common_sampler * gsmpl) {
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
const auto tm = gsmpl->tm();
- if (gsmpl->grammar) {
- const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
-
- for (int i = 0; i < n_smpl; i++) {
- auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
-
- // the grammar sampler is always the first one
- if (i == 0) {
- if (accept_grammar) {
- llama_sampler_accept(smpl, token);
- }
- } else {
- llama_sampler_accept(smpl, token);
- }
- }
- } else {
- llama_sampler_accept(gsmpl->chain, token);
+ if (gsmpl->grmr && accept_grammar) {
+ llama_sampler_accept(gsmpl->grmr, token);
}
+ llama_sampler_accept(gsmpl->chain, token);
+
gsmpl->prev.push_back(token);
}
@@ -353,8 +336,8 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler {
/* .params = */ gsmpl->params,
+ /* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
- /* .grammar = */ gsmpl->grammar,
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
@@ -410,7 +393,7 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
return gsmpl->chain;
}
-llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
+llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
@@ -418,11 +401,42 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
llama_token id = LLAMA_TOKEN_NULL;
+ auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
gsmpl->set_logits(ctx, idx);
+ if (grammar_first) {
+ llama_sampler_apply(grmr, &cur_p);
+ }
+
+ llama_sampler_apply(chain, &cur_p);
+
+ id = cur_p.data[cur_p.selected].id;
+
+ if (grammar_first) {
+ return id;
+ }
+
+ // check if it the sampled token fits the grammar (grammar-based rejection sampling)
+ {
+ llama_token_data single_token_data = { id, 1.0f, 0.0f };
+ llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
+
+ llama_sampler_apply(grmr, &single_token_data_array);
+
+ const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
+ if (is_valid) {
+ return id;
+ }
+ }
+
+ // resampling:
+ // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
+ gsmpl->set_logits(ctx, idx);
+
+ llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
@@ -432,7 +446,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return id;
}
-std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft) {
+std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector result;
@@ -440,7 +454,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample
size_t i = 0;
for (; i < draft.size(); i++) {
- const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
@@ -452,7 +466,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample
}
if (i == draft.size()) {
- const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
@@ -462,13 +476,13 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample
return result;
}
-std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
+std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}
- return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
+ return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
}
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
diff --git a/common/sampling.h b/common/sampling.h
index ace5d3d020..c7101032f2 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -57,7 +57,10 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
// - check if the token fits the grammar (if any)
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
//
-llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
+// if grammar_first is true, the grammar is applied before the samplers (slower)
+// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
+//
+llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
// generalized version of common_sampler_sample
//
@@ -75,10 +78,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
//
// returns at least 1 token, up to idxs.size()
//
-std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft);
+std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false);
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
-std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
+std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
diff --git a/common/speculative.cpp b/common/speculative.cpp
index 1e12383ae6..3e83b0964c 100644
--- a/common/speculative.cpp
+++ b/common/speculative.cpp
@@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft(
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);
- common_sampler_sample(smpl, ctx_dft, 0);
+ common_sampler_sample(smpl, ctx_dft, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index ee02cdd91c..22f703e6ad 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -136,29 +136,29 @@ class ModelBase:
self.remote_hf_model_id = remote_hf_model_id
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
- self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {}
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
- # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
- if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
- if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None:
- self.rope_parameters["rope_theta"] = rope_theta
- if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
- self.rope_parameters["rope_type"] = rope_type
-
- # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
+ # Apply heuristics to figure out typical tensor encoding based on first tensor's dtype
+ # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
if self.ftype == gguf.LlamaFileType.GUESSED:
- # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
- _, first_tensor = next(self.get_tensors())
- if first_tensor.dtype == torch.float16:
- logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})")
- self.ftype = gguf.LlamaFileType.MOSTLY_F16
+ for _, tensor in self.get_tensors():
+ if tensor.dim() < 2:
+ continue
+
+ if tensor.dtype == torch.bfloat16:
+ self.ftype = gguf.LlamaFileType.MOSTLY_BF16
+ logger.info("heuristics detected bfloat16 tensor dtype, setting --outtype bf16")
+ break
+ elif tensor.dtype == torch.float16:
+ self.ftype = gguf.LlamaFileType.MOSTLY_F16
+ logger.info("heuristics detected float16 tensor dtype, setting --outtype f16")
+ break
else:
- logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
- self.ftype = gguf.LlamaFileType.MOSTLY_BF16
+ self.ftype = gguf.LlamaFileType.MOSTLY_F16
+ logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16")
self.dequant_model()
@@ -197,10 +197,10 @@ class ModelBase:
return tensors
prefix = "model" if not self.is_mistral_format else "consolidated"
- part_names: set[str] = set(ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors"))
+ part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
is_safetensors: bool = len(part_names) > 0
if not is_safetensors:
- part_names = set(ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin"))
+ part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
tensor_names_from_index: set[str] = set()
@@ -217,7 +217,8 @@ class ModelBase:
if weight_map is None or not isinstance(weight_map, dict):
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
tensor_names_from_index.update(weight_map.keys())
- part_names |= set(weight_map.values())
+ part_dict: dict[str, None] = dict.fromkeys(weight_map.values(), None)
+ part_names = sorted(part_dict.keys())
else:
weight_map = {}
else:
@@ -719,6 +720,9 @@ class ModelBase:
if "thinker_config" in config:
# rename for Qwen2.5-Omni
config["text_config"] = config["thinker_config"]["text_config"]
+ if "lfm" in config:
+ # rename for LFM2-Audio
+ config["text_config"] = config["lfm"]
return config
@classmethod
@@ -765,6 +769,15 @@ class TextModel(ModelBase):
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
+ self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {}
+
+ # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
+ if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
+ if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None:
+ self.rope_parameters["rope_theta"] = rope_theta
+ if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
+ self.rope_parameters["rope_type"] = rope_type
+
@classmethod
def __init_subclass__(cls):
# can't use an abstract property, because overriding it without type errors
@@ -861,6 +874,14 @@ class TextModel(ModelBase):
logger.warning(f"Unknown RoPE type: {rope_type}")
logger.info(f"gguf: rope scaling type = {rope_gguf_type.name}")
+ if "mrope_section" in self.rope_parameters:
+ mrope_section = self.rope_parameters["mrope_section"]
+ # Pad to 4 dimensions [time, height, width, extra]
+ while len(mrope_section) < 4:
+ mrope_section.append(0)
+ self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
+ logger.info(f"gguf: mrope sections: {mrope_section[:4]}")
+
if (rope_theta := rope_params.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
logger.info(f"gguf: rope theta = {rope_theta}")
@@ -1203,6 +1224,9 @@ class TextModel(ModelBase):
if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95":
# ref: https://huggingface.co/MiniMaxAI/MiniMax-M2
res = "minimax-m2"
+ if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665":
+ # ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer
+ res = "kormo"
if res is None:
logger.warning("\n")
@@ -1826,7 +1850,7 @@ class MmprojModel(ModelBase):
def tensor_force_quant(self, name, new_name, bid, n_dims):
del bid, name, n_dims # unused
- if ".patch_embd.weight" in new_name:
+ if ".patch_embd.weight" in new_name or ".patch_merger.weight" in new_name:
return gguf.GGMLQuantizationType.F16 if self.ftype == gguf.LlamaFileType.MOSTLY_F16 else gguf.GGMLQuantizationType.F32
return False
@@ -3398,7 +3422,7 @@ class QwenModel(TextModel):
self._set_vocab_qwen()
-@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration")
+@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM")
class Qwen2Model(TextModel):
model_arch = gguf.MODEL_ARCH.QWEN2
@@ -3735,9 +3759,6 @@ class Qwen2VLModel(TextModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
- mrope_section = self.hparams["rope_scaling"]["mrope_section"]
- mrope_section += [0] * max(0, 4 - len(mrope_section))
- self.gguf_writer.add_rope_dimension_sections(mrope_section)
def set_vocab(self):
try:
@@ -4373,6 +4394,30 @@ class Qwen3VLVisionModel(MmprojModel):
return super().modify_tensors(data_torch, name, bid)
+@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration")
+class Glm4VVisionModel(Qwen3VLVisionModel):
+ def set_gguf_parameters(self):
+ MmprojModel.set_gguf_parameters(self) # skip Qwen3VLVisionModel parameters
+ assert self.hparams_vision is not None
+ self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLM4V)
+
+ hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower()
+ if hidden_act == "gelu":
+ self.gguf_writer.add_vision_use_gelu(True)
+ elif hidden_act == "silu":
+ self.gguf_writer.add_vision_use_silu(True)
+
+ rms_norm_eps = self.hparams_vision.get("rms_norm_eps", 1e-5)
+ self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ if name.startswith("model.visual."):
+ name = name.replace("model.visual.", "visual.")
+ if name.startswith("visual.merger."):
+ return [(self.map_tensor_name(name), data_torch)]
+ return super().modify_tensors(data_torch, name, bid)
+
+
@ModelBase.register("Qwen3VLForConditionalGeneration")
class Qwen3VLTextModel(Qwen3Model):
model_arch = gguf.MODEL_ARCH.QWEN3VL
@@ -4381,20 +4426,6 @@ class Qwen3VLTextModel(Qwen3Model):
super().set_gguf_parameters()
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
- text_config = self.hparams.get("text_config", {})
- # rope_scaling is deprecated in V5, use rope_parameters instead
- rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
-
- if rope_scaling.get("mrope_section"):
- # mrope_section contains [time, height, width] dimensions
- mrope_section = rope_scaling["mrope_section"]
- # Pad to 4 dimensions [time, height, width, extra]
- while len(mrope_section) < 4:
- mrope_section.append(0)
- self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
-
- logger.info(f"MRoPE sections: {mrope_section[:4]}")
-
vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
@@ -4413,22 +4444,6 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
-
- # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
- text_config = self.hparams.get("text_config", {})
- # rope_scaling is deprecated in V5, use rope_parameters instead
- rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
-
- if rope_scaling.get("mrope_section"):
- # mrope_section contains [time, height, width] dimensions
- mrope_section = rope_scaling["mrope_section"]
- # Pad to 4 dimensions [time, height, width, extra]
- while len(mrope_section) < 4:
- mrope_section.append(0)
- self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
-
- logger.info(f"MRoPE sections: {mrope_section[:4]}")
-
vision_config = self.hparams.get("vision_config", {})
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
@@ -7791,6 +7806,15 @@ class JaisModel(TextModel):
@ModelBase.register("Glm4ForCausalLM", "Glm4vForConditionalGeneration")
class Glm4Model(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4
+ use_mrope = False
+ partial_rotary_factor = 0.5
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 0.5)
+ if "mrope_section" in self.rope_parameters:
+ self.use_mrope = True
+ logger.info("Q/K weight will need to be permuted for M-RoPE")
def set_vocab(self):
from transformers import AutoTokenizer
@@ -7812,17 +7836,49 @@ class Glm4Model(TextModel):
super().set_gguf_parameters()
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
- self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
+ self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.partial_rotary_factor))
+
+ @staticmethod
+ def normal_to_neox(weights: Tensor, n_head: int, n_head_kv: int, head_dim: int, partial_rotary_factor: float) -> Tensor:
+ orig_shape = weights.shape
+ if len(orig_shape) == 1:
+ weights = weights.unsqueeze(1) # [out_dim, 1]
+ if len(weights.shape) != 2:
+ raise ValueError("Only 1D and 2D tensors are supported.")
+ n_effective_heads = weights.shape[0] // head_dim
+ if n_head_kv is not None and n_effective_heads != n_head:
+ if n_effective_heads != n_head_kv:
+ raise AssertionError(f"Mismatch in effective heads: computed {n_effective_heads}, expected {n_head} or {n_head_kv}")
+ rotary_dim = int(head_dim * partial_rotary_factor)
+ if rotary_dim % 2 != 0:
+ raise ValueError("rotary_dim must be even.")
+ reshaped = weights.reshape(n_effective_heads, head_dim, -1)
+ rot_part = reshaped[:, :rotary_dim, :]
+ non_rot_part = reshaped[:, rotary_dim:, :]
+ permuted_rot = torch.cat((rot_part[:, ::2, :], rot_part[:, 1::2, :]), dim=1)
+ combined = torch.cat((permuted_rot, non_rot_part), dim=1)
+ result = combined.reshape(weights.shape)
+ return result if len(orig_shape) != 1 else result.squeeze(1)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.visual."): # ignore visual part of Glm4v
return []
elif name.startswith("model.language_model."):
name = name.replace("language_model.", "") # for Glm4v
+ if self.use_mrope:
+ n_head = self.hparams["num_attention_heads"]
+ n_kv_head = self.hparams["num_key_value_heads"]
+ n_embd = self.hparams["hidden_size"]
+ head_dim = n_embd // n_head
+ # because llama.cpp M-RoPE kernel only supports Neox ordering, we have to permute the weights here
+ if name.endswith(("q_proj.weight", "q_proj.bias")):
+ data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_head, head_dim, self.partial_rotary_factor)
+ if name.endswith(("k_proj.weight", "k_proj.bias")):
+ data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_kv_head, head_dim, self.partial_rotary_factor)
return super().modify_tensors(data_torch, name, bid)
-@ModelBase.register("Glm4MoeForCausalLM")
+@ModelBase.register("Glm4MoeForCausalLM", "Glm4vMoeForConditionalGeneration")
class Glm4MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.GLM4_MOE
@@ -7889,6 +7945,7 @@ class Glm4MoeModel(TextModel):
_experts: list[dict[str, Tensor]] | None = None
+ # note: unlike GLM4V non-MoE, we don't need to permute Q/K here since GLM4V_MOE uses Neox ordering already
def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:
@@ -8486,8 +8543,18 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
class NemotronHModel(GraniteHybridModel):
"""Hybrid mamba2/attention model from NVIDIA"""
model_arch = gguf.MODEL_ARCH.NEMOTRON_H
+ is_moe: bool = False
def __init__(self, *args, **kwargs):
+ # We have to determine the correct model architecture (MoE vs non-MoE) before
+ # calling the parent __init__. This is because the parent constructor
+ # uses self.model_arch to build the tensor name map, and all MoE-specific
+ # mappings would be missed if it were called with the default non-MoE arch.
+ hparams = ModelBase.load_hparams(args[0], self.is_mistral_format)
+ if "num_experts_per_tok" in hparams:
+ self.model_arch = gguf.MODEL_ARCH.NEMOTRON_H_MOE
+ self.is_moe = True
+
super().__init__(*args, **kwargs)
# Save the top-level head_dim for later
@@ -8499,9 +8566,11 @@ class NemotronHModel(GraniteHybridModel):
# Update the ssm / attn / mlp layers
# M: Mamba2, *: Attention, -: MLP
+ # MoE:
+ # M: Mamba2, *: Attention, E: Expert
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
- self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
+ self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == ("E" if self.is_moe else "-")]
def get_attn_layers(self):
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
@@ -8517,10 +8586,28 @@ class NemotronHModel(GraniteHybridModel):
# Set feed_forward_length
# NOTE: This will trigger an override warning. This is preferrable to
# duplicating all the parent logic
- n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
- self.gguf_writer.add_feed_forward_length([
- n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
- ])
+ if not self.is_moe:
+ n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
+ self.gguf_writer.add_feed_forward_length([
+ n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
+ ])
+ else:
+ moe_intermediate_size = self.hparams["moe_intermediate_size"]
+ self.gguf_writer.add_feed_forward_length([
+ moe_intermediate_size if i in self._mlp_layers else 0 for i in range(self.block_count)
+ ])
+ self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
+ self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
+ self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["moe_shared_expert_intermediate_size"])
+ self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
+ self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"])
+ self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
+ self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
+ self.gguf_writer.add_expert_group_count(self.hparams["n_group"])
+
+ # number of experts used per token (top-k)
+ if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
+ self.gguf_writer.add_expert_used_count(n_experts_used)
def set_vocab(self):
super().set_vocab()
@@ -8528,7 +8615,81 @@ class NemotronHModel(GraniteHybridModel):
# The tokenizer _does_ add a BOS token (via post_processor type
# TemplateProcessing) but does not set add_bos_token to true in the
# config, so we need to explicitly override it here.
- self.gguf_writer.add_add_bos_token(True)
+ if not self.is_moe:
+ self.gguf_writer.add_add_bos_token(True)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ if self.is_moe and bid is not None:
+ if name.endswith("mixer.gate.e_score_correction_bias"):
+ new_name = name.replace("e_score_correction_bias", "e_score_correction.bias")
+ mapped_name = self.map_tensor_name(new_name)
+ return [(mapped_name, data_torch)]
+
+ if name.endswith("mixer.dt_bias"):
+ new_name = name.replace("dt_bias", "dt.bias")
+ mapped_name = self.map_tensor_name(new_name)
+ return [(mapped_name, data_torch)]
+
+ if name.endswith("mixer.conv1d.weight"):
+ squeezed_data = data_torch.squeeze()
+ mapped_name = self.map_tensor_name(name)
+ return [(mapped_name, squeezed_data)]
+
+ if name.endswith("mixer.A_log"):
+ transformed_data = -torch.exp(data_torch)
+ reshaped_data = transformed_data.squeeze().reshape(-1, 1)
+ mapped_name = self.map_tensor_name(name)
+ return [(mapped_name, reshaped_data)]
+
+ if name.endswith("mixer.D"):
+ reshaped_data = data_torch.squeeze().reshape(-1, 1)
+ mapped_name = self.map_tensor_name(name)
+ return [(mapped_name, reshaped_data)]
+
+ if name.endswith("mixer.norm.weight"):
+ reshaped_data = data_torch.reshape(8, 512)
+ mapped_name = self.map_tensor_name(name)
+ return [(mapped_name, reshaped_data)]
+
+ if name.find("mixer.experts") != -1:
+ n_experts = self.hparams["n_routed_experts"]
+ assert bid is not None
+
+ if self._experts is None:
+ self._experts = [{} for _ in range(self.block_count)]
+
+ self._experts[bid][name] = data_torch
+
+ if len(self._experts[bid]) >= n_experts * 2:
+ # merge the experts into a single tensor
+ tensors: list[tuple[str, Tensor]] = []
+ for w_name in ["down_proj", "up_proj"]:
+ datas: list[Tensor] = []
+
+ for xid in range(n_experts):
+ ename = f"backbone.layers.{bid}.mixer.experts.{xid}.{w_name}.weight"
+ datas.append(self._experts[bid][ename])
+ del self._experts[bid][ename]
+
+ data_torch = torch.stack(datas, dim=0)
+ merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+ new_name = self.map_tensor_name(merged_name)
+ tensors.append((new_name, data_torch))
+
+ return tensors
+ else:
+ return []
+
+ return super().modify_tensors(data_torch, name, bid)
+
+ def prepare_tensors(self):
+ super().prepare_tensors()
+
+ if self._experts is not None:
+ # flatten `list[dict[str, Tensor]]` into `list[str]`
+ experts = [k for d in self._experts for k in d.keys()]
+ if len(experts) > 0:
+ raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("BailingMoeForCausalLM")
@@ -9563,12 +9724,12 @@ class LFM2Model(TextModel):
self._add_feed_forward_length()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
- is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
- if is_vision_tensor:
- # skip vision tensors
+ if self._is_vision_tensor(name) or self._is_audio_tensor(name):
+ # skip multimodal tensors
return []
- name = name.replace("language_model.", "")
+ name = name.replace("language_model.", "") # vision
+ name = name.replace("lfm.", "model.") # audio
# conv op requires 2d tensor
if 'conv.conv' in name:
@@ -9576,6 +9737,12 @@ class LFM2Model(TextModel):
return [(self.map_tensor_name(name), data_torch)]
+ def _is_vision_tensor(self, name: str) -> bool:
+ return "vision_tower" in name or "multi_modal_projector" in name
+
+ def _is_audio_tensor(self, name: str):
+ return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"])
+
@ModelBase.register("Lfm2MoeForCausalLM")
class LFM2MoeModel(TextModel):
@@ -9681,6 +9848,81 @@ class LFM2VLModel(MmprojModel):
return [] # skip other tensors
+@ModelBase.register("Lfm2AudioForConditionalGeneration")
+class LFM2AudioModel(MmprojModel):
+ has_vision_encoder = False
+ has_audio_encoder = True
+ model_name = "Lfm2AudioEncoder"
+
+ _batch_norm_tensors: list[dict[str, Tensor]] | None = None
+
+ def get_audio_config(self) -> dict[str, Any] | None:
+ return self.global_config.get("encoder")
+
+ def set_gguf_parameters(self):
+ assert self.hparams_audio is not None
+ self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"]
+ self.hparams_audio["intermediate_size"] = self.hparams_audio["d_model"]
+ self.hparams_audio["num_attention_heads"] = self.hparams_audio["n_heads"]
+ super().set_gguf_parameters()
+ self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LFM2A)
+ self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])
+ self.gguf_writer.add_audio_attention_layernorm_eps(1e-5)
+
+ def tensor_force_quant(self, name, new_name, bid, n_dims):
+ if ".conv" in name and ".weight" in name:
+ return gguf.GGMLQuantizationType.F32
+ return super().tensor_force_quant(name, new_name, bid, n_dims)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ # skip language model tensors
+ if name.startswith("lfm."):
+ return []
+
+ # for training only
+ if any(p in name for p in ["audio_loss_weight"]):
+ return []
+
+ # for audio output
+ if any(p in name for p in ["codebook_offsets", "depth_embeddings", "depth_linear", "depthformer"]):
+ return []
+
+ # fold running_mean, running_var and eps into weight and bias for batch_norm
+ if "batch_norm" in name:
+ if self._batch_norm_tensors is None:
+ self._batch_norm_tensors = [{} for _ in range(self.block_count)]
+ assert bid is not None
+ self._batch_norm_tensors[bid][name] = data_torch
+
+ if len(self._batch_norm_tensors[bid]) < 5:
+ return []
+
+ weight = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.weight"]
+ bias = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.bias"]
+ running_mean = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.running_mean"]
+ running_var = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.running_var"]
+ eps = 1e-5 # default value
+
+ a = weight / torch.sqrt(running_var + eps)
+ b = bias - running_mean * a
+ return [
+ (self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.weight"), a),
+ (self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.bias"), b),
+ ]
+
+ # reshape conv weights
+ if name.startswith("conformer.pre_encode.conv.") and name.endswith(".bias"):
+ data_torch = data_torch[:, None, None]
+ if "conv.depthwise_conv" in name and name.endswith(".weight"):
+ assert data_torch.shape[1] == 1
+ data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[2])
+ if "conv.pointwise_conv" in name and name.endswith(".weight"):
+ assert data_torch.shape[2] == 1
+ data_torch = data_torch.reshape(data_torch.shape[0], data_torch.shape[1])
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+
@ModelBase.register("SmallThinkerForCausalLM")
class SmallThinkerModel(TextModel):
model_arch = gguf.MODEL_ARCH.SMALLTHINKER
@@ -10323,8 +10565,8 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
- "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16",
- help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
+ "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="auto",
+ help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type",
)
parser.add_argument(
"--bigendian", action="store_true",
diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py
index b8f694e86c..5e8456a7ea 100755
--- a/convert_hf_to_gguf_update.py
+++ b/convert_hf_to_gguf_update.py
@@ -143,6 +143,7 @@ models = [
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
+ {"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
diff --git a/docs/android.md b/docs/android.md
index d2a835653f..964ce8a1f0 100644
--- a/docs/android.md
+++ b/docs/android.md
@@ -1,7 +1,27 @@
# Android
-## Build on Android using Termux
+## Build GUI binding using Android Studio
+
+Import the `examples/llama.android` directory into Android Studio, then perform a Gradle sync and build the project.
+
+
+This Android binding supports hardware acceleration up to `SME2` for **Arm** and `AMX` for **x86-64** CPUs on Android and ChromeOS devices.
+It automatically detects the host's hardware to load compatible kernels. As a result, it runs seamlessly on both the latest premium devices and older devices that may lack modern CPU features or have limited RAM, without requiring any manual configuration.
+
+A minimal Android app frontend is included to showcase the binding’s core functionalities:
+1. **Parse GGUF metadata** via `GgufMetadataReader` from either a `ContentResolver` provided `Uri` from shared storage, or a local `File` from your app's private storage.
+2. **Obtain a `InferenceEngine`** instance through the `AiChat` facade and load your selected model via its app-private file path.
+3. **Send a raw user prompt** for automatic template formatting, prefill, and batch decoding. Then collect the generated tokens in a Kotlin `Flow`.
+
+For a production-ready experience that leverages advanced features such as system prompts and benchmarks, plus friendly UI features such as model management and Arm feature visualizer, check out [Arm AI Chat](https://play.google.com/store/apps/details?id=com.arm.aichat) on Google Play.
+This project is made possible through a collaborative effort by Arm's **CT-ML**, **CE-ML** and **STE** groups:
+
+|  |  |  |
+|:------------------------------------------------------:|:----------------------------------------------------:|:--------------------------------------------------------:|
+| Home screen | System prompt | "Haiku" |
+
+## Build CLI on Android using Termux
[Termux](https://termux.dev/en/) is an Android terminal emulator and Linux environment app (no root required). As of writing, Termux is available experimentally in the Google Play Store; otherwise, it may be obtained directly from the project repo or on F-Droid.
@@ -32,7 +52,7 @@ To see what it might look like visually, here's an old demo of an interactive se
https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4
-## Cross-compile using Android NDK
+## Cross-compile CLI using Android NDK
It's possible to build `llama.cpp` for Android on your host system via CMake and the Android NDK. If you are interested in this path, ensure you already have an environment prepared to cross-compile programs for Android (i.e., install the Android SDK). Note that, unlike desktop environments, the Android environment ships with a limited set of native libraries, and so only those libraries are available to CMake when building with the Android NDK (see: https://developer.android.com/ndk/guides/stable_apis.)
Once you're ready and have cloned `llama.cpp`, invoke the following in the project directory:
diff --git a/docs/android/imported-into-android-studio.jpg b/docs/android/imported-into-android-studio.jpg
new file mode 100644
index 0000000000..bbe6867c6c
Binary files /dev/null and b/docs/android/imported-into-android-studio.jpg differ
diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md
index 02a72a9d51..f44458ed3b 100644
--- a/docs/backend/SYCL.md
+++ b/docs/backend/SYCL.md
@@ -103,6 +103,8 @@ SYCL backend supports Intel GPU Family:
- Intel Built-in Arc GPU
- Intel iGPU in Core CPU (11th Generation Core CPU and newer, refer to [oneAPI supported GPU](https://www.intel.com/content/www/us/en/developer/articles/system-requirements/intel-oneapi-base-toolkit-system-requirements.html#inpage-nav-1-1)).
+On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the performance is not optimal, and some GPUs may not support OpenCL nor have any GPGPU capabilities.
+
#### Verified devices
| Intel GPU | Status | Verified Model |
diff --git a/docs/backend/hexagon/CMakeUserPresets.json b/docs/backend/hexagon/CMakeUserPresets.json
index e0b19db0f5..98d7221b3a 100644
--- a/docs/backend/hexagon/CMakeUserPresets.json
+++ b/docs/backend/hexagon/CMakeUserPresets.json
@@ -22,6 +22,7 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON",
+ "GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_CURL": "OFF"
}
},
@@ -36,6 +37,7 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON",
+ "GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_CURL": "OFF"
}
},
diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md
index 9d1452e3f0..b6870f6e25 100644
--- a/docs/development/HOWTO-add-model.md
+++ b/docs/development/HOWTO-add-model.md
@@ -97,7 +97,7 @@ The model params and tensors layout must be defined in `llama.cpp` source files:
1. Define a new `llm_arch` enum value in `src/llama-arch.h`.
2. In `src/llama-arch.cpp`:
- Add the architecture name to the `LLM_ARCH_NAMES` map.
- - Add the tensor mappings to the `LLM_TENSOR_NAMES` map.
+ - Add the list of model tensors to `llm_get_tensor_names` (you may also need to update `LLM_TENSOR_NAMES`)
3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.
diff --git a/docs/development/parsing.md b/docs/development/parsing.md
index 113ab2e2ee..dbb989bf08 100644
--- a/docs/development/parsing.md
+++ b/docs/development/parsing.md
@@ -55,7 +55,7 @@ auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder &
```
For a more complete example, see `test_example_native()` in
-[tests/test-chat-peg-parser.cpp](tests/test-chat-peg-parser.cpp).
+[tests/test-chat-peg-parser.cpp](/tests/test-chat-peg-parser.cpp).
## Parsers/Combinators
@@ -175,7 +175,7 @@ Most model output can be placed in one of the following categories:
(Qwen3-Coder, MiniMax M2) or pseudo-function calls (LFM2)
To provide broad coverage,
-[`common/chat-peg-parser.h`](common/chat-peg-parser.h) contains builders and
+[`common/chat-peg-parser.h`](/common/chat-peg-parser.h) contains builders and
mappers that help create parsers and visitors/extractors for these types. They
require parsers to tag nodes to conform to an AST "shape". This normalization
makes it easy to extract information and generalize parsing.
diff --git a/docs/docker.md b/docs/docker.md
index b9e5015396..a3b263497c 100644
--- a/docs/docker.md
+++ b/docs/docker.md
@@ -7,9 +7,9 @@
## Images
We have three Docker images available for this project:
-1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
-2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
-3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the server executable file. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
+1. `ghcr.io/ggml-org/llama.cpp:full`: This image includes both the `llama-cli` and `llama-completion` executables and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
+2. `ghcr.io/ggml-org/llama.cpp:light`: This image only includes the `llama-cli` and `llama-completion` executables. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
+3. `ghcr.io/ggml-org/llama.cpp:server`: This image only includes the `llama-server` executable. (platforms: `linux/amd64`, `linux/arm64`, `linux/s390x`)
Additionally, there the following images, similar to the above:
@@ -44,13 +44,15 @@ docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --all-in-o
On completion, you are ready to play!
```bash
-docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512
+docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --run -m /models/7B/ggml-model-q4_0.gguf
+docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:full --run-legacy -m /models/32B/ggml-model-q8_0.gguf -no-cnv -p "Building a mobile app can be done in 15 steps:" -n 512
```
or with a light image:
```bash
-docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:light -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512
+docker run -v /path/to/models:/models --entrypoint /app/llama-cli ghcr.io/ggml-org/llama.cpp:light -m /models/7B/ggml-model-q4_0.gguf
+docker run -v /path/to/models:/models --entrypoint /app/llama-completion ghcr.io/ggml-org/llama.cpp:light -m /models/32B/ggml-model-q8_0.gguf -no-cnv -p "Building a mobile app can be done in 15 steps:" -n 512
```
or with a server image:
@@ -59,6 +61,8 @@ or with a server image:
docker run -v /path/to/models:/models -p 8080:8080 ghcr.io/ggml-org/llama.cpp:server -m /models/7B/ggml-model-q4_0.gguf --port 8080 --host 0.0.0.0 -n 512
```
+In the above examples, `--entrypoint /app/llama-cli` is specified for clarity, but you can safely omit it since it's the default entrypoint in the container.
+
## Docker With CUDA
Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-container-toolkit) properly installed on Linux, or is using a GPU enabled cloud, `cuBLAS` should be accessible inside the container.
@@ -80,9 +84,9 @@ The defaults are:
The resulting images, are essentially the same as the non-CUDA images:
-1. `local/llama.cpp:full-cuda`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization.
-2. `local/llama.cpp:light-cuda`: This image only includes the main executable file.
-3. `local/llama.cpp:server-cuda`: This image only includes the server executable file.
+1. `local/llama.cpp:full-cuda`: This image includes both the `llama-cli` and `llama-completion` executables and the tools to convert LLaMA models into ggml and convert into 4-bit quantization.
+2. `local/llama.cpp:light-cuda`: This image only includes the `llama-cli` and `llama-completion` executables.
+3. `local/llama.cpp:server-cuda`: This image only includes the `llama-server` executable.
## Usage
@@ -114,9 +118,9 @@ The defaults are:
The resulting images, are essentially the same as the non-MUSA images:
-1. `local/llama.cpp:full-musa`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization.
-2. `local/llama.cpp:light-musa`: This image only includes the main executable file.
-3. `local/llama.cpp:server-musa`: This image only includes the server executable file.
+1. `local/llama.cpp:full-musa`: This image includes both the `llama-cli` and `llama-completion` executables and the tools to convert LLaMA models into ggml and convert into 4-bit quantization.
+2. `local/llama.cpp:light-musa`: This image only includes the `llama-cli` and `llama-completion` executables.
+3. `local/llama.cpp:server-musa`: This image only includes the `llama-server` executable.
## Usage
diff --git a/examples/gen-docs/gen-docs.cpp b/examples/gen-docs/gen-docs.cpp
index e9f7bf9313..dc76c4cf53 100644
--- a/examples/gen-docs/gen-docs.cpp
+++ b/examples/gen-docs/gen-docs.cpp
@@ -48,7 +48,7 @@ static void write_table(std::ofstream & file, std::vector & opts)
}
}
-static void export_md(std::string fname, llama_example ex) {
+static void export_md(std::string fname, llama_example ex, std::string name) {
std::ofstream file(fname, std::ofstream::out | std::ofstream::trunc);
common_params params;
@@ -72,13 +72,14 @@ static void export_md(std::string fname, llama_example ex) {
write_table(file, common_options);
file << "\n\n**Sampling params**\n\n";
write_table(file, sparam_options);
- file << "\n\n**Example-specific params**\n\n";
+ file << "\n\n**" << name << "-specific params**\n\n";
write_table(file, specific_options);
}
int main(int, char **) {
- export_md("autogen-main.md", LLAMA_EXAMPLE_COMPLETION);
- export_md("autogen-server.md", LLAMA_EXAMPLE_SERVER);
+ // TODO: add CLI
+ export_md("autogen-completion.md", LLAMA_EXAMPLE_COMPLETION, "Tool");
+ export_md("autogen-server.md", LLAMA_EXAMPLE_SERVER, "Server");
return 0;
}
diff --git a/examples/llama.android/README.md b/examples/llama.android/README.md
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/examples/llama.android/app/build.gradle.kts b/examples/llama.android/app/build.gradle.kts
index 8d1b37195e..3524fe39c4 100644
--- a/examples/llama.android/app/build.gradle.kts
+++ b/examples/llama.android/app/build.gradle.kts
@@ -1,16 +1,18 @@
plugins {
- id("com.android.application")
- id("org.jetbrains.kotlin.android")
+ alias(libs.plugins.android.application)
+ alias(libs.plugins.jetbrains.kotlin.android)
}
android {
namespace = "com.example.llama"
- compileSdk = 34
+ compileSdk = 36
defaultConfig {
- applicationId = "com.example.llama"
+ applicationId = "com.example.llama.aichat"
+
minSdk = 33
- targetSdk = 34
+ targetSdk = 36
+
versionCode = 1
versionName = "1.0"
@@ -21,8 +23,17 @@ android {
}
buildTypes {
+ debug {
+ isMinifyEnabled = true
+ isShrinkResources = true
+ proguardFiles(
+ getDefaultProguardFile("proguard-android.txt"),
+ "proguard-rules.pro"
+ )
+ }
release {
- isMinifyEnabled = false
+ isMinifyEnabled = true
+ isShrinkResources = true
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
@@ -36,30 +47,15 @@ android {
kotlinOptions {
jvmTarget = "1.8"
}
- buildFeatures {
- compose = true
- }
- composeOptions {
- kotlinCompilerExtensionVersion = "1.5.1"
- }
}
dependencies {
+ implementation(libs.bundles.androidx)
+ implementation(libs.material)
- implementation("androidx.core:core-ktx:1.12.0")
- implementation("androidx.lifecycle:lifecycle-runtime-ktx:2.6.2")
- implementation("androidx.activity:activity-compose:1.8.2")
- implementation(platform("androidx.compose:compose-bom:2023.08.00"))
- implementation("androidx.compose.ui:ui")
- implementation("androidx.compose.ui:ui-graphics")
- implementation("androidx.compose.ui:ui-tooling-preview")
- implementation("androidx.compose.material3:material3")
- implementation(project(":llama"))
- testImplementation("junit:junit:4.13.2")
- androidTestImplementation("androidx.test.ext:junit:1.1.5")
- androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
- androidTestImplementation(platform("androidx.compose:compose-bom:2023.08.00"))
- androidTestImplementation("androidx.compose.ui:ui-test-junit4")
- debugImplementation("androidx.compose.ui:ui-tooling")
- debugImplementation("androidx.compose.ui:ui-test-manifest")
+ implementation(project(":lib"))
+
+ testImplementation(libs.junit)
+ androidTestImplementation(libs.androidx.junit)
+ androidTestImplementation(libs.androidx.espresso.core)
}
diff --git a/examples/llama.android/app/proguard-rules.pro b/examples/llama.android/app/proguard-rules.pro
index f1b424510d..358020d2d2 100644
--- a/examples/llama.android/app/proguard-rules.pro
+++ b/examples/llama.android/app/proguard-rules.pro
@@ -19,3 +19,11 @@
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
+
+-keep class com.arm.aichat.* { *; }
+-keep class com.arm.aichat.gguf.* { *; }
+
+-assumenosideeffects class android.util.Log {
+ public static int v(...);
+ public static int d(...);
+}
diff --git a/examples/llama.android/app/src/main/AndroidManifest.xml b/examples/llama.android/app/src/main/AndroidManifest.xml
index 41a358a299..8f7c606b41 100644
--- a/examples/llama.android/app/src/main/AndroidManifest.xml
+++ b/examples/llama.android/app/src/main/AndroidManifest.xml
@@ -1,24 +1,21 @@
-
-
-
+
+ android:exported="true">
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt b/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt
deleted file mode 100644
index 78c231ae55..0000000000
--- a/examples/llama.android/app/src/main/java/com/example/llama/Downloadable.kt
+++ /dev/null
@@ -1,119 +0,0 @@
-package com.example.llama
-
-import android.app.DownloadManager
-import android.net.Uri
-import android.util.Log
-import androidx.compose.material3.Button
-import androidx.compose.material3.Text
-import androidx.compose.runtime.Composable
-import androidx.compose.runtime.getValue
-import androidx.compose.runtime.mutableDoubleStateOf
-import androidx.compose.runtime.mutableStateOf
-import androidx.compose.runtime.remember
-import androidx.compose.runtime.rememberCoroutineScope
-import androidx.compose.runtime.setValue
-import androidx.core.database.getLongOrNull
-import androidx.core.net.toUri
-import kotlinx.coroutines.delay
-import kotlinx.coroutines.launch
-import java.io.File
-
-data class Downloadable(val name: String, val source: Uri, val destination: File) {
- companion object {
- @JvmStatic
- private val tag: String? = this::class.qualifiedName
-
- sealed interface State
- data object Ready: State
- data class Downloading(val id: Long): State
- data class Downloaded(val downloadable: Downloadable): State
- data class Error(val message: String): State
-
- @JvmStatic
- @Composable
- fun Button(viewModel: MainViewModel, dm: DownloadManager, item: Downloadable) {
- var status: State by remember {
- mutableStateOf(
- if (item.destination.exists()) Downloaded(item)
- else Ready
- )
- }
- var progress by remember { mutableDoubleStateOf(0.0) }
-
- val coroutineScope = rememberCoroutineScope()
-
- suspend fun waitForDownload(result: Downloading, item: Downloadable): State {
- while (true) {
- val cursor = dm.query(DownloadManager.Query().setFilterById(result.id))
-
- if (cursor == null) {
- Log.e(tag, "dm.query() returned null")
- return Error("dm.query() returned null")
- }
-
- if (!cursor.moveToFirst() || cursor.count < 1) {
- cursor.close()
- Log.i(tag, "cursor.moveToFirst() returned false or cursor.count < 1, download canceled?")
- return Ready
- }
-
- val pix = cursor.getColumnIndex(DownloadManager.COLUMN_BYTES_DOWNLOADED_SO_FAR)
- val tix = cursor.getColumnIndex(DownloadManager.COLUMN_TOTAL_SIZE_BYTES)
- val sofar = cursor.getLongOrNull(pix) ?: 0
- val total = cursor.getLongOrNull(tix) ?: 1
- cursor.close()
-
- if (sofar == total) {
- return Downloaded(item)
- }
-
- progress = (sofar * 1.0) / total
-
- delay(1000L)
- }
- }
-
- fun onClick() {
- when (val s = status) {
- is Downloaded -> {
- viewModel.load(item.destination.path)
- }
-
- is Downloading -> {
- coroutineScope.launch {
- status = waitForDownload(s, item)
- }
- }
-
- else -> {
- item.destination.delete()
-
- val request = DownloadManager.Request(item.source).apply {
- setTitle("Downloading model")
- setDescription("Downloading model: ${item.name}")
- setAllowedNetworkTypes(DownloadManager.Request.NETWORK_WIFI)
- setDestinationUri(item.destination.toUri())
- }
-
- viewModel.log("Saving ${item.name} to ${item.destination.path}")
- Log.i(tag, "Saving ${item.name} to ${item.destination.path}")
-
- val id = dm.enqueue(request)
- status = Downloading(id)
- onClick()
- }
- }
- }
-
- Button(onClick = { onClick() }, enabled = status !is Downloading) {
- when (status) {
- is Downloading -> Text(text = "Downloading ${(progress * 100).toInt()}%")
- is Downloaded -> Text("Load ${item.name}")
- is Ready -> Text("Download ${item.name}")
- is Error -> Text("Download ${item.name}")
- }
- }
- }
-
- }
-}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
index 9da04f7d3c..52c5dc2154 100644
--- a/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
+++ b/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
@@ -1,154 +1,257 @@
package com.example.llama
-import android.app.ActivityManager
-import android.app.DownloadManager
-import android.content.ClipData
-import android.content.ClipboardManager
import android.net.Uri
import android.os.Bundle
-import android.os.StrictMode
-import android.os.StrictMode.VmPolicy
-import android.text.format.Formatter
-import androidx.activity.ComponentActivity
-import androidx.activity.compose.setContent
-import androidx.activity.viewModels
-import androidx.compose.foundation.layout.Box
-import androidx.compose.foundation.layout.Column
-import androidx.compose.foundation.layout.Row
-import androidx.compose.foundation.layout.fillMaxSize
-import androidx.compose.foundation.layout.padding
-import androidx.compose.foundation.lazy.LazyColumn
-import androidx.compose.foundation.lazy.items
-import androidx.compose.foundation.lazy.rememberLazyListState
-import androidx.compose.material3.Button
-import androidx.compose.material3.LocalContentColor
-import androidx.compose.material3.MaterialTheme
-import androidx.compose.material3.OutlinedTextField
-import androidx.compose.material3.Surface
-import androidx.compose.material3.Text
-import androidx.compose.runtime.Composable
-import androidx.compose.ui.Modifier
-import androidx.compose.ui.unit.dp
-import androidx.core.content.getSystemService
-import com.example.llama.ui.theme.LlamaAndroidTheme
+import android.util.Log
+import android.widget.EditText
+import android.widget.TextView
+import android.widget.Toast
+import androidx.activity.enableEdgeToEdge
+import androidx.activity.result.contract.ActivityResultContracts
+import androidx.appcompat.app.AppCompatActivity
+import androidx.lifecycle.lifecycleScope
+import androidx.recyclerview.widget.LinearLayoutManager
+import androidx.recyclerview.widget.RecyclerView
+import com.arm.aichat.AiChat
+import com.arm.aichat.InferenceEngine
+import com.arm.aichat.gguf.GgufMetadata
+import com.arm.aichat.gguf.GgufMetadataReader
+import com.google.android.material.floatingactionbutton.FloatingActionButton
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.flow.onCompletion
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.withContext
import java.io.File
+import java.io.FileOutputStream
+import java.io.InputStream
+import java.util.UUID
-class MainActivity(
- activityManager: ActivityManager? = null,
- downloadManager: DownloadManager? = null,
- clipboardManager: ClipboardManager? = null,
-): ComponentActivity() {
- private val tag: String? = this::class.simpleName
+class MainActivity : AppCompatActivity() {
- private val activityManager by lazy { activityManager ?: getSystemService()!! }
- private val downloadManager by lazy { downloadManager ?: getSystemService()!! }
- private val clipboardManager by lazy { clipboardManager ?: getSystemService()!! }
+ // Android views
+ private lateinit var ggufTv: TextView
+ private lateinit var messagesRv: RecyclerView
+ private lateinit var userInputEt: EditText
+ private lateinit var userActionFab: FloatingActionButton
- private val viewModel: MainViewModel by viewModels()
+ // Arm AI Chat inference engine
+ private lateinit var engine: InferenceEngine
- // Get a MemoryInfo object for the device's current memory status.
- private fun availableMemory(): ActivityManager.MemoryInfo {
- return ActivityManager.MemoryInfo().also { memoryInfo ->
- activityManager.getMemoryInfo(memoryInfo)
- }
- }
+ // Conversation states
+ private var isModelReady = false
+ private val messages = mutableListOf()
+ private val lastAssistantMsg = StringBuilder()
+ private val messageAdapter = MessageAdapter(messages)
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
+ enableEdgeToEdge()
+ setContentView(R.layout.activity_main)
- StrictMode.setVmPolicy(
- VmPolicy.Builder(StrictMode.getVmPolicy())
- .detectLeakedClosableObjects()
- .build()
- )
+ // Find views
+ ggufTv = findViewById(R.id.gguf)
+ messagesRv = findViewById(R.id.messages)
+ messagesRv.layoutManager = LinearLayoutManager(this)
+ messagesRv.adapter = messageAdapter
+ userInputEt = findViewById(R.id.user_input)
+ userActionFab = findViewById(R.id.fab)
- val free = Formatter.formatFileSize(this, availableMemory().availMem)
- val total = Formatter.formatFileSize(this, availableMemory().totalMem)
-
- viewModel.log("Current memory: $free / $total")
- viewModel.log("Downloads directory: ${getExternalFilesDir(null)}")
-
- val extFilesDir = getExternalFilesDir(null)
-
- val models = listOf(
- Downloadable(
- "Phi-2 7B (Q4_0, 1.6 GiB)",
- Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true"),
- File(extFilesDir, "phi-2-q4_0.gguf"),
- ),
- Downloadable(
- "TinyLlama 1.1B (f16, 2.2 GiB)",
- Uri.parse("https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true"),
- File(extFilesDir, "tinyllama-1.1-f16.gguf"),
- ),
- Downloadable(
- "Phi 2 DPO (Q3_K_M, 1.48 GiB)",
- Uri.parse("https://huggingface.co/TheBloke/phi-2-dpo-GGUF/resolve/main/phi-2-dpo.Q3_K_M.gguf?download=true"),
- File(extFilesDir, "phi-2-dpo.Q3_K_M.gguf")
- ),
- )
-
- setContent {
- LlamaAndroidTheme {
- // A surface container using the 'background' color from the theme
- Surface(
- modifier = Modifier.fillMaxSize(),
- color = MaterialTheme.colorScheme.background
- ) {
- MainCompose(
- viewModel,
- clipboardManager,
- downloadManager,
- models,
- )
- }
+ // Arm AI Chat initialization
+ lifecycleScope.launch(Dispatchers.Default) {
+ engine = AiChat.getInferenceEngine(applicationContext)
+ }
+ // Upon CTA button tapped
+ userActionFab.setOnClickListener {
+ if (isModelReady) {
+ // If model is ready, validate input and send to engine
+ handleUserInput()
+ } else {
+ // Otherwise, prompt user to select a GGUF metadata on the device
+ getContent.launch(arrayOf("*/*"))
}
}
}
-}
-@Composable
-fun MainCompose(
- viewModel: MainViewModel,
- clipboard: ClipboardManager,
- dm: DownloadManager,
- models: List
-) {
- Column {
- val scrollState = rememberLazyListState()
+ private val getContent = registerForActivityResult(
+ ActivityResultContracts.OpenDocument()
+ ) { uri ->
+ Log.i(TAG, "Selected file uri:\n $uri")
+ uri?.let { handleSelectedModel(it) }
+ }
- Box(modifier = Modifier.weight(1f)) {
- LazyColumn(state = scrollState) {
- items(viewModel.messages) {
- Text(
- it,
- style = MaterialTheme.typography.bodyLarge.copy(color = LocalContentColor.current),
- modifier = Modifier.padding(16.dp)
- )
+ /**
+ * Handles the file Uri from [getContent] result
+ */
+ private fun handleSelectedModel(uri: Uri) {
+ // Update UI states
+ userActionFab.isEnabled = false
+ userInputEt.hint = "Parsing GGUF..."
+ ggufTv.text = "Parsing metadata from selected file \n$uri"
+
+ lifecycleScope.launch(Dispatchers.IO) {
+ // Parse GGUF metadata
+ Log.i(TAG, "Parsing GGUF metadata...")
+ contentResolver.openInputStream(uri)?.use {
+ GgufMetadataReader.create().readStructuredMetadata(it)
+ }?.let { metadata ->
+ // Update UI to show GGUF metadata to user
+ Log.i(TAG, "GGUF parsed: \n$metadata")
+ withContext(Dispatchers.Main) {
+ ggufTv.text = metadata.toString()
}
- }
- }
- OutlinedTextField(
- value = viewModel.message,
- onValueChange = { viewModel.updateMessage(it) },
- label = { Text("Message") },
- )
- Row {
- Button({ viewModel.send() }) { Text("Send") }
- Button({ viewModel.bench(8, 4, 1) }) { Text("Bench") }
- Button({ viewModel.clear() }) { Text("Clear") }
- Button({
- viewModel.messages.joinToString("\n").let {
- clipboard.setPrimaryClip(ClipData.newPlainText("", it))
- }
- }) { Text("Copy") }
- }
- Column {
- for (model in models) {
- Downloadable.Button(viewModel, dm, model)
+ // Ensure the model file is available
+ val modelName = metadata.filename() + FILE_EXTENSION_GGUF
+ contentResolver.openInputStream(uri)?.use { input ->
+ ensureModelFile(modelName, input)
+ }?.let { modelFile ->
+ loadModel(modelName, modelFile)
+
+ withContext(Dispatchers.Main) {
+ isModelReady = true
+ userInputEt.hint = "Type and send a message!"
+ userInputEt.isEnabled = true
+ userActionFab.setImageResource(R.drawable.outline_send_24)
+ userActionFab.isEnabled = true
+ }
+ }
}
}
}
+
+ /**
+ * Prepare the model file within app's private storage
+ */
+ private suspend fun ensureModelFile(modelName: String, input: InputStream) =
+ withContext(Dispatchers.IO) {
+ File(ensureModelsDirectory(), modelName).also { file ->
+ // Copy the file into local storage if not yet done
+ if (!file.exists()) {
+ Log.i(TAG, "Start copying file to $modelName")
+ withContext(Dispatchers.Main) {
+ userInputEt.hint = "Copying file..."
+ }
+
+ FileOutputStream(file).use { input.copyTo(it) }
+ Log.i(TAG, "Finished copying file to $modelName")
+ } else {
+ Log.i(TAG, "File already exists $modelName")
+ }
+ }
+ }
+
+ /**
+ * Load the model file from the app private storage
+ */
+ private suspend fun loadModel(modelName: String, modelFile: File) =
+ withContext(Dispatchers.IO) {
+ Log.i(TAG, "Loading model $modelName")
+ withContext(Dispatchers.Main) {
+ userInputEt.hint = "Loading model..."
+ }
+ engine.loadModel(modelFile.path)
+ }
+
+ /**
+ * Validate and send the user message into [InferenceEngine]
+ */
+ private fun handleUserInput() {
+ userInputEt.text.toString().also { userSsg ->
+ if (userSsg.isEmpty()) {
+ Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
+ } else {
+ userInputEt.text = null
+ userActionFab.isEnabled = false
+
+ // Update message states
+ messages.add(Message(UUID.randomUUID().toString(), userSsg, true))
+ lastAssistantMsg.clear()
+ messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
+
+ lifecycleScope.launch(Dispatchers.Default) {
+ engine.sendUserPrompt(userSsg)
+ .onCompletion {
+ withContext(Dispatchers.Main) {
+ userActionFab.isEnabled = true
+ }
+ }.collect { token ->
+ val messageCount = messages.size
+ check(messageCount > 0 && !messages[messageCount - 1].isUser)
+
+ messages.removeAt(messageCount - 1).copy(
+ content = lastAssistantMsg.append(token).toString()
+ ).let { messages.add(it) }
+
+ withContext(Dispatchers.Main) {
+ messageAdapter.notifyItemChanged(messages.size - 1)
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Run a benchmark with the model file
+ */
+ private suspend fun runBenchmark(modelName: String, modelFile: File) =
+ withContext(Dispatchers.Default) {
+ Log.i(TAG, "Starts benchmarking $modelName")
+ withContext(Dispatchers.Main) {
+ userInputEt.hint = "Running benchmark..."
+ }
+ engine.bench(
+ pp=BENCH_PROMPT_PROCESSING_TOKENS,
+ tg=BENCH_TOKEN_GENERATION_TOKENS,
+ pl=BENCH_SEQUENCE,
+ nr=BENCH_REPETITION
+ ).let { result ->
+ messages.add(Message(UUID.randomUUID().toString(), result, false))
+ withContext(Dispatchers.Main) {
+ messageAdapter.notifyItemChanged(messages.size - 1)
+ }
+ }
+ }
+
+ /**
+ * Create the `models` directory if not exist.
+ */
+ private fun ensureModelsDirectory() =
+ File(filesDir, DIRECTORY_MODELS).also {
+ if (it.exists() && !it.isDirectory) { it.delete() }
+ if (!it.exists()) { it.mkdir() }
+ }
+
+ companion object {
+ private val TAG = MainActivity::class.java.simpleName
+
+ private const val DIRECTORY_MODELS = "models"
+ private const val FILE_EXTENSION_GGUF = ".gguf"
+
+ private const val BENCH_PROMPT_PROCESSING_TOKENS = 512
+ private const val BENCH_TOKEN_GENERATION_TOKENS = 128
+ private const val BENCH_SEQUENCE = 1
+ private const val BENCH_REPETITION = 3
+ }
+}
+
+fun GgufMetadata.filename() = when {
+ basic.name != null -> {
+ basic.name?.let { name ->
+ basic.sizeLabel?.let { size ->
+ "$name-$size"
+ } ?: name
+ }
+ }
+ architecture?.architecture != null -> {
+ architecture?.architecture?.let { arch ->
+ basic.uuid?.let { uuid ->
+ "$arch-$uuid"
+ } ?: "$arch-${System.currentTimeMillis()}"
+ }
+ }
+ else -> {
+ "model-${System.currentTimeMillis().toHexString()}"
+ }
}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt
deleted file mode 100644
index 45ac29938f..0000000000
--- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt
+++ /dev/null
@@ -1,105 +0,0 @@
-package com.example.llama
-
-import android.llama.cpp.LLamaAndroid
-import android.util.Log
-import androidx.compose.runtime.getValue
-import androidx.compose.runtime.mutableStateOf
-import androidx.compose.runtime.setValue
-import androidx.lifecycle.ViewModel
-import androidx.lifecycle.viewModelScope
-import kotlinx.coroutines.flow.catch
-import kotlinx.coroutines.launch
-
-class MainViewModel(private val llamaAndroid: LLamaAndroid = LLamaAndroid.instance()): ViewModel() {
- companion object {
- @JvmStatic
- private val NanosPerSecond = 1_000_000_000.0
- }
-
- private val tag: String? = this::class.simpleName
-
- var messages by mutableStateOf(listOf("Initializing..."))
- private set
-
- var message by mutableStateOf("")
- private set
-
- override fun onCleared() {
- super.onCleared()
-
- viewModelScope.launch {
- try {
- llamaAndroid.unload()
- } catch (exc: IllegalStateException) {
- messages += exc.message!!
- }
- }
- }
-
- fun send() {
- val text = message
- message = ""
-
- // Add to messages console.
- messages += text
- messages += ""
-
- viewModelScope.launch {
- llamaAndroid.send(text)
- .catch {
- Log.e(tag, "send() failed", it)
- messages += it.message!!
- }
- .collect { messages = messages.dropLast(1) + (messages.last() + it) }
- }
- }
-
- fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) {
- viewModelScope.launch {
- try {
- val start = System.nanoTime()
- val warmupResult = llamaAndroid.bench(pp, tg, pl, nr)
- val end = System.nanoTime()
-
- messages += warmupResult
-
- val warmup = (end - start).toDouble() / NanosPerSecond
- messages += "Warm up time: $warmup seconds, please wait..."
-
- if (warmup > 5.0) {
- messages += "Warm up took too long, aborting benchmark"
- return@launch
- }
-
- messages += llamaAndroid.bench(512, 128, 1, 3)
- } catch (exc: IllegalStateException) {
- Log.e(tag, "bench() failed", exc)
- messages += exc.message!!
- }
- }
- }
-
- fun load(pathToModel: String) {
- viewModelScope.launch {
- try {
- llamaAndroid.load(pathToModel)
- messages += "Loaded $pathToModel"
- } catch (exc: IllegalStateException) {
- Log.e(tag, "load() failed", exc)
- messages += exc.message!!
- }
- }
- }
-
- fun updateMessage(newMessage: String) {
- message = newMessage
- }
-
- fun clear() {
- messages = listOf()
- }
-
- fun log(message: String) {
- messages += message
- }
-}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt b/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt
new file mode 100644
index 0000000000..0439f96441
--- /dev/null
+++ b/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt
@@ -0,0 +1,51 @@
+package com.example.llama
+
+import android.view.LayoutInflater
+import android.view.View
+import android.view.ViewGroup
+import android.widget.TextView
+import androidx.recyclerview.widget.RecyclerView
+
+data class Message(
+ val id: String,
+ val content: String,
+ val isUser: Boolean
+)
+
+class MessageAdapter(
+ private val messages: List
+) : RecyclerView.Adapter() {
+
+ companion object {
+ private const val VIEW_TYPE_USER = 1
+ private const val VIEW_TYPE_ASSISTANT = 2
+ }
+
+ override fun getItemViewType(position: Int): Int {
+ return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT
+ }
+
+ override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder {
+ val layoutInflater = LayoutInflater.from(parent.context)
+ return if (viewType == VIEW_TYPE_USER) {
+ val view = layoutInflater.inflate(R.layout.item_message_user, parent, false)
+ UserMessageViewHolder(view)
+ } else {
+ val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false)
+ AssistantMessageViewHolder(view)
+ }
+ }
+
+ override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) {
+ val message = messages[position]
+ if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) {
+ val textView = holder.itemView.findViewById(R.id.msg_content)
+ textView.text = message.content
+ }
+ }
+
+ override fun getItemCount(): Int = messages.size
+
+ class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
+ class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
+}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt
deleted file mode 100644
index 40c30e8d97..0000000000
--- a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Color.kt
+++ /dev/null
@@ -1,11 +0,0 @@
-package com.example.llama.ui.theme
-
-import androidx.compose.ui.graphics.Color
-
-val Purple80 = Color(0xFFD0BCFF)
-val PurpleGrey80 = Color(0xFFCCC2DC)
-val Pink80 = Color(0xFFEFB8C8)
-
-val Purple40 = Color(0xFF6650a4)
-val PurpleGrey40 = Color(0xFF625b71)
-val Pink40 = Color(0xFF7D5260)
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt
deleted file mode 100644
index e742220a8d..0000000000
--- a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Theme.kt
+++ /dev/null
@@ -1,70 +0,0 @@
-package com.example.llama.ui.theme
-
-import android.app.Activity
-import android.os.Build
-import androidx.compose.foundation.isSystemInDarkTheme
-import androidx.compose.material3.MaterialTheme
-import androidx.compose.material3.darkColorScheme
-import androidx.compose.material3.dynamicDarkColorScheme
-import androidx.compose.material3.dynamicLightColorScheme
-import androidx.compose.material3.lightColorScheme
-import androidx.compose.runtime.Composable
-import androidx.compose.runtime.SideEffect
-import androidx.compose.ui.graphics.toArgb
-import androidx.compose.ui.platform.LocalContext
-import androidx.compose.ui.platform.LocalView
-import androidx.core.view.WindowCompat
-
-private val DarkColorScheme = darkColorScheme(
- primary = Purple80,
- secondary = PurpleGrey80,
- tertiary = Pink80
-)
-
-private val LightColorScheme = lightColorScheme(
- primary = Purple40,
- secondary = PurpleGrey40,
- tertiary = Pink40
-
- /* Other default colors to override
- background = Color(0xFFFFFBFE),
- surface = Color(0xFFFFFBFE),
- onPrimary = Color.White,
- onSecondary = Color.White,
- onTertiary = Color.White,
- onBackground = Color(0xFF1C1B1F),
- onSurface = Color(0xFF1C1B1F),
- */
-)
-
-@Composable
-fun LlamaAndroidTheme(
- darkTheme: Boolean = isSystemInDarkTheme(),
- // Dynamic color is available on Android 12+
- dynamicColor: Boolean = true,
- content: @Composable () -> Unit
-) {
- val colorScheme = when {
- dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {
- val context = LocalContext.current
- if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)
- }
-
- darkTheme -> DarkColorScheme
- else -> LightColorScheme
- }
- val view = LocalView.current
- if (!view.isInEditMode) {
- SideEffect {
- val window = (view.context as Activity).window
- window.statusBarColor = colorScheme.primary.toArgb()
- WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme
- }
- }
-
- MaterialTheme(
- colorScheme = colorScheme,
- typography = Typography,
- content = content
- )
-}
diff --git a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt b/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt
deleted file mode 100644
index 0b87946ca3..0000000000
--- a/examples/llama.android/app/src/main/java/com/example/llama/ui/theme/Type.kt
+++ /dev/null
@@ -1,34 +0,0 @@
-package com.example.llama.ui.theme
-
-import androidx.compose.material3.Typography
-import androidx.compose.ui.text.TextStyle
-import androidx.compose.ui.text.font.FontFamily
-import androidx.compose.ui.text.font.FontWeight
-import androidx.compose.ui.unit.sp
-
-// Set of Material typography styles to start with
-val Typography = Typography(
- bodyLarge = TextStyle(
- fontFamily = FontFamily.Default,
- fontWeight = FontWeight.Normal,
- fontSize = 16.sp,
- lineHeight = 24.sp,
- letterSpacing = 0.5.sp
- )
- /* Other default text styles to override
- titleLarge = TextStyle(
- fontFamily = FontFamily.Default,
- fontWeight = FontWeight.Normal,
- fontSize = 22.sp,
- lineHeight = 28.sp,
- letterSpacing = 0.sp
- ),
- labelSmall = TextStyle(
- fontFamily = FontFamily.Default,
- fontWeight = FontWeight.Medium,
- fontSize = 11.sp,
- lineHeight = 16.sp,
- letterSpacing = 0.5.sp
- )
- */
-)
diff --git a/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml b/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml
new file mode 100644
index 0000000000..f90c3db458
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml
@@ -0,0 +1,4 @@
+
+
+
+
diff --git a/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml b/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml
new file mode 100644
index 0000000000..3ca7daefec
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml
@@ -0,0 +1,4 @@
+
+
+
+
diff --git a/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml b/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml
new file mode 100644
index 0000000000..f58b501e3b
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml
@@ -0,0 +1,10 @@
+
+
+
diff --git a/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml b/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml
new file mode 100644
index 0000000000..712adc00c4
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml
@@ -0,0 +1,11 @@
+
+
+
diff --git a/examples/llama.android/app/src/main/res/layout/activity_main.xml b/examples/llama.android/app/src/main/res/layout/activity_main.xml
new file mode 100644
index 0000000000..ad805a674e
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/layout/activity_main.xml
@@ -0,0 +1,78 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml b/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml
new file mode 100644
index 0000000000..2c8e4bc2e1
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml
@@ -0,0 +1,16 @@
+
+
+
+
+
diff --git a/examples/llama.android/app/src/main/res/layout/item_message_user.xml b/examples/llama.android/app/src/main/res/layout/item_message_user.xml
new file mode 100644
index 0000000000..5aa79f2df3
--- /dev/null
+++ b/examples/llama.android/app/src/main/res/layout/item_message_user.xml
@@ -0,0 +1,16 @@
+
+
+
+
+
diff --git a/examples/llama.android/app/src/main/res/values/strings.xml b/examples/llama.android/app/src/main/res/values/strings.xml
index 7a9d314e29..36059fc799 100644
--- a/examples/llama.android/app/src/main/res/values/strings.xml
+++ b/examples/llama.android/app/src/main/res/values/strings.xml
@@ -1,3 +1,3 @@
- LlamaAndroid
+ AI Chat basic sample
diff --git a/examples/llama.android/app/src/main/res/values/themes.xml b/examples/llama.android/app/src/main/res/values/themes.xml
index 8a24fda566..2e4fdad72e 100644
--- a/examples/llama.android/app/src/main/res/values/themes.xml
+++ b/examples/llama.android/app/src/main/res/values/themes.xml
@@ -1,5 +1,10 @@
-
+
+
+
diff --git a/examples/llama.android/build.gradle.kts b/examples/llama.android/build.gradle.kts
index acd1ada7d9..076a0f1c9a 100644
--- a/examples/llama.android/build.gradle.kts
+++ b/examples/llama.android/build.gradle.kts
@@ -1,6 +1,6 @@
// Top-level build file where you can add configuration options common to all sub-projects/modules.
plugins {
- id("com.android.application") version "8.2.0" apply false
- id("org.jetbrains.kotlin.android") version "1.9.0" apply false
- id("com.android.library") version "8.2.0" apply false
+ alias(libs.plugins.android.application) apply false
+ alias(libs.plugins.android.library) apply false
+ alias(libs.plugins.jetbrains.kotlin.android) apply false
}
diff --git a/examples/llama.android/gradle.properties b/examples/llama.android/gradle.properties
index 2cbd6d19d3..8888cc9c51 100644
--- a/examples/llama.android/gradle.properties
+++ b/examples/llama.android/gradle.properties
@@ -21,3 +21,4 @@ kotlin.code.style=official
# resources declared in the library itself and none from the library's dependencies,
# thereby reducing the size of the R class for that library
android.nonTransitiveRClass=true
+android.native.buildOutput=verbose
diff --git a/examples/llama.android/gradle/libs.versions.toml b/examples/llama.android/gradle/libs.versions.toml
new file mode 100644
index 0000000000..df32a75661
--- /dev/null
+++ b/examples/llama.android/gradle/libs.versions.toml
@@ -0,0 +1,53 @@
+[versions]
+
+# Plugins
+agp = "8.13.0"
+kotlin = "2.2.20"
+
+# AndroidX
+activity = "1.11.0"
+appcompat = "1.7.1"
+core-ktx = "1.17.0"
+constraint-layout = "2.2.1"
+datastore-preferences = "1.1.7"
+
+# Material
+material = "1.13.0"
+
+# Testing
+espresso-core = "3.7.0"
+androidx-junit = "1.3.0"
+junit = "4.13.2"
+
+
+[plugins]
+android-application = { id = "com.android.application", version.ref = "agp" }
+android-library = { id = "com.android.library", version.ref = "agp" }
+jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
+
+
+[libraries]
+
+# AndroidX
+androidx-activity = { group = "androidx.activity", name = "activity", version.ref = "activity" }
+androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" }
+androidx-constraintlayout = { group = "androidx.constraintlayout", name = "constraintlayout", version.ref = "constraint-layout" }
+androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "core-ktx" }
+androidx-datastore-preferences = { group = "androidx.datastore", name = "datastore-preferences", version.ref = "datastore-preferences" }
+
+#Material
+material = { group = "com.google.android.material", name = "material", version.ref = "material" }
+
+# Testing
+androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espresso-core" }
+androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-junit" }
+junit = { group = "junit", name = "junit", version.ref = "junit" }
+
+[bundles]
+androidx = [
+ "androidx-activity",
+ "androidx-appcompat",
+ "androidx-constraintlayout",
+ "androidx-core-ktx",
+ "androidx-datastore-preferences",
+]
diff --git a/examples/llama.android/gradle/wrapper/gradle-wrapper.properties b/examples/llama.android/gradle/wrapper/gradle-wrapper.properties
index a3958c140b..6b993e909f 100644
--- a/examples/llama.android/gradle/wrapper/gradle-wrapper.properties
+++ b/examples/llama.android/gradle/wrapper/gradle-wrapper.properties
@@ -1,6 +1,6 @@
-#Thu Dec 21 14:31:09 AEDT 2023
+#Tue Apr 01 11:15:06 PDT 2025
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
-distributionUrl=https\://services.gradle.org/distributions/gradle-8.2-bin.zip
+distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
diff --git a/examples/llama.android/llama/.gitignore b/examples/llama.android/lib/.gitignore
similarity index 100%
rename from examples/llama.android/llama/.gitignore
rename to examples/llama.android/lib/.gitignore
diff --git a/examples/llama.android/lib/build.gradle.kts b/examples/llama.android/lib/build.gradle.kts
new file mode 100644
index 0000000000..5255f0c17b
--- /dev/null
+++ b/examples/llama.android/lib/build.gradle.kts
@@ -0,0 +1,78 @@
+plugins {
+ alias(libs.plugins.android.library)
+ alias(libs.plugins.jetbrains.kotlin.android)
+}
+
+android {
+ namespace = "com.arm.aichat"
+ compileSdk = 36
+
+ ndkVersion = "29.0.13113456"
+
+ defaultConfig {
+ minSdk = 33
+
+ testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
+ consumerProguardFiles("consumer-rules.pro")
+
+ ndk {
+ abiFilters += listOf("arm64-v8a", "x86_64")
+ }
+ externalNativeBuild {
+ cmake {
+ arguments += "-DCMAKE_BUILD_TYPE=Release"
+ arguments += "-DCMAKE_MESSAGE_LOG_LEVEL=DEBUG"
+ arguments += "-DCMAKE_VERBOSE_MAKEFILE=ON"
+
+ arguments += "-DBUILD_SHARED_LIBS=ON"
+ arguments += "-DLLAMA_BUILD_COMMON=ON"
+ arguments += "-DLLAMA_CURL=OFF"
+
+ arguments += "-DGGML_NATIVE=OFF"
+ arguments += "-DGGML_BACKEND_DL=ON"
+ arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
+ arguments += "-DGGML_LLAMAFILE=OFF"
+ }
+ }
+ aarMetadata {
+ minCompileSdk = 35
+ }
+ }
+ externalNativeBuild {
+ cmake {
+ path("src/main/cpp/CMakeLists.txt")
+ version = "3.31.6"
+ }
+ }
+ compileOptions {
+ sourceCompatibility = JavaVersion.VERSION_17
+ targetCompatibility = JavaVersion.VERSION_17
+ }
+ kotlin {
+ jvmToolchain(17)
+
+ compileOptions {
+ targetCompatibility = JavaVersion.VERSION_17
+ }
+ }
+
+ packaging {
+ resources {
+ excludes += "/META-INF/{AL2.0,LGPL2.1}"
+ }
+ }
+
+ publishing {
+ singleVariant("release") {
+ withJavadocJar()
+ }
+ }
+}
+
+dependencies {
+ implementation(libs.androidx.core.ktx)
+ implementation(libs.androidx.datastore.preferences)
+
+ testImplementation(libs.junit)
+ androidTestImplementation(libs.androidx.junit)
+}
diff --git a/examples/llama.android/lib/consumer-rules.pro b/examples/llama.android/lib/consumer-rules.pro
new file mode 100644
index 0000000000..e6eb6f5474
--- /dev/null
+++ b/examples/llama.android/lib/consumer-rules.pro
@@ -0,0 +1,8 @@
+-keep class com.arm.aichat.* { *; }
+-keep class com.arm.aichat.gguf.* { *; }
+
+-keepclasseswithmembernames class * {
+ native ;
+}
+
+-keep class kotlin.Metadata { *; }
diff --git a/examples/llama.android/llama/proguard-rules.pro b/examples/llama.android/lib/proguard-rules.pro
similarity index 100%
rename from examples/llama.android/llama/proguard-rules.pro
rename to examples/llama.android/lib/proguard-rules.pro
diff --git a/examples/llama.android/llama/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt b/examples/llama.android/lib/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt
similarity index 100%
rename from examples/llama.android/llama/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt
rename to examples/llama.android/lib/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt
diff --git a/examples/llama.android/llama/src/main/AndroidManifest.xml b/examples/llama.android/lib/src/main/AndroidManifest.xml
similarity index 100%
rename from examples/llama.android/llama/src/main/AndroidManifest.xml
rename to examples/llama.android/lib/src/main/AndroidManifest.xml
diff --git a/examples/llama.android/lib/src/main/cpp/CMakeLists.txt b/examples/llama.android/lib/src/main/cpp/CMakeLists.txt
new file mode 100644
index 0000000000..7862c61a3f
--- /dev/null
+++ b/examples/llama.android/lib/src/main/cpp/CMakeLists.txt
@@ -0,0 +1,56 @@
+cmake_minimum_required(VERSION 3.31.6)
+
+project("ai-chat" VERSION 1.0.0 LANGUAGES C CXX)
+
+set(CMAKE_C_STANDARD 11)
+set(CMAKE_C_STANDARD_REQUIRED true)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED true)
+
+set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE)
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE)
+
+# --------------------------------------------------------------------------
+# AI Chat library
+# --------------------------------------------------------------------------
+
+if(DEFINED ANDROID_ABI)
+ message(STATUS "Detected Android ABI: ${ANDROID_ABI}")
+ if(ANDROID_ABI STREQUAL "arm64-v8a")
+ set(GGML_SYSTEM_ARCH "ARM")
+ set(GGML_CPU_KLEIDIAI ON)
+ set(GGML_OPENMP ON)
+ elseif(ANDROID_ABI STREQUAL "x86_64")
+ set(GGML_SYSTEM_ARCH "x86")
+ set(GGML_CPU_KLEIDIAI OFF)
+ set(GGML_OPENMP OFF)
+ else()
+ message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}")
+ endif()
+endif()
+
+set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
+add_subdirectory(${LLAMA_SRC} build-llama)
+
+add_library(${CMAKE_PROJECT_NAME} SHARED
+ ai_chat.cpp)
+
+target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
+ GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH}
+ GGML_CPU_KLEIDIAI=$
+ GGML_OPENMP=$
+)
+
+target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
+ ${LLAMA_SRC}
+ ${LLAMA_SRC}/common
+ ${LLAMA_SRC}/include
+ ${LLAMA_SRC}/ggml/include
+ ${LLAMA_SRC}/ggml/src)
+
+target_link_libraries(${CMAKE_PROJECT_NAME}
+ llama
+ common
+ android
+ log)
diff --git a/examples/llama.android/lib/src/main/cpp/ai_chat.cpp b/examples/llama.android/lib/src/main/cpp/ai_chat.cpp
new file mode 100644
index 0000000000..d655a0965f
--- /dev/null
+++ b/examples/llama.android/lib/src/main/cpp/ai_chat.cpp
@@ -0,0 +1,565 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "logging.h"
+#include "chat.h"
+#include "common.h"
+#include "llama.h"
+
+template
+static std::string join(const std::vector &values, const std::string &delim) {
+ std::ostringstream str;
+ for (size_t i = 0; i < values.size(); i++) {
+ str << values[i];
+ if (i < values.size() - 1) { str << delim; }
+ }
+ return str.str();
+}
+
+/**
+ * LLama resources: context, model, batch and sampler
+ */
+constexpr int N_THREADS_MIN = 2;
+constexpr int N_THREADS_MAX = 4;
+constexpr int N_THREADS_HEADROOM = 2;
+
+constexpr int DEFAULT_CONTEXT_SIZE = 8192;
+constexpr int OVERFLOW_HEADROOM = 4;
+constexpr int BATCH_SIZE = 512;
+constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
+
+static llama_model * g_model;
+static llama_context * g_context;
+static llama_batch g_batch;
+static common_chat_templates_ptr g_chat_templates;
+static common_sampler * g_sampler;
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) {
+ // Set llama log handler to Android
+ llama_log_set(aichat_android_log_callback, nullptr);
+
+ // Loading all CPU backend variants
+ const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0);
+ LOGi("Loading backends from %s", path_to_backend);
+ ggml_backend_load_all_from_path(path_to_backend);
+ env->ReleaseStringUTFChars(nativeLibDir, path_to_backend);
+
+ // Initialize backends
+ llama_backend_init();
+ LOGi("Backend initiated; Log handler set.");
+}
+
+extern "C"
+JNIEXPORT jint JNICALL
+Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
+ llama_model_params model_params = llama_model_default_params();
+
+ const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
+ LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
+
+ auto *model = llama_model_load_from_file(model_path, model_params);
+ env->ReleaseStringUTFChars(jmodel_path, model_path);
+ if (!model) {
+ return 1;
+ }
+ g_model = model;
+ return 0;
+}
+
+static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
+ if (!model) {
+ LOGe("%s: model cannot be null", __func__);
+ return nullptr;
+ }
+
+ // Multi-threading setup
+ const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
+ (int) sysconf(_SC_NPROCESSORS_ONLN) -
+ N_THREADS_HEADROOM));
+ LOGi("%s: Using %d threads", __func__, n_threads);
+
+ // Context parameters setup
+ llama_context_params ctx_params = llama_context_default_params();
+ const int trained_context_size = llama_model_n_ctx_train(model);
+ if (n_ctx > trained_context_size) {
+ LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
+ __func__, trained_context_size, n_ctx);
+ }
+ ctx_params.n_ctx = n_ctx;
+ ctx_params.n_batch = BATCH_SIZE;
+ ctx_params.n_ubatch = BATCH_SIZE;
+ ctx_params.n_threads = n_threads;
+ ctx_params.n_threads_batch = n_threads;
+ auto *context = llama_init_from_model(g_model, ctx_params);
+ if (context == nullptr) {
+ LOGe("%s: llama_new_context_with_model() returned null)", __func__);
+ }
+ return context;
+}
+
+static common_sampler *new_sampler(float temp) {
+ common_params_sampling sparams;
+ sparams.temp = temp;
+ return common_sampler_init(g_model, sparams);
+}
+
+extern "C"
+JNIEXPORT jint JNICALL
+Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
+ auto *context = init_context(g_model);
+ if (!context) { return 1; }
+ g_context = context;
+ g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
+ g_chat_templates = common_chat_templates_init(g_model, "");
+ g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
+ return 0;
+}
+
+static std::string get_backend() {
+ std::vector backends;
+ for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
+ auto *reg = ggml_backend_reg_get(i);
+ std::string name = ggml_backend_reg_name(reg);
+ if (name != "CPU") {
+ backends.push_back(ggml_backend_reg_name(reg));
+ }
+ }
+ return backends.empty() ? "CPU" : join(backends, ",");
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
+ return env->NewStringUTF(llama_print_system_info());
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
+ jint pl, jint nr) {
+ auto *context = init_context(g_model, pp);
+ if (!context) {
+ const auto *const err_msg = "Fail to init_context! Bench aborted.";
+ LOGe(err_msg);
+ return env->NewStringUTF(err_msg);
+ }
+
+ auto pp_avg = 0.0;
+ auto tg_avg = 0.0;
+ auto pp_std = 0.0;
+ auto tg_std = 0.0;
+
+ const uint32_t n_ctx = llama_n_ctx(context);
+ LOGi("n_ctx = %d", n_ctx);
+
+ int i, j;
+ int nri;
+ for (nri = 0; nri < nr; nri++) {
+ LOGi("Benchmark prompt processing (pp = %d)", pp);
+
+ common_batch_clear(g_batch);
+
+ const int n_tokens = pp;
+ for (i = 0; i < n_tokens; i++) {
+ common_batch_add(g_batch, 0, i, {0}, false);
+ }
+
+ g_batch.logits[g_batch.n_tokens - 1] = true;
+ llama_memory_clear(llama_get_memory(context), false);
+
+ const auto t_pp_start = ggml_time_us();
+ if (llama_decode(context, g_batch) != 0) {
+ LOGe("llama_decode() failed during prompt processing");
+ }
+ const auto t_pp_end = ggml_time_us();
+
+ // bench text generation
+
+ LOGi("Benchmark text generation (tg = %d)", tg);
+
+ llama_memory_clear(llama_get_memory(context), false);
+ const auto t_tg_start = ggml_time_us();
+ for (i = 0; i < tg; i++) {
+ common_batch_clear(g_batch);
+ for (j = 0; j < pl; j++) {
+ common_batch_add(g_batch, 0, i, {j}, true);
+ }
+
+ if (llama_decode(context, g_batch) != 0) {
+ LOGe("llama_decode() failed during text generation");
+ }
+ }
+ const auto t_tg_end = ggml_time_us();
+
+ llama_memory_clear(llama_get_memory(context), false);
+
+ const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
+ const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
+
+ const auto speed_pp = double(pp) / t_pp;
+ const auto speed_tg = double(pl * tg) / t_tg;
+
+ pp_avg += speed_pp;
+ tg_avg += speed_tg;
+
+ pp_std += speed_pp * speed_pp;
+ tg_std += speed_tg * speed_tg;
+
+ LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
+ }
+
+ llama_free(context);
+
+ pp_avg /= double(nr);
+ tg_avg /= double(nr);
+
+ if (nr > 1) {
+ pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
+ tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
+ } else {
+ pp_std = 0;
+ tg_std = 0;
+ }
+
+ char model_desc[128];
+ llama_model_desc(g_model, model_desc, sizeof(model_desc));
+
+ const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
+ const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
+
+ const auto backend = get_backend();
+ std::stringstream result;
+ result << std::setprecision(3);
+ result << "| model | size | params | backend | test | t/s |\n";
+ result << "| --- | --- | --- | --- | --- | --- |\n";
+ result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
+ << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
+ result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
+ << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
+ return env->NewStringUTF(result.str().c_str());
+}
+
+
+/**
+ * Completion loop's long-term states:
+ * - chat management
+ * - position tracking
+ */
+constexpr const char *ROLE_SYSTEM = "system";
+constexpr const char *ROLE_USER = "user";
+constexpr const char *ROLE_ASSISTANT = "assistant";
+
+static std::vector chat_msgs;
+static llama_pos system_prompt_position;
+static llama_pos current_position;
+
+static void reset_long_term_states(const bool clear_kv_cache = true) {
+ chat_msgs.clear();
+ system_prompt_position = 0;
+ current_position = 0;
+
+ if (clear_kv_cache)
+ llama_memory_clear(llama_get_memory(g_context), false);
+}
+
+/**
+ * TODO-hyin: implement sliding-window version as a better alternative
+ *
+ * Context shifting by discarding the older half of the tokens appended after system prompt:
+ * - take the [system_prompt_position] first tokens from the original prompt
+ * - take half of the last (system_prompt_position - system_prompt_position) tokens
+ * - recompute the logits in batches
+ */
+static void shift_context() {
+ const int n_discard = (current_position - system_prompt_position) / 2;
+ LOGi("%s: Discarding %d tokens", __func__, n_discard);
+ llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
+ llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
+ current_position -= n_discard;
+ LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
+}
+
+static std::string chat_add_and_format(const std::string &role, const std::string &content) {
+ common_chat_msg new_msg;
+ new_msg.role = role;
+ new_msg.content = content;
+ auto formatted = common_chat_format_single(
+ g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
+ chat_msgs.push_back(new_msg);
+ LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
+ return formatted;
+}
+
+/**
+ * Completion loop's short-term states:
+ * - stop generation position
+ * - token chars caching
+ * - current assistant message being generated
+ */
+static llama_pos stop_generation_position;
+static std::string cached_token_chars;
+static std::ostringstream assistant_ss;
+
+static void reset_short_term_states() {
+ stop_generation_position = 0;
+ cached_token_chars.clear();
+ assistant_ss.str("");
+}
+
+static int decode_tokens_in_batches(
+ llama_context *context,
+ llama_batch &batch,
+ const llama_tokens &tokens,
+ const llama_pos start_pos,
+ const bool compute_last_logit = false) {
+ // Process tokens in batches using the global batch
+ LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
+ for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
+ const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
+ common_batch_clear(batch);
+ LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
+
+ // Shift context if current batch cannot fit into the context
+ if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
+ LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
+ shift_context();
+ }
+
+ // Add tokens to the batch with proper positions
+ for (int j = 0; j < cur_batch_size; j++) {
+ const llama_token token_id = tokens[i + j];
+ const llama_pos position = start_pos + i + j;
+ const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
+ common_batch_add(batch, token_id, position, {0}, want_logit);
+ }
+
+ // Decode this batch
+ const int decode_result = llama_decode(context, batch);
+ if (decode_result) {
+ LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
+ return 1;
+ }
+ }
+ return 0;
+}
+
+extern "C"
+JNIEXPORT jint JNICALL
+Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt(
+ JNIEnv *env,
+ jobject /*unused*/,
+ jstring jsystem_prompt
+) {
+ // Reset long-term & short-term states
+ reset_long_term_states();
+ reset_short_term_states();
+
+ // Obtain system prompt from JEnv
+ const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
+ LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
+ std::string formatted_system_prompt(system_prompt);
+ env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
+
+ // Format system prompt if applicable
+ const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
+ if (has_chat_template) {
+ formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
+ }
+
+ // Tokenize system prompt
+ const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
+ has_chat_template, has_chat_template);
+ for (auto id: system_tokens) {
+ LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
+ }
+
+ // Handle context overflow
+ const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
+ if ((int) system_tokens.size() > max_batch_size) {
+ LOGe("%s: System prompt too long for context! %d tokens, max: %d",
+ __func__, (int) system_tokens.size(), max_batch_size);
+ return 1;
+ }
+
+ // Decode system tokens in batches
+ if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
+ LOGe("%s: llama_decode() failed!", __func__);
+ return 2;
+ }
+
+ // Update position
+ system_prompt_position = current_position = (int) system_tokens.size();
+ return 0;
+}
+
+extern "C"
+JNIEXPORT jint JNICALL
+Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt(
+ JNIEnv *env,
+ jobject /*unused*/,
+ jstring juser_prompt,
+ jint n_predict
+) {
+ // Reset short-term states
+ reset_short_term_states();
+
+ // Obtain and tokenize user prompt
+ const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
+ LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
+ std::string formatted_user_prompt(user_prompt);
+ env->ReleaseStringUTFChars(juser_prompt, user_prompt);
+
+ // Format user prompt if applicable
+ const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
+ if (has_chat_template) {
+ formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
+ }
+
+ // Decode formatted user prompts
+ auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
+ for (auto id: user_tokens) {
+ LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
+ }
+
+ // Ensure user prompt doesn't exceed the context size by truncating if necessary.
+ const int user_prompt_size = (int) user_tokens.size();
+ const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
+ if (user_prompt_size > max_batch_size) {
+ const int skipped_tokens = user_prompt_size - max_batch_size;
+ user_tokens.resize(max_batch_size);
+ LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
+ }
+
+ // Decode user tokens in batches
+ if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
+ LOGe("%s: llama_decode() failed!", __func__);
+ return 2;
+ }
+
+ // Update position
+ current_position += user_prompt_size;
+ stop_generation_position = current_position + user_prompt_size + n_predict;
+ return 0;
+}
+
+static bool is_valid_utf8(const char *string) {
+ if (!string) { return true; }
+
+ const auto *bytes = (const unsigned char *) string;
+ int num;
+
+ while (*bytes != 0x00) {
+ if ((*bytes & 0x80) == 0x00) {
+ // U+0000 to U+007F
+ num = 1;
+ } else if ((*bytes & 0xE0) == 0xC0) {
+ // U+0080 to U+07FF
+ num = 2;
+ } else if ((*bytes & 0xF0) == 0xE0) {
+ // U+0800 to U+FFFF
+ num = 3;
+ } else if ((*bytes & 0xF8) == 0xF0) {
+ // U+10000 to U+10FFFF
+ num = 4;
+ } else {
+ return false;
+ }
+
+ bytes += 1;
+ for (int i = 1; i < num; ++i) {
+ if ((*bytes & 0xC0) != 0x80) {
+ return false;
+ }
+ bytes += 1;
+ }
+ }
+ return true;
+}
+
+extern "C"
+JNIEXPORT jstring JNICALL
+Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken(
+ JNIEnv *env,
+ jobject /*unused*/
+) {
+ // Infinite text generation via context shifting
+ if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
+ LOGw("%s: Context full! Shifting...", __func__);
+ shift_context();
+ }
+
+ // Stop if reaching the marked position
+ if (current_position >= stop_generation_position) {
+ LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
+ return nullptr;
+ }
+
+ // Sample next token
+ const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
+ common_sampler_accept(g_sampler, new_token_id, true);
+
+ // Populate the batch with new token, then decode
+ common_batch_clear(g_batch);
+ common_batch_add(g_batch, new_token_id, current_position, {0}, true);
+ if (llama_decode(g_context, g_batch) != 0) {
+ LOGe("%s: llama_decode() failed for generated token", __func__);
+ return nullptr;
+ }
+
+ // Update position
+ current_position++;
+
+ // Stop if next token is EOG
+ if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
+ LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
+ chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
+ return nullptr;
+ }
+
+ // If not EOG, convert to text
+ auto new_token_chars = common_token_to_piece(g_context, new_token_id);
+ cached_token_chars += new_token_chars;
+
+ // Create and return a valid UTF-8 Java string
+ jstring result = nullptr;
+ if (is_valid_utf8(cached_token_chars.c_str())) {
+ result = env->NewStringUTF(cached_token_chars.c_str());
+ LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
+
+ assistant_ss << cached_token_chars;
+ cached_token_chars.clear();
+ } else {
+ LOGv("id: %d,\tappend to cache", new_token_id);
+ result = env->NewStringUTF("");
+ }
+ return result;
+}
+
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
+ // Reset long-term & short-term states
+ reset_long_term_states();
+ reset_short_term_states();
+
+ // Free up resources
+ common_sampler_free(g_sampler);
+ g_chat_templates.reset();
+ llama_batch_free(g_batch);
+ llama_free(g_context);
+ llama_model_free(g_model);
+}
+
+extern "C"
+JNIEXPORT void JNICALL
+Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *env, jobject /*unused*/) {
+ llama_backend_free();
+}
diff --git a/examples/llama.android/lib/src/main/cpp/logging.h b/examples/llama.android/lib/src/main/cpp/logging.h
new file mode 100644
index 0000000000..2e768d2beb
--- /dev/null
+++ b/examples/llama.android/lib/src/main/cpp/logging.h
@@ -0,0 +1,61 @@
+//
+// Created by Han Yin on 10/31/25.
+//
+
+#ifndef AICHAT_LOGGING_H
+#define AICHAT_LOGGING_H
+
+#endif //AICHAT_LOGGING_H
+
+#pragma once
+#include
+
+#ifndef LOG_TAG
+#define LOG_TAG "ai-chat"
+#endif
+
+#ifndef LOG_MIN_LEVEL
+#if defined(NDEBUG)
+#define LOG_MIN_LEVEL ANDROID_LOG_INFO
+#else
+#define LOG_MIN_LEVEL ANDROID_LOG_VERBOSE
+#endif
+#endif
+
+static inline int ai_should_log(int prio) {
+ return __android_log_is_loggable(prio, LOG_TAG, LOG_MIN_LEVEL);
+}
+
+#if LOG_MIN_LEVEL <= ANDROID_LOG_VERBOSE
+#define LOGv(...) do { if (ai_should_log(ANDROID_LOG_VERBOSE)) __android_log_print(ANDROID_LOG_VERBOSE, LOG_TAG, __VA_ARGS__); } while (0)
+#else
+#define LOGv(...) ((void)0)
+#endif
+
+#if LOG_MIN_LEVEL <= ANDROID_LOG_DEBUG
+#define LOGd(...) do { if (ai_should_log(ANDROID_LOG_DEBUG)) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__); } while (0)
+#else
+#define LOGd(...) ((void)0)
+#endif
+
+#define LOGi(...) do { if (ai_should_log(ANDROID_LOG_INFO )) __android_log_print(ANDROID_LOG_INFO , LOG_TAG, __VA_ARGS__); } while (0)
+#define LOGw(...) do { if (ai_should_log(ANDROID_LOG_WARN )) __android_log_print(ANDROID_LOG_WARN , LOG_TAG, __VA_ARGS__); } while (0)
+#define LOGe(...) do { if (ai_should_log(ANDROID_LOG_ERROR)) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__); } while (0)
+
+static inline int android_log_prio_from_ggml(enum ggml_log_level level) {
+ switch (level) {
+ case GGML_LOG_LEVEL_ERROR: return ANDROID_LOG_ERROR;
+ case GGML_LOG_LEVEL_WARN: return ANDROID_LOG_WARN;
+ case GGML_LOG_LEVEL_INFO: return ANDROID_LOG_INFO;
+ case GGML_LOG_LEVEL_DEBUG: return ANDROID_LOG_DEBUG;
+ default: return ANDROID_LOG_DEFAULT;
+ }
+}
+
+static inline void aichat_android_log_callback(enum ggml_log_level level,
+ const char* text,
+ void* /*user*/) {
+ const int prio = android_log_prio_from_ggml(level);
+ if (!ai_should_log(prio)) return;
+ __android_log_write(prio, LOG_TAG, text);
+}
diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt
new file mode 100644
index 0000000000..b72a24ec1d
--- /dev/null
+++ b/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt
@@ -0,0 +1,14 @@
+package com.arm.aichat
+
+import android.content.Context
+import com.arm.aichat.internal.InferenceEngineImpl
+
+/**
+ * Main entry point for Arm's AI Chat library.
+ */
+object AiChat {
+ /**
+ * Get the inference engine single instance.
+ */
+ fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context)
+}
diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt
new file mode 100644
index 0000000000..44852fa828
--- /dev/null
+++ b/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt
@@ -0,0 +1,89 @@
+package com.arm.aichat
+
+import com.arm.aichat.InferenceEngine.State
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.StateFlow
+
+/**
+ * Interface defining the core LLM inference operations.
+ */
+interface InferenceEngine {
+ /**
+ * Current state of the inference engine
+ */
+ val state: StateFlow
+
+ /**
+ * Load a model from the given path.
+ *
+ * @throws UnsupportedArchitectureException if model architecture not supported
+ */
+ suspend fun loadModel(pathToModel: String)
+
+ /**
+ * Sends a system prompt to the loaded model
+ */
+ suspend fun setSystemPrompt(systemPrompt: String)
+
+ /**
+ * Sends a user prompt to the loaded model and returns a Flow of generated tokens.
+ */
+ fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow
+
+ /**
+ * Runs a benchmark with the specified parameters.
+ */
+ suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
+
+ /**
+ * Unloads the currently loaded model.
+ */
+ suspend fun cleanUp()
+
+ /**
+ * Cleans up resources when the engine is no longer needed.
+ */
+ fun destroy()
+
+ /**
+ * States of the inference engine
+ */
+ sealed class State {
+ object Uninitialized : State()
+ object Initializing : State()
+ object Initialized : State()
+
+ object LoadingModel : State()
+ object UnloadingModel : State()
+ object ModelReady : State()
+
+ object Benchmarking : State()
+ object ProcessingSystemPrompt : State()
+ object ProcessingUserPrompt : State()
+
+ object Generating : State()
+
+ data class Error(val exception: Exception) : State()
+ }
+
+ companion object {
+ const val DEFAULT_PREDICT_LENGTH = 1024
+ }
+}
+
+val State.isUninterruptible
+ get() = this is State.Initializing ||
+ this is State.LoadingModel ||
+ this is State.UnloadingModel ||
+ this is State.Benchmarking ||
+ this is State.ProcessingSystemPrompt ||
+ this is State.ProcessingUserPrompt
+
+val State.isModelLoaded: Boolean
+ get() = this is State.ModelReady ||
+ this is State.Benchmarking ||
+ this is State.ProcessingSystemPrompt ||
+ this is State.ProcessingUserPrompt ||
+ this is State.Generating
+
+class UnsupportedArchitectureException : Exception()
diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt
new file mode 100644
index 0000000000..2f15eef077
--- /dev/null
+++ b/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt
@@ -0,0 +1,61 @@
+package com.arm.aichat.gguf
+
+import kotlin.collections.get
+
+
+/**
+ * Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`).
+ * The `label` matches what llama‑cli prints.
+ */
+enum class FileType(val code: Int, val label: String) {
+ ALL_F32(0, "all F32"),
+ MOSTLY_F16(1, "F16"),
+ MOSTLY_Q4_0(2, "Q4_0"),
+ MOSTLY_Q4_1(3, "Q4_1"),
+ // 4 removed
+ MOSTLY_Q8_0(7, "Q8_0"),
+ MOSTLY_Q5_0(8, "Q5_0"),
+ MOSTLY_Q5_1(9, "Q5_1"),
+
+ /* K‑quants ------------------------------------------------------------ */
+ MOSTLY_Q2_K (10, "Q2_K - Medium"),
+ MOSTLY_Q3_K_S (11, "Q3_K - Small"),
+ MOSTLY_Q3_K_M (12, "Q3_K - Medium"),
+ MOSTLY_Q3_K_L (13, "Q3_K - Large"),
+ MOSTLY_Q4_K_S (14, "Q4_K - Small"),
+ MOSTLY_Q4_K_M (15, "Q4_K - Medium"),
+ MOSTLY_Q5_K_S (16, "Q5_K - Small"),
+ MOSTLY_Q5_K_M (17, "Q5_K - Medium"),
+ MOSTLY_Q6_K (18, "Q6_K"),
+
+ /* IQ quants ----------------------------------------------------------- */
+ MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"),
+ MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"),
+ MOSTLY_Q2_K_S (21, "Q2_K - Small"),
+ MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"),
+ MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"),
+ MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"),
+ MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"),
+ MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"),
+ MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"),
+ MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"),
+ MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"),
+ MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"),
+ MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"),
+
+ /* BF16 & Ternary ------------------------------------------------------ */
+ MOSTLY_BF16 (32, "BF16"),
+ MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"),
+ MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"),
+
+ /* Special flag -------------------------------------------------------- */
+ GUESSED(1024, "(guessed)"),
+
+ UNKNOWN(-1, "unknown");
+
+ companion object {
+ private val map = entries.associateBy(FileType::code)
+
+ fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN
+ }
+}
diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt
new file mode 100644
index 0000000000..5e1971ae2f
--- /dev/null
+++ b/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt
@@ -0,0 +1,132 @@
+package com.arm.aichat.gguf
+
+import java.io.IOException
+
+
+/**
+ * Structured metadata of GGUF
+ */
+data class GgufMetadata(
+ // Basic file info
+ val version: GgufVersion,
+ val tensorCount: Long,
+ val kvCount: Long,
+
+ // General info
+ val basic: BasicInfo,
+ val author: AuthorInfo? = null,
+ val additional: AdditionalInfo? = null,
+ val architecture: ArchitectureInfo? = null,
+ val baseModels: List? = null,
+ val tokenizer: TokenizerInfo? = null,
+
+ // Derivative info
+ val dimensions: DimensionsInfo? = null,
+ val attention: AttentionInfo? = null,
+ val rope: RopeInfo? = null,
+ val experts: ExpertsInfo? = null
+) {
+ enum class GgufVersion(val code: Int, val label: String) {
+ /** First public draft; little‑endian only, no alignment key. */
+ LEGACY_V1(1, "Legacy v1"),
+
+ /** Added split‑file support and some extra metadata keys. */
+ EXTENDED_V2(2, "Extended v2"),
+
+ /** Current spec: endian‑aware, mandatory alignment, fully validated. */
+ VALIDATED_V3(3, "Validated v3");
+
+ companion object {
+ fun fromCode(code: Int): GgufVersion =
+ entries.firstOrNull { it.code == code }
+ ?: throw IOException("Unknown GGUF version code $code")
+ }
+
+ override fun toString(): String = "$label (code=$code)"
+ }
+
+ data class BasicInfo(
+ val uuid: String? = null,
+ val name: String? = null,
+ val nameLabel: String? = null,
+ val sizeLabel: String? = null, // Size label like "7B"
+ )
+
+ data class AuthorInfo(
+ val organization: String? = null,
+ val author: String? = null,
+ val doi: String? = null,
+ val url: String? = null,
+ val repoUrl: String? = null,
+ val license: String? = null,
+ val licenseLink: String? = null,
+ )
+
+ data class AdditionalInfo(
+ val type: String? = null,
+ val description: String? = null,
+ val tags: List? = null,
+ val languages: List? = null,
+ )
+
+ data class ArchitectureInfo(
+ val architecture: String? = null,
+ val fileType: Int? = null,
+ val vocabSize: Int? = null,
+ val finetune: String? = null,
+ val quantizationVersion: Int? = null,
+ )
+
+ data class BaseModelInfo(
+ val name: String? = null,
+ val author: String? = null,
+ val version: String? = null,
+ val organization: String? = null,
+ val url: String? = null,
+ val doi: String? = null,
+ val uuid: String? = null,
+ val repoUrl: String? = null,
+ )
+
+ data class TokenizerInfo(
+ val model: String? = null,
+ val bosTokenId: Int? = null,
+ val eosTokenId: Int? = null,
+ val unknownTokenId: Int? = null,
+ val paddingTokenId: Int? = null,
+ val addBosToken: Boolean? = null,
+ val addEosToken: Boolean? = null,
+ val chatTemplate: String? = null,
+ )
+
+ data class DimensionsInfo(
+ val contextLength: Int? = null,
+ val embeddingSize: Int? = null,
+ val blockCount: Int? = null,
+ val feedForwardSize: Int? = null,
+ )
+
+ data class AttentionInfo(
+ val headCount: Int? = null,
+ val headCountKv: Int? = null,
+ val keyLength: Int? = null,
+ val valueLength: Int? = null,
+ val layerNormEpsilon: Float? = null,
+ val layerNormRmsEpsilon: Float? = null,
+ )
+
+ data class RopeInfo(
+ val frequencyBase: Float? = null,
+ val dimensionCount: Int? = null,
+ val scalingType: String? = null,
+ val scalingFactor: Float? = null,
+ val attnFactor: Float? = null,
+ val originalContextLength: Int? = null,
+ val finetuned: Boolean? = null,
+ )
+
+ data class ExpertsInfo(
+ val count: Int? = null,
+ val usedCount: Int? = null,
+ )
+}
diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt
new file mode 100644
index 0000000000..264a6c0bda
--- /dev/null
+++ b/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt
@@ -0,0 +1,77 @@
+package com.arm.aichat.gguf
+
+import android.content.Context
+import android.net.Uri
+import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl
+import java.io.File
+import java.io.IOException
+import java.io.InputStream
+
+/**
+ * Interface for reading GGUF metadata from model files.
+ * Use `GgufMetadataReader.create()` to get an instance.
+ */
+interface GgufMetadataReader {
+ /**
+ * Reads the magic number from the specified file path.
+ *
+ * @param file Java File to the GGUF file with absolute path
+ * @return true if file is valid GGUF, otherwise false
+ * @throws InvalidFileFormatException if file format is invalid
+ */
+ suspend fun ensureSourceFileFormat(file: File): Boolean
+
+ /**
+ * Reads the magic number from the specified file path.
+ *
+ * @param context Context for obtaining [android.content.ContentProvider]
+ * @param uri Uri to the GGUF file provided by [android.content.ContentProvider]
+ * @return true if file is valid GGUF, otherwise false
+ * @throws InvalidFileFormatException if file format is invalid
+ */
+ suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean
+
+ /**
+ * Reads and parses GGUF metadata from the specified file path.
+ *
+ * @param input the [InputStream] obtained from a readable file or content
+ * @return Structured metadata extracted from the file
+ * @throws IOException if file is damaged or cannot be read
+ * @throws InvalidFileFormatException if file format is invalid
+ */
+ suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
+
+ companion object {
+ private val DEFAULT_SKIP_KEYS = setOf(
+ "tokenizer.chat_template",
+ "tokenizer.ggml.scores",
+ "tokenizer.ggml.tokens",
+ "tokenizer.ggml.token_type"
+ )
+
+ /**
+ * Creates a default GgufMetadataReader instance
+ */
+ fun create(): GgufMetadataReader = GgufMetadataReaderImpl(
+ skipKeys = DEFAULT_SKIP_KEYS,
+ arraySummariseThreshold = 1_000
+ )
+
+ /**
+ * Creates a GgufMetadataReader with custom configuration
+ *
+ * @param skipKeys Keys whose value should be skipped entirely (not kept in the result map)
+ * @param arraySummariseThreshold If ≥0, arrays longer get summarised, not materialised;
+ * If -1, never summarise.
+ */
+ fun create(
+ skipKeys: Set = DEFAULT_SKIP_KEYS,
+ arraySummariseThreshold: Int = 1_000
+ ): GgufMetadataReader = GgufMetadataReaderImpl(
+ skipKeys = skipKeys,
+ arraySummariseThreshold = arraySummariseThreshold
+ )
+ }
+}
+
+class InvalidFileFormatException : IOException()
diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt
new file mode 100644
index 0000000000..b9056ea819
--- /dev/null
+++ b/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt
@@ -0,0 +1,309 @@
+package com.arm.aichat.internal
+
+import android.content.Context
+import android.util.Log
+import com.arm.aichat.InferenceEngine
+import com.arm.aichat.UnsupportedArchitectureException
+import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance
+import dalvik.annotation.optimization.FastNative
+import kotlinx.coroutines.CancellationException
+import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.SupervisorJob
+import kotlinx.coroutines.cancel
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.MutableStateFlow
+import kotlinx.coroutines.flow.StateFlow
+import kotlinx.coroutines.flow.flow
+import kotlinx.coroutines.flow.flowOn
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.withContext
+import java.io.File
+import java.io.IOException
+
+/**
+ * JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
+ *
+ * This class implements a singleton pattern for managing the lifecycle of a single LLM instance.
+ * All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety
+ * with the underlying C++ native code.
+ *
+ * The typical usage flow is:
+ * 1. Get instance via [getInstance]
+ * 2. Load a model with [loadModel]
+ * 3. Send prompts with [sendUserPrompt]
+ * 4. Generate responses as token streams
+ * 5. Perform [cleanUp] when done with a model
+ * 6. Properly [destroy] when completely done
+ *
+ * State transitions are managed automatically and validated at each operation.
+ *
+ * @see ai_chat.cpp for the native implementation details
+ */
+internal class InferenceEngineImpl private constructor(
+ private val nativeLibDir: String
+) : InferenceEngine {
+
+ companion object {
+ private val TAG = InferenceEngineImpl::class.java.simpleName
+
+ @Volatile
+ private var instance: InferenceEngine? = null
+
+ /**
+ * Create or obtain [InferenceEngineImpl]'s single instance.
+ *
+ * @param Context for obtaining native library directory
+ * @throws IllegalArgumentException if native library path is invalid
+ * @throws UnsatisfiedLinkError if library failed to load
+ */
+ internal fun getInstance(context: Context) =
+ instance ?: synchronized(this) {
+ val nativeLibDir = context.applicationInfo.nativeLibraryDir
+ require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" }
+
+ try {
+ Log.i(TAG, "Instantiating InferenceEngineImpl,,,")
+ InferenceEngineImpl(nativeLibDir).also { instance = it }
+ } catch (e: UnsatisfiedLinkError) {
+ Log.e(TAG, "Failed to load native library from $nativeLibDir", e)
+ throw e
+ }
+ }
+ }
+
+ /**
+ * JNI methods
+ * @see ai_chat.cpp
+ */
+ @FastNative
+ private external fun init(nativeLibDir: String)
+
+ @FastNative
+ private external fun load(modelPath: String): Int
+
+ @FastNative
+ private external fun prepare(): Int
+
+ @FastNative
+ private external fun systemInfo(): String
+
+ @FastNative
+ private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
+
+ @FastNative
+ private external fun processSystemPrompt(systemPrompt: String): Int
+
+ @FastNative
+ private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
+
+ @FastNative
+ private external fun generateNextToken(): String?
+
+ @FastNative
+ private external fun unload()
+
+ @FastNative
+ private external fun shutdown()
+
+ private val _state =
+ MutableStateFlow(InferenceEngine.State.Uninitialized)
+ override val state: StateFlow = _state
+
+ private var _readyForSystemPrompt = false
+
+ /**
+ * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
+ */
+ @OptIn(ExperimentalCoroutinesApi::class)
+ private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
+ private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
+
+ init {
+ llamaScope.launch {
+ try {
+ check(_state.value is InferenceEngine.State.Uninitialized) {
+ "Cannot load native library in ${_state.value.javaClass.simpleName}!"
+ }
+ _state.value = InferenceEngine.State.Initializing
+ Log.i(TAG, "Loading native library...")
+ System.loadLibrary("ai-chat")
+ init(nativeLibDir)
+ _state.value = InferenceEngine.State.Initialized
+ Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}")
+
+ } catch (e: Exception) {
+ Log.e(TAG, "Failed to load native library", e)
+ throw e
+ }
+ }
+ }
+
+ /**
+ * Load the LLM
+ */
+ override suspend fun loadModel(pathToModel: String) =
+ withContext(llamaDispatcher) {
+ check(_state.value is InferenceEngine.State.Initialized) {
+ "Cannot load model in ${_state.value.javaClass.simpleName}!"
+ }
+
+ try {
+ Log.i(TAG, "Checking access to model file... \n$pathToModel")
+ File(pathToModel).let {
+ require(it.exists()) { "File not found" }
+ require(it.isFile) { "Not a valid file" }
+ require(it.canRead()) { "Cannot read file" }
+ }
+
+ Log.i(TAG, "Loading model... \n$pathToModel")
+ _readyForSystemPrompt = false
+ _state.value = InferenceEngine.State.LoadingModel
+ load(pathToModel).let {
+ // TODO-han.yin: find a better way to pass other error codes
+ if (it != 0) throw UnsupportedArchitectureException()
+ }
+ prepare().let {
+ if (it != 0) throw IOException("Failed to prepare resources")
+ }
+ Log.i(TAG, "Model loaded!")
+ _readyForSystemPrompt = true
+ _state.value = InferenceEngine.State.ModelReady
+ } catch (e: Exception) {
+ Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e)
+ _state.value = InferenceEngine.State.Error(e)
+ throw e
+ }
+ }
+
+ /**
+ * Process the plain text system prompt
+ *
+ * TODO-han.yin: return error code if system prompt not correct processed?
+ */
+ override suspend fun setSystemPrompt(prompt: String) =
+ withContext(llamaDispatcher) {
+ require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
+ check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
+ check(_state.value is InferenceEngine.State.ModelReady) {
+ "Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
+ }
+
+ Log.i(TAG, "Sending system prompt...")
+ _readyForSystemPrompt = false
+ _state.value = InferenceEngine.State.ProcessingSystemPrompt
+ processSystemPrompt(prompt).let { result ->
+ if (result != 0) {
+ RuntimeException("Failed to process system prompt: $result").also {
+ _state.value = InferenceEngine.State.Error(it)
+ throw it
+ }
+ }
+ }
+ Log.i(TAG, "System prompt processed! Awaiting user prompt...")
+ _state.value = InferenceEngine.State.ModelReady
+ }
+
+ /**
+ * Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
+ */
+ override fun sendUserPrompt(
+ message: String,
+ predictLength: Int,
+ ): Flow = flow {
+ require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
+ check(_state.value is InferenceEngine.State.ModelReady) {
+ "User prompt discarded due to: ${_state.value.javaClass.simpleName}"
+ }
+
+ try {
+ Log.i(TAG, "Sending user prompt...")
+ _readyForSystemPrompt = false
+ _state.value = InferenceEngine.State.ProcessingUserPrompt
+
+ processUserPrompt(message, predictLength).let { result ->
+ if (result != 0) {
+ Log.e(TAG, "Failed to process user prompt: $result")
+ return@flow
+ }
+ }
+
+ Log.i(TAG, "User prompt processed. Generating assistant prompt...")
+ _state.value = InferenceEngine.State.Generating
+ while (true) {
+ generateNextToken()?.let { utf8token ->
+ if (utf8token.isNotEmpty()) emit(utf8token)
+ } ?: break
+ }
+ Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
+ _state.value = InferenceEngine.State.ModelReady
+ } catch (e: CancellationException) {
+ Log.i(TAG, "Generation cancelled by user.")
+ _state.value = InferenceEngine.State.ModelReady
+ throw e
+ } catch (e: Exception) {
+ Log.e(TAG, "Error during generation!", e)
+ _state.value = InferenceEngine.State.Error(e)
+ throw e
+ }
+ }.flowOn(llamaDispatcher)
+
+ /**
+ * Benchmark the model
+ */
+ override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
+ withContext(llamaDispatcher) {
+ check(_state.value is InferenceEngine.State.ModelReady) {
+ "Benchmark request discarded due to: $state"
+ }
+ Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
+ _readyForSystemPrompt = false // Just to be safe
+ _state.value = InferenceEngine.State.Benchmarking
+ benchModel(pp, tg, pl, nr).also {
+ _state.value = InferenceEngine.State.ModelReady
+ }
+ }
+
+ /**
+ * Unloads the model and frees resources, or reset error states
+ */
+ override suspend fun cleanUp() =
+ withContext(llamaDispatcher) {
+ when (val state = _state.value) {
+ is InferenceEngine.State.ModelReady -> {
+ Log.i(TAG, "Unloading model and free resources...")
+ _readyForSystemPrompt = false
+ _state.value = InferenceEngine.State.UnloadingModel
+
+ unload()
+
+ _state.value = InferenceEngine.State.Initialized
+ Log.i(TAG, "Model unloaded!")
+ Unit
+ }
+
+ is InferenceEngine.State.Error -> {
+ Log.i(TAG, "Resetting error states...")
+ _state.value = InferenceEngine.State.Initialized
+ Log.i(TAG, "States reset!")
+ Unit
+ }
+
+ else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
+ }
+ }
+
+ /**
+ * Cancel all ongoing coroutines and free GGML backends
+ */
+ override fun destroy() {
+ _readyForSystemPrompt = false
+ llamaScope.cancel()
+ when(_state.value) {
+ is InferenceEngine.State.Uninitialized -> {}
+ is InferenceEngine.State.Initialized -> shutdown()
+ else -> { unload(); shutdown() }
+ }
+ }
+}
diff --git a/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt b/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt
new file mode 100644
index 0000000000..bf250ac13c
--- /dev/null
+++ b/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt
@@ -0,0 +1,590 @@
+package com.arm.aichat.internal.gguf
+
+import android.content.Context
+import android.net.Uri
+import com.arm.aichat.gguf.GgufMetadata
+import com.arm.aichat.gguf.GgufMetadataReader
+import com.arm.aichat.gguf.InvalidFileFormatException
+import java.io.File
+import java.io.IOException
+import java.io.InputStream
+
+
+/**
+ * Utility class to read GGUF model files and extract metadata key-value pairs.
+ * This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data.
+ */
+internal class GgufMetadataReaderImpl(
+ private val skipKeys: Set,
+ private val arraySummariseThreshold: Int,
+) : GgufMetadataReader {
+ companion object {
+ private const val ARCH_LLAMA = "llama"
+ }
+
+ /** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */
+ enum class MetadataType(val code: Int) {
+ UINT8(0), INT8(1), UINT16(2), INT16(3),
+ UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
+ STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
+ companion object {
+ private val codeMap = entries.associateBy(MetadataType::code)
+ fun fromCode(code: Int): MetadataType = codeMap[code]
+ ?: throw IOException("Unknown metadata value type code: $code")
+ }
+ }
+
+ /** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */
+ sealed class MetadataValue {
+ data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int
+ data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int
+ data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian)
+ data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian)
+ data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian)
+ data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian)
+ data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float
+ data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true)
+ data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed)
+ data class ArrayVal(val elementType: MetadataType, val elements: List) : MetadataValue()
+ data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian)
+ data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian)
+ data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double
+ }
+
+ /* Convert MetadataValue to plain Kotlin primitives for allMetadata map */
+ private fun MetadataValue.toPrimitive(): Any = when (this) {
+ is MetadataValue.UInt8 -> value
+ is MetadataValue.Int8 -> value
+ is MetadataValue.UInt16 -> value
+ is MetadataValue.Int16 -> value
+ is MetadataValue.UInt32 -> value
+ is MetadataValue.Int32 -> value
+ is MetadataValue.Float32 -> value
+ is MetadataValue.Bool -> value
+ is MetadataValue.StringVal -> value
+ is MetadataValue.UInt64 -> value
+ is MetadataValue.Int64 -> value
+ is MetadataValue.Float64 -> value
+ is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
+ }
+
+ /**
+ * Reads the magic number from the specified file path.
+ *
+ * @param context Context for obtaining ContentResolver
+ * @param uri Uri to the GGUF file provided by ContentProvider
+ * @return true if file is valid GGUF, otherwise false
+ */
+ override suspend fun ensureSourceFileFormat(file: File): Boolean =
+ file.inputStream().buffered().use { ensureMagic(it) }
+
+ /**
+ * Reads the magic number from the specified file path.
+ *
+ * @param context Context for obtaining ContentResolver
+ * @param uri Uri to the GGUF file provided by ContentProvider
+ * @return true if file is valid GGUF, otherwise false
+ */
+ override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean =
+ context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
+
+ /** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */
+ private fun ensureMagic(input: InputStream): Boolean =
+ ByteArray(4).let {
+ if (input.read(it) != 4) throw IOException("Not a valid file!")
+ it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
+ }
+
+ /**
+ * High‑level entry point: parses a `.gguf` file on disk and returns the fully
+ * populated [GgufMetadata] tree.
+ *
+ * Steps performed internally:
+ * 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version).
+ * 2. Streams through the key‑value section, skipping large blobs if the key
+ * appears in [skipKeys] or if an array exceeds [arraySummariseThreshold].
+ * 3. Converts the resulting raw map into strongly‑typed sub‑structures
+ * (basic info, tokenizer, rope, etc.).
+ *
+ * The method is STREAMING‑ONLY: tensors are never mapped or loaded into
+ * memory, so even multi‑GB model files can be processed in < 50 ms.
+ *
+ * @param path Absolute or relative filesystem path to a `.gguf` file.
+ * @return A [GgufMetadata] instance containing all recognised metadata plus
+ * an `allMetadata` map with any keys that were not given a dedicated
+ * field.
+ * @throws IOException if the file is not GGUF, the version is unsupported,
+ * or the metadata block is truncated / corrupt.
+ */
+ override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
+ // ── 1. header ──────────────────────────────────────────────────────────
+ // throws on mismatch
+ val version = ensureMagicAndVersion(input)
+ val tensorCount = readLittleLong(input)
+ val kvCount = readLittleLong(input)
+
+ // ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
+ val meta = readMetaMap(input, kvCount) //
+
+ // ── 3. build structured object ────────────────────────────────────────
+ return buildStructured(meta, version, tensorCount, kvCount)
+ }
+
+ /** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */
+ private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
+ if (!ensureMagic(input)) throw InvalidFileFormatException()
+ return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
+ }
+
+ /**
+ * Read an unsigned 32‑bit little‑endian integer.
+ *
+ * @throws IOException if fewer than four bytes are available.
+ */
+ private fun readLEUInt32(input: InputStream): Int {
+ val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
+ if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
+ return (b3 and 0xFF shl 24) or
+ (b2 and 0xFF shl 16) or
+ (b1 and 0xFF shl 8) or
+ (b0 and 0xFF)
+ }
+
+ /**
+ * Low‑level helper that reads the entire “key-value” section from the current
+ * stream position.
+ *
+ * @param input Open stream positioned JUST AFTER the header.
+ * @param kvCnt Number of key‑value pairs (taken from the header).
+ * @return Mutable map with one [MetadataValue] for every key that is NOT skipped.
+ *
+ * The function honours [skipKeys] and [arraySummariseThreshold] by invoking
+ * [skipValue] or [parseValue] accordingly.
+ */
+ private fun readMetaMap(input: InputStream, kvCnt: Long): Map =
+ mutableMapOf().apply {
+ repeat(kvCnt.toInt()) {
+ val key = readString(input)
+ val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
+ if (key in skipKeys) {
+ skipValue(input, valueT)
+ } else {
+ this[key] = parseValue(input, valueT)
+ }
+ }
+ }
+
+ /**
+ * Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed
+ * [GgufMetadata] tree used by the rest of the app.
+ *
+ * Only the keys listed in the spec are copied into dedicated data classes;
+ * everything else is preserved in `GgufMetadata.allMetadata`.
+ *
+ * @param m Raw key/value map.
+ * @param version GGUF file‑format version (enum).
+ * @param tensorCnt Number of tensors (from the header).
+ * @param kvCnt Total metadata pair count (from the header).
+ */
+ private fun buildStructured(
+ m: Map,
+ version: GgufMetadata.GgufVersion,
+ tensorCnt: Long,
+ kvCnt: Long
+ ): GgufMetadata {
+ // ---------- helpers ----------
+ fun String.str() = (m[this] as? MetadataValue.StringVal)?.value
+ fun String.bool() = (m[this] as? MetadataValue.Bool)?.value
+ fun String.i32() = (m[this] as? MetadataValue.Int32)?.value
+ fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt()
+ fun String.f32() = (m[this] as? MetadataValue.Float32)?.value
+ fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat()
+ fun String.strList(): List? =
+ (m[this] as? MetadataValue.ArrayVal)
+ ?.elements
+ ?.mapNotNull { (it as? MetadataValue.StringVal)?.value }
+
+ val arch = "general.architecture".str() ?: ARCH_LLAMA
+
+ // -------------- populate sections ----------------
+ val basic = GgufMetadata.BasicInfo(
+ uuid = "general.uuid".str(),
+ name = "general.basename".str(),
+ nameLabel = "general.name".str(),
+ sizeLabel = "general.size_label".str()
+ )
+
+ val author = GgufMetadata.AuthorInfo(
+ organization = "general.organization".str(),
+ author = "general.author".str(),
+ doi = "general.doi".str(),
+ url = "general.url".str(),
+ repoUrl = "general.repo_url".str(),
+ license = "general.license".str(),
+ licenseLink = "general.license.link".str()
+ ).takeUnless {
+ organization == null && author == null && doi == null &&
+ url == null && repoUrl == null && license == null && licenseLink == null
+ }
+
+ val additional = GgufMetadata.AdditionalInfo(
+ type = "general.type".str(),
+ description = "general.description".str(),
+ tags = "general.tags".strList(),
+ languages = "general.languages".strList()
+ ).takeUnless {
+ type == null && description == null && tags == null && languages == null
+ }
+
+ val architectureInfo = GgufMetadata.ArchitectureInfo(
+ architecture = arch,
+ fileType = "general.file_type".u32(),
+ vocabSize = "$arch.vocab_size".u32(),
+ finetune = "general.finetune".str(),
+ quantizationVersion = "general.quantization_version".u32()
+ ).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null }
+
+ val baseModels = buildList {
+ val n = "general.base_model.count".u32() ?: 0
+ for (i in 0 until n) {
+ fun k(s: String) = "general.base_model.$i.$s"
+ add(
+ GgufMetadata.BaseModelInfo(
+ name = k("name").str(),
+ author = k("author").str(),
+ version = k("version").str(),
+ organization = k("organization").str(),
+ url = k("url").str(),
+ doi = k("doi").str(),
+ uuid = k("uuid").str(),
+ repoUrl = k("repo_url").str(),
+ )
+ )
+ }
+ }.takeIf { it.isNotEmpty() }
+
+ val tokenizer = GgufMetadata.TokenizerInfo(
+ model = "tokenizer.ggml.model".str(),
+ bosTokenId = "tokenizer.ggml.bos_token_id".u32(),
+ eosTokenId = "tokenizer.ggml.eos_token_id".u32(),
+ unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(),
+ paddingTokenId = "tokenizer.ggml.padding_token_id".u32(),
+ addBosToken = "tokenizer.ggml.add_bos_token".bool(),
+ addEosToken = "tokenizer.ggml.add_eos_token".bool(),
+ chatTemplate = "tokenizer.chat_template".str()
+ ).takeUnless { model == null && bosTokenId == null && eosTokenId == null &&
+ unknownTokenId == null && paddingTokenId == null &&
+ addBosToken == null && addEosToken == null && chatTemplate == null
+ }
+
+ val dimensions = GgufMetadata.DimensionsInfo(
+ contextLength = "$arch.context_length".u32(),
+ embeddingSize = "$arch.embedding_length".u32(),
+ blockCount = "$arch.block_count".u32(),
+ feedForwardSize = "$arch.feed_forward_length".u32()
+ ).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null }
+
+ val attention = GgufMetadata.AttentionInfo(
+ headCount = "$arch.attention.head_count".u32(),
+ headCountKv = "$arch.attention.head_count_kv".u32(),
+ keyLength = "$arch.attention.key_length".u32(),
+ valueLength = "$arch.attention.value_length".u32(),
+ layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(),
+ layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(),
+ ).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null &&
+ layerNormEpsilon == null && layerNormRmsEpsilon == null
+ }
+
+ val rope = GgufMetadata.RopeInfo(
+ frequencyBase = "$arch.rope.freq_base".f32(),
+ dimensionCount = "$arch.rope.dimension_count".u32(),
+ scalingType = "$arch.rope.scaling.type".str(),
+ scalingFactor = "$arch.rope.scaling.factor".f32(),
+ attnFactor = "$arch.rope.scaling.attn_factor".f32(),
+ originalContextLength = "$arch.rope.scaling.original_context_length".u32(),
+ finetuned = "$arch.rope.scaling.finetuned".bool()
+ ).takeUnless { frequencyBase == null && dimensionCount == null &&
+ scalingType == null && scalingFactor == null && attnFactor == null &&
+ originalContextLength == null && finetuned == null
+ }
+
+ val experts = GgufMetadata.ExpertsInfo(
+ count = "$arch.expert_count".u32(),
+ usedCount = "$arch.expert_used_count".u32()
+ ).takeUnless { count == null && usedCount == null }
+
+ return GgufMetadata(
+ version = version,
+ tensorCount = tensorCnt,
+ kvCount = kvCnt,
+ basic = basic,
+ author = author,
+ additional = additional,
+ architecture = architectureInfo,
+ baseModels = baseModels,
+ tokenizer = tokenizer,
+ dimensions = dimensions,
+ attention = attention,
+ rope = rope,
+ experts = experts
+ )
+ }
+
+ /**
+ * Recursively parses a metadata value of the given type from the input stream.
+ * @param input The input stream positioned at the start of the value.
+ * @param type The metadata value type to parse.
+ */
+ private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) {
+ MetadataType.UINT8 -> {
+ // 1-byte unsigned integer
+ val byteVal = input.read()
+ if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.")
+ MetadataValue.UInt8(byteVal.toUByte())
+ }
+ MetadataType.INT8 -> {
+ // 1-byte signed integer
+ val byteVal = input.read()
+ if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.")
+ MetadataValue.Int8(byteVal.toByte())
+ }
+ MetadataType.UINT16 -> {
+ // 2-byte unsigned integer (little-endian)
+ val bytes = ByteArray(2)
+ if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.")
+ // Combine two bytes (little-endian) into an unsigned 16-bit value
+ val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
+ MetadataValue.UInt16(u16.toUShort())
+ }
+ MetadataType.INT16 -> {
+ // 2-byte signed integer (little-endian)
+ val bytes = ByteArray(2)
+ if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.")
+ // Combine to 16-bit and interpret as signed
+ val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
+ MetadataValue.Int16(i16.toShort())
+ }
+ MetadataType.UINT32 -> {
+ // 4-byte unsigned integer (little-endian)
+ val bytes = ByteArray(4)
+ if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.")
+ // Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt
+ val u32 = (bytes[3].toLong() and 0xFFL shl 24) or
+ (bytes[2].toLong() and 0xFFL shl 16) or
+ (bytes[1].toLong() and 0xFFL shl 8) or
+ (bytes[0].toLong() and 0xFFL)
+ MetadataValue.UInt32(u32.toUInt())
+ }
+ MetadataType.INT32 -> {
+ // 4-byte signed integer (little-endian)
+ val bytes = ByteArray(4)
+ if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.")
+ // Combine four bytes into a 32-bit signed int
+ val i32 = (bytes[3].toInt() and 0xFF shl 24) or
+ (bytes[2].toInt() and 0xFF shl 16) or
+ (bytes[1].toInt() and 0xFF shl 8) or
+ (bytes[0].toInt() and 0xFF)
+ MetadataValue.Int32(i32)
+ }
+ MetadataType.FLOAT32 -> {
+ // 4-byte IEEE 754 float (little-endian)
+ val bytes = ByteArray(4)
+ if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.")
+ // Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float
+ val bits = (bytes[3].toInt() and 0xFF shl 24) or
+ (bytes[2].toInt() and 0xFF shl 16) or
+ (bytes[1].toInt() and 0xFF shl 8) or
+ (bytes[0].toInt() and 0xFF)
+ val floatVal = Float.fromBits(bits)
+ MetadataValue.Float32(floatVal)
+ }
+ MetadataType.BOOL -> {
+ // 1-byte boolean (0 = false, 1 = true)
+ val byteVal = input.read()
+ if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.")
+ if (byteVal != 0 && byteVal != 1) {
+ throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).")
+ }
+ MetadataValue.Bool(byteVal != 0)
+ }
+ MetadataType.STRING -> {
+ // UTF-8 string (length-prefixed with 8-byte length)
+ val str = readString(input)
+ MetadataValue.StringVal(str)
+ }
+ MetadataType.ARRAY -> {
+ val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
+ val len = readLittleLong(input)
+ val count = len.toInt()
+
+ if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) {
+ // fast‑forward without allocation
+ repeat(count) { skipValue(input, elemType) }
+ MetadataValue.StringVal("Array($elemType, $count items) /* summarised */")
+ } else {
+ val list = ArrayList(count)
+ repeat(count) { list += parseValue(input, elemType) }
+ MetadataValue.ArrayVal(elemType, list)
+ }
+ }
+ MetadataType.UINT64 -> {
+ // 8-byte unsigned integer (little-endian)
+ val bytes = ByteArray(8)
+ if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.")
+ // Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range.
+ val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or
+ (bytes[6].toULong() and 0xFFuL shl 48) or
+ (bytes[5].toULong() and 0xFFuL shl 40) or
+ (bytes[4].toULong() and 0xFFuL shl 32) or
+ (bytes[3].toULong() and 0xFFuL shl 24) or
+ (bytes[2].toULong() and 0xFFuL shl 16) or
+ (bytes[1].toULong() and 0xFFuL shl 8) or
+ (bytes[0].toULong() and 0xFFuL)
+ MetadataValue.UInt64(u64)
+ }
+ MetadataType.INT64 -> {
+ // 8-byte signed integer (little-endian)
+ val bytes = ByteArray(8)
+ if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.")
+ // Combine 8 bytes into a signed 64-bit value (Long)
+ val i64 = (bytes[7].toLong() and 0xFFL shl 56) or
+ (bytes[6].toLong() and 0xFFL shl 48) or
+ (bytes[5].toLong() and 0xFFL shl 40) or
+ (bytes[4].toLong() and 0xFFL shl 32) or
+ (bytes[3].toLong() and 0xFFL shl 24) or
+ (bytes[2].toLong() and 0xFFL shl 16) or
+ (bytes[1].toLong() and 0xFFL shl 8) or
+ (bytes[0].toLong() and 0xFFL)
+ MetadataValue.Int64(i64)
+ }
+ MetadataType.FLOAT64 -> {
+ // 8-byte IEEE 754 double (little-endian)
+ val bytes = ByteArray(8)
+ if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.")
+ // Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double
+ val bits = (bytes[7].toLong() and 0xFFL shl 56) or
+ (bytes[6].toLong() and 0xFFL shl 48) or
+ (bytes[5].toLong() and 0xFFL shl 40) or
+ (bytes[4].toLong() and 0xFFL shl 32) or
+ (bytes[3].toLong() and 0xFFL shl 24) or
+ (bytes[2].toLong() and 0xFFL shl 16) or
+ (bytes[1].toLong() and 0xFFL shl 8) or
+ (bytes[0].toLong() and 0xFFL)
+ val doubleVal = Double.fromBits(bits)
+ MetadataValue.Float64(doubleVal)
+ }
+ }
+
+
+ private fun T?.takeUnless(check: T.() -> Boolean): T? =
+ this?.takeIf { !it.check() }
+
+ /** Helper: Skip a value in the stream without storing it (still maintains pointer). */
+ private fun skipValue(input: InputStream, type: MetadataType) {
+ when (type) {
+ MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1)
+ MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2)
+ MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4)
+ MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8)
+ MetadataType.STRING -> {
+ val len = readLittleLong(input); input.skipFully(len)
+ }
+ MetadataType.ARRAY -> {
+ val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
+ val len = readLittleLong(input)
+ repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip
+ }
+ }
+ }
+
+ /** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */
+ private fun readLittleLong(input: InputStream): Long {
+ val bytes = ByteArray(8)
+ input.readFully(bytes)
+
+ // Combine 8 bytes into a 64-bit value (Little Endian).
+ // Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement).
+ // In our context (lengths/counts), such extremely large values are not expected.
+ return (bytes[7].toLong() and 0xFFL shl 56) or
+ (bytes[6].toLong() and 0xFFL shl 48) or
+ (bytes[5].toLong() and 0xFFL shl 40) or
+ (bytes[4].toLong() and 0xFFL shl 32) or
+ (bytes[3].toLong() and 0xFFL shl 24) or
+ (bytes[2].toLong() and 0xFFL shl 16) or
+ (bytes[1].toLong() and 0xFFL shl 8) or
+ (bytes[0].toLong() and 0xFFL)
+ }
+
+ /** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
+ private fun readString(input: InputStream): String =
+ // Read 8-byte little-endian length (number of bytes in the string).
+ readLittleLong(input).let { len ->
+ if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
+
+ // Read the UTF-8 bytes of the given length.
+ ByteArray(len.toInt()).let {
+ if (it.isNotEmpty()) input.readFully(it)
+ String(it, Charsets.UTF_8)
+ }
+ }
+
+ /** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
+ private fun littleEndianBytesToInt(bytes: ByteArray): Int =
+ // Note: assumes bytes length is 4.
+ (bytes[3].toInt() and 0xFF shl 24) or
+ (bytes[2].toInt() and 0xFF shl 16) or
+ (bytes[1].toInt() and 0xFF shl 8) or
+ (bytes[0].toInt() and 0xFF)
+
+ /**
+ * Robust skip that works the same on JDK 11 and Android’s desugared runtime.
+ *
+ * @param n Number of bytes to advance in the stream.
+ * @throws IOException on premature EOF.
+ */
+ private fun InputStream.skipFully(n: Long) {
+ var remaining = n
+ val scratch = ByteArray(8192) // read‑and‑toss buffer
+ while (remaining > 0) {
+ val skipped = skip(remaining)
+ when {
+ skipped > 0 -> remaining -= skipped // normal fast path
+ skipped == 0L -> {
+ // fallback: read and discard
+ val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt())
+ if (read == -1) throw IOException("EOF while skipping $n bytes")
+ remaining -= read
+ }
+ else -> throw IOException("Skip returned negative value")
+ }
+ }
+ }
+
+ /**
+ * Extension that keeps reading until the requested number of bytes are filled.
+ * Falls back to `read()` when `skip()` returns 0, which happens on some Android
+ * streams.
+ *
+ * @param buf Destination buffer.
+ * @param len Number of bytes to fill (defaults to `buf.size`).
+ * @throws IOException on premature EOF.
+ */
+ private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) {
+ var off = 0
+ while (off < len) {
+ val n = read(buf, off, len - off)
+ if (n == -1) throw IOException("EOF after $off of $len bytes")
+ off += n
+ }
+ }
+
+ /**
+ * Read EXACTLY `n` bytes or throw – never returns a partially‑filled array.
+ * This is used for small fixed‑length reads (e.g. 4‑byte type codes).
+ *
+ * @throws IOException on premature EOF.
+ */
+ private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
+ if (read(it) != n) throw IOException("Unexpected EOF")
+ }
+}
diff --git a/examples/llama.android/llama/src/test/java/android/llama/cpp/ExampleUnitTest.kt b/examples/llama.android/lib/src/test/java/android/llama/cpp/ExampleUnitTest.kt
similarity index 100%
rename from examples/llama.android/llama/src/test/java/android/llama/cpp/ExampleUnitTest.kt
rename to examples/llama.android/lib/src/test/java/android/llama/cpp/ExampleUnitTest.kt
diff --git a/examples/llama.android/llama/build.gradle.kts b/examples/llama.android/llama/build.gradle.kts
deleted file mode 100644
index 5bb6478022..0000000000
--- a/examples/llama.android/llama/build.gradle.kts
+++ /dev/null
@@ -1,71 +0,0 @@
-plugins {
- id("com.android.library")
- id("org.jetbrains.kotlin.android")
-}
-
-android {
- namespace = "android.llama.cpp"
- compileSdk = 34
-
- defaultConfig {
- minSdk = 33
-
- testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
- consumerProguardFiles("consumer-rules.pro")
- ndk {
- // Add NDK properties if wanted, e.g.
- // abiFilters += listOf("arm64-v8a")
- }
- externalNativeBuild {
- cmake {
- arguments += "-DLLAMA_CURL=OFF"
- arguments += "-DLLAMA_BUILD_COMMON=ON"
- arguments += "-DGGML_LLAMAFILE=OFF"
- arguments += "-DCMAKE_BUILD_TYPE=Release"
- cppFlags += listOf()
- arguments += listOf()
-
- cppFlags("")
- }
- }
- }
-
- buildTypes {
- release {
- isMinifyEnabled = false
- proguardFiles(
- getDefaultProguardFile("proguard-android-optimize.txt"),
- "proguard-rules.pro"
- )
- }
- }
- externalNativeBuild {
- cmake {
- path("src/main/cpp/CMakeLists.txt")
- version = "3.22.1"
- }
- }
- compileOptions {
- sourceCompatibility = JavaVersion.VERSION_1_8
- targetCompatibility = JavaVersion.VERSION_1_8
- }
- kotlinOptions {
- jvmTarget = "1.8"
- }
-
- packaging {
- resources {
- excludes += "/META-INF/{AL2.0,LGPL2.1}"
- }
- }
-}
-
-dependencies {
-
- implementation("androidx.core:core-ktx:1.12.0")
- implementation("androidx.appcompat:appcompat:1.6.1")
- implementation("com.google.android.material:material:1.11.0")
- testImplementation("junit:junit:4.13.2")
- androidTestImplementation("androidx.test.ext:junit:1.1.5")
- androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")
-}
diff --git a/examples/llama.android/llama/consumer-rules.pro b/examples/llama.android/llama/consumer-rules.pro
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt b/examples/llama.android/llama/src/main/cpp/CMakeLists.txt
deleted file mode 100644
index 6119fe09b0..0000000000
--- a/examples/llama.android/llama/src/main/cpp/CMakeLists.txt
+++ /dev/null
@@ -1,53 +0,0 @@
-# For more information about using CMake with Android Studio, read the
-# documentation: https://d.android.com/studio/projects/add-native-code.html.
-# For more examples on how to use CMake, see https://github.com/android/ndk-samples.
-
-# Sets the minimum CMake version required for this project.
-cmake_minimum_required(VERSION 3.22.1)
-
-# Declares the project name. The project name can be accessed via ${ PROJECT_NAME},
-# Since this is the top level CMakeLists.txt, the project name is also accessible
-# with ${CMAKE_PROJECT_NAME} (both CMake variables are in-sync within the top level
-# build script scope).
-project("llama-android")
-
-#include(FetchContent)
-#FetchContent_Declare(
-# llama
-# GIT_REPOSITORY https://github.com/ggml-org/llama.cpp
-# GIT_TAG master
-#)
-
-# Also provides "common"
-#FetchContent_MakeAvailable(llama)
-
-# Creates and names a library, sets it as either STATIC
-# or SHARED, and provides the relative paths to its source code.
-# You can define multiple libraries, and CMake builds them for you.
-# Gradle automatically packages shared libraries with your APK.
-#
-# In this top level CMakeLists.txt, ${CMAKE_PROJECT_NAME} is used to define
-# the target library name; in the sub-module's CMakeLists.txt, ${PROJECT_NAME}
-# is preferred for the same purpose.
-#
-
-#load local llama.cpp
-add_subdirectory(../../../../../../ build-llama)
-
-# In order to load a library into your app from Java/Kotlin, you must call
-# System.loadLibrary() and pass the name of the library defined here;
-# for GameActivity/NativeActivity derived applications, the same library name must be
-# used in the AndroidManifest.xml file.
-add_library(${CMAKE_PROJECT_NAME} SHARED
- # List C/C++ source files with relative paths to this CMakeLists.txt.
- llama-android.cpp)
-
-# Specifies libraries CMake should link to your target library. You
-# can link libraries from various origins, such as libraries defined in this
-# build script, prebuilt third-party libraries, or Android system libraries.
-target_link_libraries(${CMAKE_PROJECT_NAME}
- # List libraries link to the target library
- llama
- common
- android
- log)
diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp
deleted file mode 100644
index 711ddc5d19..0000000000
--- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp
+++ /dev/null
@@ -1,452 +0,0 @@
-#include
-#include
-#include
-#include
-#include
-#include
-#include "llama.h"
-#include "common.h"
-
-// Write C++ code here.
-//
-// Do not forget to dynamically load the C++ library into your application.
-//
-// For instance,
-//
-// In MainActivity.java:
-// static {
-// System.loadLibrary("llama-android");
-// }
-//
-// Or, in MainActivity.kt:
-// companion object {
-// init {
-// System.loadLibrary("llama-android")
-// }
-// }
-
-#define TAG "llama-android.cpp"
-#define LOGi(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
-#define LOGe(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
-
-jclass la_int_var;
-jmethodID la_int_var_value;
-jmethodID la_int_var_inc;
-
-std::string cached_token_chars;
-
-bool is_valid_utf8(const char * string) {
- if (!string) {
- return true;
- }
-
- const unsigned char * bytes = (const unsigned char *)string;
- int num;
-
- while (*bytes != 0x00) {
- if ((*bytes & 0x80) == 0x00) {
- // U+0000 to U+007F
- num = 1;
- } else if ((*bytes & 0xE0) == 0xC0) {
- // U+0080 to U+07FF
- num = 2;
- } else if ((*bytes & 0xF0) == 0xE0) {
- // U+0800 to U+FFFF
- num = 3;
- } else if ((*bytes & 0xF8) == 0xF0) {
- // U+10000 to U+10FFFF
- num = 4;
- } else {
- return false;
- }
-
- bytes += 1;
- for (int i = 1; i < num; ++i) {
- if ((*bytes & 0xC0) != 0x80) {
- return false;
- }
- bytes += 1;
- }
- }
-
- return true;
-}
-
-static void log_callback(ggml_log_level level, const char * fmt, void * data) {
- if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
- else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
- else if (level == GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
- else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
-}
-
-extern "C"
-JNIEXPORT jlong JNICALL
-Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring filename) {
- llama_model_params model_params = llama_model_default_params();
-
- auto path_to_model = env->GetStringUTFChars(filename, 0);
- LOGi("Loading model from %s", path_to_model);
-
- auto model = llama_model_load_from_file(path_to_model, model_params);
- env->ReleaseStringUTFChars(filename, path_to_model);
-
- if (!model) {
- LOGe("load_model() failed");
- env->ThrowNew(env->FindClass("java/lang/IllegalStateException"), "load_model() failed");
- return 0;
- }
-
- return reinterpret_cast(model);
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
- llama_model_free(reinterpret_cast(model));
-}
-
-extern "C"
-JNIEXPORT jlong JNICALL
-Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmodel) {
- auto model = reinterpret_cast(jmodel);
-
- if (!model) {
- LOGe("new_context(): model cannot be null");
- env->ThrowNew(env->FindClass("java/lang/IllegalArgumentException"), "Model cannot be null");
- return 0;
- }
-
- int n_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2));
- LOGi("Using %d threads", n_threads);
-
- llama_context_params ctx_params = llama_context_default_params();
-
- ctx_params.n_ctx = 2048;
- ctx_params.n_threads = n_threads;
- ctx_params.n_threads_batch = n_threads;
-
- llama_context * context = llama_new_context_with_model(model, ctx_params);
-
- if (!context) {
- LOGe("llama_new_context_with_model() returned null)");
- env->ThrowNew(env->FindClass("java/lang/IllegalStateException"),
- "llama_new_context_with_model() returned null)");
- return 0;
- }
-
- return reinterpret_cast(context);
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_android_llama_cpp_LLamaAndroid_free_1context(JNIEnv *, jobject, jlong context) {
- llama_free(reinterpret_cast(context));
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_android_llama_cpp_LLamaAndroid_backend_1free(JNIEnv *, jobject) {
- llama_backend_free();
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_android_llama_cpp_LLamaAndroid_log_1to_1android(JNIEnv *, jobject) {
- llama_log_set(log_callback, NULL);
-}
-
-extern "C"
-JNIEXPORT jstring JNICALL
-Java_android_llama_cpp_LLamaAndroid_bench_1model(
- JNIEnv *env,
- jobject,
- jlong context_pointer,
- jlong model_pointer,
- jlong batch_pointer,
- jint pp,
- jint tg,
- jint pl,
- jint nr
- ) {
- auto pp_avg = 0.0;
- auto tg_avg = 0.0;
- auto pp_std = 0.0;
- auto tg_std = 0.0;
-
- const auto context = reinterpret_cast(context_pointer);
- const auto model = reinterpret_cast(model_pointer);
- const auto batch = reinterpret_cast(batch_pointer);
-
- const int n_ctx = llama_n_ctx(context);
-
- LOGi("n_ctx = %d", n_ctx);
-
- int i, j;
- int nri;
- for (nri = 0; nri < nr; nri++) {
- LOGi("Benchmark prompt processing (pp)");
-
- common_batch_clear(*batch);
-
- const int n_tokens = pp;
- for (i = 0; i < n_tokens; i++) {
- common_batch_add(*batch, 0, i, { 0 }, false);
- }
-
- batch->logits[batch->n_tokens - 1] = true;
- llama_memory_clear(llama_get_memory(context), false);
-
- const auto t_pp_start = ggml_time_us();
- if (llama_decode(context, *batch) != 0) {
- LOGi("llama_decode() failed during prompt processing");
- }
- const auto t_pp_end = ggml_time_us();
-
- // bench text generation
-
- LOGi("Benchmark text generation (tg)");
-
- llama_memory_clear(llama_get_memory(context), false);
- const auto t_tg_start = ggml_time_us();
- for (i = 0; i < tg; i++) {
-
- common_batch_clear(*batch);
- for (j = 0; j < pl; j++) {
- common_batch_add(*batch, 0, i, { j }, true);
- }
-
- LOGi("llama_decode() text generation: %d", i);
- if (llama_decode(context, *batch) != 0) {
- LOGi("llama_decode() failed during text generation");
- }
- }
-
- const auto t_tg_end = ggml_time_us();
-
- llama_memory_clear(llama_get_memory(context), false);
-
- const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
- const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
-
- const auto speed_pp = double(pp) / t_pp;
- const auto speed_tg = double(pl * tg) / t_tg;
-
- pp_avg += speed_pp;
- tg_avg += speed_tg;
-
- pp_std += speed_pp * speed_pp;
- tg_std += speed_tg * speed_tg;
-
- LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
- }
-
- pp_avg /= double(nr);
- tg_avg /= double(nr);
-
- if (nr > 1) {
- pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
- tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
- } else {
- pp_std = 0;
- tg_std = 0;
- }
-
- char model_desc[128];
- llama_model_desc(model, model_desc, sizeof(model_desc));
-
- const auto model_size = double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0;
- const auto model_n_params = double(llama_model_n_params(model)) / 1e9;
-
- const auto backend = "(Android)"; // TODO: What should this be?
-
- std::stringstream result;
- result << std::setprecision(2);
- result << "| model | size | params | backend | test | t/s |\n";
- result << "| --- | --- | --- | --- | --- | --- |\n";
- result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
- result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
-
- return env->NewStringUTF(result.str().c_str());
-}
-
-extern "C"
-JNIEXPORT jlong JNICALL
-Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
-
- // Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
-
- llama_batch *batch = new llama_batch {
- 0,
- nullptr,
- nullptr,
- nullptr,
- nullptr,
- nullptr,
- nullptr,
- };
-
- if (embd) {
- batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
- } else {
- batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
- }
-
- batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
- batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
- batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
- for (int i = 0; i < n_tokens; ++i) {
- batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
- }
- batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
-
- return reinterpret_cast(batch);
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
- //llama_batch_free(*reinterpret_cast(batch_pointer));
- const auto batch = reinterpret_cast(batch_pointer);
- delete batch;
-}
-
-extern "C"
-JNIEXPORT jlong JNICALL
-Java_android_llama_cpp_LLamaAndroid_new_1sampler(JNIEnv *, jobject) {
- auto sparams = llama_sampler_chain_default_params();
- sparams.no_perf = true;
- llama_sampler * smpl = llama_sampler_chain_init(sparams);
- llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
-
- return reinterpret_cast(smpl);
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_android_llama_cpp_LLamaAndroid_free_1sampler(JNIEnv *, jobject, jlong sampler_pointer) {
- llama_sampler_free(reinterpret_cast(sampler_pointer));
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_android_llama_cpp_LLamaAndroid_backend_1init(JNIEnv *, jobject) {
- llama_backend_init();
-}
-
-extern "C"
-JNIEXPORT jstring JNICALL
-Java_android_llama_cpp_LLamaAndroid_system_1info(JNIEnv *env, jobject) {
- return env->NewStringUTF(llama_print_system_info());
-}
-
-extern "C"
-JNIEXPORT jint JNICALL
-Java_android_llama_cpp_LLamaAndroid_completion_1init(
- JNIEnv *env,
- jobject,
- jlong context_pointer,
- jlong batch_pointer,
- jstring jtext,
- jboolean format_chat,
- jint n_len
- ) {
-
- cached_token_chars.clear();
-
- const auto text = env->GetStringUTFChars(jtext, 0);
- const auto context = reinterpret_cast(context_pointer);
- const auto batch = reinterpret_cast(batch_pointer);
-
- bool parse_special = (format_chat == JNI_TRUE);
- const auto tokens_list = common_tokenize(context, text, true, parse_special);
-
- auto n_ctx = llama_n_ctx(context);
- auto n_kv_req = tokens_list.size() + n_len;
-
- LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
-
- if (n_kv_req > n_ctx) {
- LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
- }
-
- for (auto id : tokens_list) {
- LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
- }
-
- common_batch_clear(*batch);
-
- // evaluate the initial prompt
- for (auto i = 0; i < tokens_list.size(); i++) {
- common_batch_add(*batch, tokens_list[i], i, { 0 }, false);
- }
-
- // llama_decode will output logits only for the last token of the prompt
- batch->logits[batch->n_tokens - 1] = true;
-
- if (llama_decode(context, *batch) != 0) {
- LOGe("llama_decode() failed");
- }
-
- env->ReleaseStringUTFChars(jtext, text);
-
- return batch->n_tokens;
-}
-
-extern "C"
-JNIEXPORT jstring JNICALL
-Java_android_llama_cpp_LLamaAndroid_completion_1loop(
- JNIEnv * env,
- jobject,
- jlong context_pointer,
- jlong batch_pointer,
- jlong sampler_pointer,
- jint n_len,
- jobject intvar_ncur
-) {
- const auto context = reinterpret_cast(context_pointer);
- const auto batch = reinterpret_cast(batch_pointer);
- const auto sampler = reinterpret_cast(sampler_pointer);
- const auto model = llama_get_model(context);
- const auto vocab = llama_model_get_vocab(model);
-
- if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
- if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
- if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
-
- // sample the most likely token
- const auto new_token_id = llama_sampler_sample(sampler, context, -1);
-
- const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
- if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
- return nullptr;
- }
-
- auto new_token_chars = common_token_to_piece(context, new_token_id);
- cached_token_chars += new_token_chars;
-
- jstring new_token = nullptr;
- if (is_valid_utf8(cached_token_chars.c_str())) {
- new_token = env->NewStringUTF(cached_token_chars.c_str());
- LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
- cached_token_chars.clear();
- } else {
- new_token = env->NewStringUTF("");
- }
-
- common_batch_clear(*batch);
- common_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
-
- env->CallVoidMethod(intvar_ncur, la_int_var_inc);
-
- if (llama_decode(context, *batch) != 0) {
- LOGe("llama_decode() returned null");
- }
-
- return new_token;
-}
-
-extern "C"
-JNIEXPORT void JNICALL
-Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
- llama_memory_clear(llama_get_memory(reinterpret_cast(context)), true);
-}
diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt
deleted file mode 100644
index b964d93e37..0000000000
--- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt
+++ /dev/null
@@ -1,180 +0,0 @@
-package android.llama.cpp
-
-import android.util.Log
-import kotlinx.coroutines.CoroutineDispatcher
-import kotlinx.coroutines.asCoroutineDispatcher
-import kotlinx.coroutines.flow.Flow
-import kotlinx.coroutines.flow.flow
-import kotlinx.coroutines.flow.flowOn
-import kotlinx.coroutines.withContext
-import java.util.concurrent.Executors
-import kotlin.concurrent.thread
-
-class LLamaAndroid {
- private val tag: String? = this::class.simpleName
-
- private val threadLocalState: ThreadLocal = ThreadLocal.withInitial { State.Idle }
-
- private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
- thread(start = false, name = "Llm-RunLoop") {
- Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
-
- // No-op if called more than once.
- System.loadLibrary("llama-android")
-
- // Set llama log handler to Android
- log_to_android()
- backend_init(false)
-
- Log.d(tag, system_info())
-
- it.run()
- }.apply {
- uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
- Log.e(tag, "Unhandled exception", exception)
- }
- }
- }.asCoroutineDispatcher()
-
- private val nlen: Int = 64
-
- private external fun log_to_android()
- private external fun load_model(filename: String): Long
- private external fun free_model(model: Long)
- private external fun new_context(model: Long): Long
- private external fun free_context(context: Long)
- private external fun backend_init(numa: Boolean)
- private external fun backend_free()
- private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
- private external fun free_batch(batch: Long)
- private external fun new_sampler(): Long
- private external fun free_sampler(sampler: Long)
- private external fun bench_model(
- context: Long,
- model: Long,
- batch: Long,
- pp: Int,
- tg: Int,
- pl: Int,
- nr: Int
- ): String
-
- private external fun system_info(): String
-
- private external fun completion_init(
- context: Long,
- batch: Long,
- text: String,
- formatChat: Boolean,
- nLen: Int
- ): Int
-
- private external fun completion_loop(
- context: Long,
- batch: Long,
- sampler: Long,
- nLen: Int,
- ncur: IntVar
- ): String?
-
- private external fun kv_cache_clear(context: Long)
-
- suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
- return withContext(runLoop) {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- Log.d(tag, "bench(): $state")
- bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
- }
-
- else -> throw IllegalStateException("No model loaded")
- }
- }
- }
-
- suspend fun load(pathToModel: String) {
- withContext(runLoop) {
- when (threadLocalState.get()) {
- is State.Idle -> {
- val model = load_model(pathToModel)
- if (model == 0L) throw IllegalStateException("load_model() failed")
-
- val context = new_context(model)
- if (context == 0L) throw IllegalStateException("new_context() failed")
-
- val batch = new_batch(512, 0, 1)
- if (batch == 0L) throw IllegalStateException("new_batch() failed")
-
- val sampler = new_sampler()
- if (sampler == 0L) throw IllegalStateException("new_sampler() failed")
-
- Log.i(tag, "Loaded model $pathToModel")
- threadLocalState.set(State.Loaded(model, context, batch, sampler))
- }
- else -> throw IllegalStateException("Model already loaded")
- }
- }
- }
-
- fun send(message: String, formatChat: Boolean = false): Flow = flow {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- val ncur = IntVar(completion_init(state.context, state.batch, message, formatChat, nlen))
- while (ncur.value <= nlen) {
- val str = completion_loop(state.context, state.batch, state.sampler, nlen, ncur)
- if (str == null) {
- break
- }
- emit(str)
- }
- kv_cache_clear(state.context)
- }
- else -> {}
- }
- }.flowOn(runLoop)
-
- /**
- * Unloads the model and frees resources.
- *
- * This is a no-op if there's no model loaded.
- */
- suspend fun unload() {
- withContext(runLoop) {
- when (val state = threadLocalState.get()) {
- is State.Loaded -> {
- free_context(state.context)
- free_model(state.model)
- free_batch(state.batch)
- free_sampler(state.sampler);
-
- threadLocalState.set(State.Idle)
- }
- else -> {}
- }
- }
- }
-
- companion object {
- private class IntVar(value: Int) {
- @Volatile
- var value: Int = value
- private set
-
- fun inc() {
- synchronized(this) {
- value += 1
- }
- }
- }
-
- private sealed interface State {
- data object Idle: State
- data class Loaded(val model: Long, val context: Long, val batch: Long, val sampler: Long): State
- }
-
- // Enforce only one instance of Llm.
- private val _instance: LLamaAndroid = LLamaAndroid()
-
- fun instance(): LLamaAndroid = _instance
- }
-}
diff --git a/examples/llama.android/settings.gradle.kts b/examples/llama.android/settings.gradle.kts
index c7c1a034a4..74f4eb3e46 100644
--- a/examples/llama.android/settings.gradle.kts
+++ b/examples/llama.android/settings.gradle.kts
@@ -8,11 +8,11 @@ pluginManagement {
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
- google()
mavenCentral()
+ google()
}
}
-rootProject.name = "LlamaAndroid"
+rootProject.name = "AiChat"
include(":app")
-include(":llama")
+include(":lib")
diff --git a/examples/model-conversion/README.md b/examples/model-conversion/README.md
index 05d95d588b..8163b306b4 100644
--- a/examples/model-conversion/README.md
+++ b/examples/model-conversion/README.md
@@ -10,6 +10,13 @@ and in some cases perplexity checked of the quantized model. And finally the
model/models need to the ggml-org on Hugging Face. This tool/example tries to
help with this process.
+> 📝 **Note:** When adding a new model from an existing family, verify the
+> previous version passes logits verification first. Existing models can have
+> subtle numerical differences that don't affect generation quality but cause
+> logits mismatches. Identifying these upfront whether they exist in llama.cpp,
+> the conversion script, or in an upstream implementation, can save significant
+> debugging time.
+
### Overview
The idea is that the makefile targets and scripts here can be used in the
development/conversion process assisting with things like:
diff --git a/examples/model-conversion/scripts/causal/modelcard.template b/examples/model-conversion/scripts/causal/modelcard.template
index 87800a1b93..cfa8e6b433 100644
--- a/examples/model-conversion/scripts/causal/modelcard.template
+++ b/examples/model-conversion/scripts/causal/modelcard.template
@@ -7,7 +7,7 @@ base_model:
Recommended way to run this model:
```sh
-llama-server -hf {namespace}/{model_name}-GGUF -c 0 -fa
+llama-server -hf {namespace}/{model_name}-GGUF -c 0
```
Then, access http://localhost:8080
diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py
index da1132c003..14bb12fe68 100755
--- a/examples/model-conversion/scripts/causal/run-org-model.py
+++ b/examples/model-conversion/scripts/causal/run-org-model.py
@@ -2,134 +2,22 @@
import argparse
import os
+import sys
import importlib
from pathlib import Path
-from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
+# Add parent directory to path for imports
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+
+from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
import torch
import numpy as np
-
-### If you want to dump RoPE activations, apply this monkey patch to the model
-### class from Transformers that you are running (replace apertus.modeling_apertus
-### with the proper package and class for your model
-### === START ROPE DEBUG ===
-# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
-
-# orig_rope = apply_rotary_pos_emb
-# torch.set_printoptions(threshold=float('inf'))
-# torch.set_printoptions(precision=6, sci_mode=False)
-
-# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
-# # log inputs
-# summarize(q, "RoPE.q_in")
-# summarize(k, "RoPE.k_in")
-
-# # call original
-# q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
-
-# # log outputs
-# summarize(q_out, "RoPE.q_out")
-# summarize(k_out, "RoPE.k_out")
-
-# return q_out, k_out
-
-# # Patch it
-# import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
-# apertus_mod.apply_rotary_pos_emb = debug_rope
-### == END ROPE DEBUG ===
-
-
-def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
- """
- Print a tensor in llama.cpp debug style.
-
- Supports:
- - 2D tensors (seq, hidden)
- - 3D tensors (batch, seq, hidden)
- - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
-
- Shows first and last max_vals of each vector per sequence position.
- """
- t = tensor.detach().to(torch.float32).cpu()
-
- # Determine dimensions
- if t.ndim == 3:
- _, s, _ = t.shape
- elif t.ndim == 2:
- _, s = 1, t.shape[0]
- t = t.unsqueeze(0)
- elif t.ndim == 4:
- _, s, _, _ = t.shape
- else:
- print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
- return
-
- ten_shape = t.shape
-
- print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
- print(" [")
- print(" [")
-
- # Determine indices for first and last sequences
- first_indices = list(range(min(s, max_seq)))
- last_indices = list(range(max(0, s - max_seq), s))
-
- # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
- has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
-
- # Combine indices
- if has_overlap:
- # If there's overlap, just use the combined unique indices
- indices = sorted(list(set(first_indices + last_indices)))
- separator_index = None
- else:
- # If no overlap, we'll add a separator between first and last sequences
- indices = first_indices + last_indices
- separator_index = len(first_indices)
-
- for i, si in enumerate(indices):
- # Add separator if needed
- if separator_index is not None and i == separator_index:
- print(" ...")
-
- # Extract appropriate slice
- vec = t[0, si]
- if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
- flat = vec.flatten().tolist()
- else: # 2D or 3D case
- flat = vec.tolist()
-
- # First and last slices
- first = flat[:max_vals]
- last = flat[-max_vals:] if len(flat) >= max_vals else flat
- first_str = ", ".join(f"{v:12.4f}" for v in first)
- last_str = ", ".join(f"{v:12.4f}" for v in last)
-
- print(f" [{first_str}, ..., {last_str}]")
-
- print(" ],")
- print(" ]")
- print(f" sum = {t.sum().item():.6f}\n")
-
-
-def debug_hook(name):
- def fn(_m, input, output):
- if isinstance(input, torch.Tensor):
- summarize(input, name + "_in")
- elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor):
- summarize(input[0], name + "_in")
- if isinstance(output, torch.Tensor):
- summarize(output, name + "_out")
- elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor):
- summarize(output[0], name + "_out")
-
- return fn
-
-
-unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
+from utils.common import debug_hook
parser = argparse.ArgumentParser(description="Process model with specified path")
parser.add_argument("--model-path", "-m", help="Path to the model")
+parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False)
+parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug output")
args = parser.parse_args()
model_path = os.environ.get("MODEL_PATH", args.model_path)
@@ -138,18 +26,30 @@ if model_path is None:
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
)
+### If you want to dump RoPE activations, uncomment the following lines:
+### === START ROPE DEBUG ===
+# from utils.common import setup_rope_debug
+# setup_rope_debug("transformers.models.apertus.modeling_apertus")
+### == END ROPE DEBUG ===
+
print("Loading model and tokenizer using AutoTokenizer:", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+multimodal = False
+full_config = config
print("Model type: ", config.model_type)
+if "vocab_size" not in config and "text_config" in config:
+ config = config.text_config
+ multimodal = True
print("Vocab size: ", config.vocab_size)
print("Hidden size: ", config.hidden_size)
print("Number of layers: ", config.num_hidden_layers)
print("BOS token id: ", config.bos_token_id)
print("EOS token id: ", config.eos_token_id)
+unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
if unreleased_model_name:
model_name_lower = unreleased_model_name.lower()
unreleased_module_path = (
@@ -169,13 +69,19 @@ if unreleased_model_name:
print(f"Failed to import or load model: {e}")
exit(1)
else:
- model = AutoModelForCausalLM.from_pretrained(
- model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
- )
+ if multimodal:
+ model = AutoModelForImageTextToText.from_pretrained(
+ model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=full_config
+ )
+ else:
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path, device_map="auto", offload_folder="offload", trust_remote_code=True, config=config
+ )
-for name, module in model.named_modules():
- if len(list(module.children())) == 0: # only leaf modules
- module.register_forward_hook(debug_hook(name))
+if args.verbose:
+ for name, module in model.named_modules():
+ if len(list(module.children())) == 0: # only leaf modules
+ module.register_forward_hook(debug_hook(name))
model_name = os.path.basename(model_path)
# Printing the Model class to allow for easier debugging. This can be useful
@@ -185,7 +91,10 @@ model_name = os.path.basename(model_path)
print(f"Model class: {model.__class__.__name__}")
device = next(model.parameters()).device
-if os.getenv("MODEL_TESTING_PROMPT"):
+if args.prompt_file:
+ with open(args.prompt_file, encoding='utf-8') as f:
+ prompt = f.read()
+elif os.getenv("MODEL_TESTING_PROMPT"):
prompt = os.getenv("MODEL_TESTING_PROMPT")
else:
prompt = "Hello, my name is"
@@ -195,9 +104,18 @@ print(f"Input tokens: {input_ids}")
print(f"Input text: {repr(prompt)}")
print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}")
+batch_size = 512
+
with torch.no_grad():
- outputs = model(input_ids.to(model.device))
- logits = outputs.logits
+ past = None
+ outputs = None
+ for i in range(0, input_ids.size(1), batch_size):
+ print(f"Processing chunk with tokens {i} to {i + batch_size}")
+ chunk = input_ids[:, i:i + batch_size]
+ outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True)
+ past = outputs.past_key_values
+
+ logits = outputs.logits # type: ignore
# Extract logits for the last token (next token prediction)
last_logits = logits[0, -1, :].float().cpu().numpy()
diff --git a/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh b/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh
index c48af3075c..984d03e95d 100755
--- a/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh
+++ b/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh
@@ -34,8 +34,11 @@ done
MODEL_PATH="${MODEL_PATH:-"$EMBEDDING_MODEL_PATH"}"
MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}"
+CONVERTED_MODEL_PATH="${CONVERTED_EMBEDDING_PATH:-"$CONVERTED_EMBEDDING_MODEL"}"
+CONVERTED_MODEL_NAME="${CONVERTED_MODEL_NAME:-$(basename "$CONVERTED_MODEL_PATH" .gguf)}"
+
if [ -t 0 ]; then
- CPP_EMBEDDINGS="data/llamacpp-${MODEL_NAME}-embeddings.bin"
+ CPP_EMBEDDINGS="data/llamacpp-${CONVERTED_MODEL_NAME}-embeddings.bin"
else
# Process piped JSON data and convert to binary (matching logits.cpp format)
TEMP_FILE=$(mktemp /tmp/tmp.XXXXXX.binn)
diff --git a/examples/model-conversion/scripts/utils/common.py b/examples/model-conversion/scripts/utils/common.py
index 945f9a1a1d..7595d0410e 100644
--- a/examples/model-conversion/scripts/utils/common.py
+++ b/examples/model-conversion/scripts/utils/common.py
@@ -2,6 +2,8 @@
import os
import sys
+import torch
+
def get_model_name_from_env_path(env_path_name):
model_path = os.getenv(env_path_name)
@@ -18,3 +20,131 @@ def get_model_name_from_env_path(env_path_name):
name = name[:-5]
return name
+
+
+def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
+ """
+ Print a tensor in llama.cpp debug style.
+
+ Supports:
+ - 2D tensors (seq, hidden)
+ - 3D tensors (batch, seq, hidden)
+ - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
+
+ Shows first and last max_vals of each vector per sequence position.
+ """
+ t = tensor.detach().to(torch.float32).cpu()
+
+ # Determine dimensions
+ if t.ndim == 3:
+ _, s, _ = t.shape
+ elif t.ndim == 2:
+ _, s = 1, t.shape[0]
+ t = t.unsqueeze(0)
+ elif t.ndim == 4:
+ _, s, _, _ = t.shape
+ else:
+ print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
+ return
+
+ ten_shape = t.shape
+
+ print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
+ print(" [")
+ print(" [")
+
+ # Determine indices for first and last sequences
+ first_indices = list(range(min(s, max_seq)))
+ last_indices = list(range(max(0, s - max_seq), s))
+
+ # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
+ has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
+
+ # Combine indices
+ if has_overlap:
+ # If there's overlap, just use the combined unique indices
+ indices = sorted(list(set(first_indices + last_indices)))
+ separator_index = None
+ else:
+ # If no overlap, we'll add a separator between first and last sequences
+ indices = first_indices + last_indices
+ separator_index = len(first_indices)
+
+ for i, si in enumerate(indices):
+ # Add separator if needed
+ if separator_index is not None and i == separator_index:
+ print(" ...")
+
+ # Extract appropriate slice
+ vec = t[0, si]
+ if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
+ flat = vec.flatten().tolist()
+ else: # 2D or 3D case
+ flat = vec.tolist()
+
+ # First and last slices
+ first = flat[:max_vals]
+ last = flat[-max_vals:] if len(flat) >= max_vals else flat
+ first_str = ", ".join(f"{v:12.4f}" for v in first)
+ last_str = ", ".join(f"{v:12.4f}" for v in last)
+
+ print(f" [{first_str}, ..., {last_str}]")
+
+ print(" ],")
+ print(" ]")
+ print(f" sum = {t.sum().item():.6f}\n")
+
+
+def debug_hook(name):
+ def fn(_m, input, output):
+ if isinstance(input, torch.Tensor):
+ summarize(input, name + "_in")
+ elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor):
+ summarize(input[0], name + "_in")
+ if isinstance(output, torch.Tensor):
+ summarize(output, name + "_out")
+ elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor):
+ summarize(output[0], name + "_out")
+
+ return fn
+
+
+def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_pos_emb"):
+ """
+ Apply monkey patch to dump RoPE activations for debugging.
+
+ Args:
+ model_module_path: Path to the model module (e.g., "transformers.models.apertus.modeling_apertus")
+ function_name: Name of the RoPE function to patch (default: "apply_rotary_pos_emb")
+
+ Example:
+ from utils.common import setup_rope_debug
+ setup_rope_debug("transformers.models.apertus.modeling_apertus")
+ """
+ import importlib
+
+ # Import the module and get the original function
+ module = importlib.import_module(model_module_path)
+ orig_rope = getattr(module, function_name)
+
+ # Set torch print options for better debugging
+ torch.set_printoptions(threshold=float('inf'))
+ torch.set_printoptions(precision=6, sci_mode=False)
+
+ def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ # log inputs
+ summarize(q, "RoPE.q_in")
+ summarize(k, "RoPE.k_in")
+
+ # call original
+ q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
+
+ # log outputs
+ summarize(q_out, "RoPE.q_out")
+ summarize(k_out, "RoPE.k_out")
+
+ return q_out, k_out
+
+ # Patch it
+ setattr(module, function_name, debug_rope)
+ print(f"RoPE debug patching applied to {model_module_path}.{function_name}")
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index 2fb7f6374e..89d3249431 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
bool accept = false;
if (params.sampling.temp > 0) {
// stochastic verification
- common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
+ common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
@@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
continue;
}
- common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
+ common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index a65dcfbe1e..18d117f7cc 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -254,6 +254,7 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
"gmml: OpenCL API version to target")
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
+set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
# toolchain for vulkan-shaders-gen
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 4c04c33003..262d78a4cf 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -386,6 +386,9 @@ if (GGML_CPU_ALL_VARIANTS)
ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD)
ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
+ ggml_add_cpu_backend_variant(android_armv9.0_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2)
+ ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SME)
+ ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
elseif (APPLE)
ggml_add_cpu_backend_variant(apple_m1 DOTPROD)
ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8)
diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt
index fc31089f3e..28fb7612e5 100644
--- a/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/ggml/src/ggml-cpu/CMakeLists.txt
@@ -458,6 +458,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
if (GGML_RV_ZFH)
string(APPEND MARCH_STR "_zfh")
endif()
+
if (GGML_XTHEADVECTOR)
string(APPEND MARCH_STR "_xtheadvector")
elseif (GGML_RVV)
@@ -465,6 +466,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
if (GGML_RV_ZVFH)
string(APPEND MARCH_STR "_zvfh")
endif()
+ if (GGML_RV_ZVFBFWMA)
+ string(APPEND MARCH_STR "_zvfbfwma")
+ endif()
endif()
if (GGML_RV_ZICBOP)
string(APPEND MARCH_STR "_zicbop")
diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h
index 0775c87f98..3f8946ac70 100644
--- a/ggml/src/ggml-cpu/arch-fallback.h
+++ b/ggml/src/ggml-cpu/arch-fallback.h
@@ -43,6 +43,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@@ -51,6 +53,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
// repack.cpp
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
@@ -67,10 +71,14 @@
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__POWERPC__) || defined(__powerpc__)
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
// quants.c
@@ -91,6 +99,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@@ -99,6 +109,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__loongarch64)
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
@@ -119,6 +131,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@@ -127,6 +141,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__riscv)
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
@@ -154,6 +170,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
@@ -161,6 +179,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__s390x__)
// quants.c
#define quantize_row_q8_K_generic quantize_row_q8_K
@@ -187,6 +207,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@@ -195,6 +217,8 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#elif defined(__wasm__)
// quants.c
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
@@ -223,6 +247,8 @@
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
+#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
@@ -231,4 +257,6 @@
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
+#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
#endif
diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp
index fb7f074a85..b61220a189 100644
--- a/ggml/src/ggml-cpu/arch/arm/repack.cpp
+++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp
@@ -786,6 +786,133 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
+void ggml_gemv_q8_0_4x4_q8_0(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 4;
+
+ assert(n % qk == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
+
+ for (int c = 0; c < nc; c += ncols_interleaved) {
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+ float32x4_t acc = vdupq_n_f32(0);
+ for (int b = 0; b < nb; b++) {
+ int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
+ int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
+
+ int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
+
+ int32x4_t ret = vdupq_n_s32(0);
+
+ ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
+ ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
+ ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
+ ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
+
+ ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
+ ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
+ ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
+ ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
+
+ acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
+ a_ptr++;
+ b_ptr++;
+ }
+ vst1q_f32(s, acc);
+ s += ncols_interleaved;
+ }
+ return;
+
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+ ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemv_q8_0_4x8_q8_0(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 8;
+
+ assert(n % qk == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
+
+ for (int c = 0; c < nc; c += ncols_interleaved) {
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+ float32x4_t acc = vdupq_n_f32(0);
+
+ for (int b = 0; b < nb; b++) {
+ int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
+ int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
+
+ int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
+ int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
+ int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
+ int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
+ int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
+
+ int32x4_t ret0 = vdupq_n_s32(0);
+ int32x4_t ret1 = vdupq_n_s32(0);
+
+ // 0..7
+ ret0 = vdotq_s32(ret0, b_low.val[0], a0);
+ ret1 = vdotq_s32(ret1, b_low.val[1], a0);
+ // 8..15
+ ret0 = vdotq_s32(ret0, b_low.val[2], a1);
+ ret1 = vdotq_s32(ret1, b_low.val[3], a1);
+ // 16..23
+ ret0 = vdotq_s32(ret0, b_high.val[0], a2);
+ ret1 = vdotq_s32(ret1, b_high.val[1], a2);
+ // 24..31
+ ret0 = vdotq_s32(ret0, b_high.val[2], a3);
+ ret1 = vdotq_s32(ret1, b_high.val[3], a3);
+
+ int32x4_t ret = vpaddq_s32(ret0, ret1);
+
+ acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
+ a_ptr++;
+ b_ptr++;
+ }
+ vst1q_f32(s, acc);
+ s += ncols_interleaved;
+ }
+ return;
+
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+ ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
@@ -2610,3 +2737,159 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
+
+
+void ggml_gemm_q8_0_4x4_q8_0(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 4;
+
+ assert(n % qk == 0);
+ assert(nr % 4 == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+ for (int y = 0; y < nr / 4; y++) {
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+
+ float32x4_t sumf[4];
+ for (int m = 0; m < 4; m++) {
+ sumf[m] = vdupq_n_f32(0);
+ }
+
+ for (int l = 0; l < nb; l++) {
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));
+
+ int32x4_t sumi_0 = vdupq_n_s32(0);
+ int32x4_t sumi_1 = vdupq_n_s32(0);
+ int32x4_t sumi_2 = vdupq_n_s32(0);
+ int32x4_t sumi_3 = vdupq_n_s32(0);
+
+ for (int k_group = 0; k_group < 8; k_group += 4) {
+ int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);
+ int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);
+
+ for (int k = 0; k < 4; k++) {
+ sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);
+ sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);
+ sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);
+ sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);
+ }
+ }
+
+ sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
+ sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
+ sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
+ sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
+ }
+
+ for (int m = 0; m < 4; m++) {
+ vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
+ }
+ }
+ }
+ return;
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+ ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
+void ggml_gemm_q8_0_4x8_q8_0(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 8;
+
+ assert(n % qk == 0);
+ assert(nr % 4 == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ UNUSED(nb);
+ UNUSED(ncols_interleaved);
+ UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
+
+ for (int y = 0; y < nr; y += 4) {
+ const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
+
+ for (int x = 0; x < nc; x += ncols_interleaved) {
+ const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
+ const block_q8_0x4 * a_ptr = a_ptr_base;
+
+ float32x4_t acc_f32[4];
+ for (int i = 0; i < 4; i++) {
+ acc_f32[i] = vdupq_n_f32(0);
+ }
+
+ for (int b = 0; b < nb; b++) {
+ int32x4_t acc[4];
+ for (int i = 0; i < 4; i++) {
+ acc[i] = vdupq_n_s32(0);
+ }
+
+ // Process 4 chunks of 8 positions each
+ for (int chunk = 0; chunk < 4; chunk++) {
+ int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);
+ int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);
+ int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);
+ int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);
+
+ acc[0] = vmmlaq_s32(acc[0], a01, b01);
+ acc[1] = vmmlaq_s32(acc[1], a01, b23);
+ acc[2] = vmmlaq_s32(acc[2], a23, b01);
+ acc[3] = vmmlaq_s32(acc[3], a23, b23);
+ }
+
+ // Reorder outputs from 2×2 tiles to row-major
+ // acc[0] = [r0c0, r0c1, r1c0, r1c1]
+ // acc[1] = [r0c2, r0c3, r1c2, r1c3]
+ // acc[2] = [r2c0, r2c1, r3c0, r3c1]
+ // acc[3] = [r2c2, r2c3, r3c2, r3c3]
+ int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));
+ int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));
+ int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));
+ int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));
+
+ // Scales
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));
+
+ acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));
+ acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));
+ acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));
+ acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));
+
+ a_ptr++;
+ b_ptr++;
+ }
+
+ for (int row = 0; row < 4; row++) {
+ vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);
+ }
+ }
+ }
+ return;
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
+ ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index a59b518938..f7ba1fe317 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -3320,13 +3320,33 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
__m128 y_vec = _mm_cvtph_ps(x_vec);
_mm_storeu_ps(y + i, y_vec);
}
-#elif defined(__riscv_zvfh)
- for (int vl; i < n; i += vl) {
- vl = __riscv_vsetvl_e16m1(n - i);
- vfloat16m1_t vx = __riscv_vle16_v_f16m1((_Float16 *)&x[i], vl);
- vfloat32m2_t vy = __riscv_vfwcvt_f_f_v_f32m2(vx, vl);
- __riscv_vse32_v_f32m2(&y[i], vy, vl);
+
+#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfhmin)
+ // calculate step size
+ const int epr = __riscv_vsetvlmax_e16m2();
+ const int step = epr * 2;
+ const int np = (n & ~(step - 1));
+
+ // unroll by 2
+ for (; i < np; i += step) {
+ vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, epr);
+ vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, epr);
+ __riscv_vse32_v_f32m4(y + i, ay0, epr);
+
+ vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16*)x + i + epr, epr);
+ vfloat32m4_t ay1 = __riscv_vfwcvt_f_f_v_f32m4(ax1, epr);
+ __riscv_vse32_v_f32m4(y + i + epr, ay1, epr);
}
+
+ // leftovers
+ int vl;
+ for (i = np; i < n; i += vl) {
+ vl = __riscv_vsetvl_e16m2(n - i);
+ vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, vl);
+ vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, vl);
+ __riscv_vse32_v_f32m4(y + i, ay0, vl);
+ }
+
#endif
for (; i < n; ++i) {
@@ -3371,6 +3391,31 @@ void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
(const __m128i *)(x + i))),
16)));
}
+#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfmin)
+ // calculate step size
+ const int epr = __riscv_vsetvlmax_e16m2();
+ const int step = epr * 2;
+ const int np = (n & ~(step - 1));
+
+ // unroll by 2
+ for (; i < np; i += step) {
+ vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, epr);
+ vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, epr);
+ __riscv_vse32_v_f32m4(y + i, ay0, epr);
+
+ vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16*)x + i + epr, epr);
+ vfloat32m4_t ay1 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax1, epr);
+ __riscv_vse32_v_f32m4(y + i + epr, ay1, epr);
+ }
+
+ // leftovers
+ int vl;
+ for (i = np; i < n; i += vl) {
+ vl = __riscv_vsetvl_e16m2(n - i);
+ vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, vl);
+ vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, vl);
+ __riscv_vse32_v_f32m4(y + i, ay0, vl);
+ }
#endif
for (; i < n; i++) {
y[i] = GGML_BF16_TO_FP32(x[i]);
diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp
index b70ea7d78b..fbf7ed9432 100644
--- a/ggml/src/ggml-cpu/repack.cpp
+++ b/ggml/src/ggml-cpu/repack.cpp
@@ -692,6 +692,100 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
}
}
+void ggml_gemv_q8_0_4x4_q8_0_generic(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 4;
+
+ assert(nr == 1);
+ assert(n % qk == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ UNUSED(bs);
+ UNUSED(nr);
+
+ float sumf[4];
+ int sumi;
+
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumf[j] = 0.0;
+ }
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / blocklen); k++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+ sumi += v0 * a_ptr[l].qs[k * blocklen + i];
+ }
+ sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+ }
+ }
+ }
+ for (int j = 0; j < ncols_interleaved; j++) {
+ s[x * ncols_interleaved + j] = sumf[j];
+ }
+ }
+}
+
+void ggml_gemv_q8_0_4x8_q8_0_generic(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 8;
+
+ assert(nr == 1);
+ assert(n % qk == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ UNUSED(bs);
+ UNUSED(nr);
+
+ float sumf[4];
+ int sumi;
+
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumf[j] = 0.0;
+ }
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / blocklen); k++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+ sumi += v0 * a_ptr[l].qs[k * blocklen + i];
+ }
+ sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+ }
+ }
+ }
+ for (int j = 0; j < ncols_interleaved; j++) {
+ s[x * ncols_interleaved + j] = sumf[j];
+ }
+ }
+}
+
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
@@ -1219,8 +1313,129 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
}
}
+void ggml_gemm_q8_0_4x4_q8_0_generic(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 4;
+
+ assert(n % qk == 0);
+ assert(nr % 4 == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ float sumf[4][4];
+ int sumi;
+
+ for (int y = 0; y < nr / 4; y++) {
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumf[m][j] = 0.0;
+ }
+ }
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / blocklen); k++) {
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+ sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
+ }
+ sumf[m][j] +=
+ sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+ }
+ }
+ }
+ }
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+ }
+ }
+ }
+ }
+}
+
+void ggml_gemm_q8_0_4x8_q8_0_generic(int n,
+ float * GGML_RESTRICT s,
+ size_t bs,
+ const void * GGML_RESTRICT vx,
+ const void * GGML_RESTRICT vy,
+ int nr,
+ int nc) {
+ const int qk = QK8_0;
+ const int nb = n / qk;
+ const int ncols_interleaved = 4;
+ const int blocklen = 8;
+
+ assert(n % qk == 0);
+ assert(nr % 4 == 0);
+ assert(nc % ncols_interleaved == 0);
+
+ float sumf[4][4];
+ int sumi;
+
+ for (int y = 0; y < nr / 4; y++) {
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumf[m][j] = 0.0;
+ }
+ }
+ for (int l = 0; l < nb; l++) {
+ for (int k = 0; k < (qk / blocklen); k++) {
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ sumi = 0;
+ for (int i = 0; i < blocklen; ++i) {
+ const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i];
+ sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i];
+ }
+ sumf[m][j] +=
+ sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+ }
+ }
+ }
+ }
+ for (int m = 0; m < 4; m++) {
+ for (int j = 0; j < ncols_interleaved; j++) {
+ s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+ }
+ }
+ }
+ }
+}
+
} // extern "C"
+static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) {
+ block_q8_0x4 out;
+
+ for (int i = 0; i < 4; i++) {
+ out.d[i] = in[i].d;
+ }
+
+ const int end = QK8_0 * 4 / blck_size_interleave;
+ for (int i = 0; i < end; ++i) {
+ int src_id = i % 4;
+ int src_offset = (i / 4) * blck_size_interleave;
+ int dst_offset = i * blck_size_interleave;
+ memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
+ }
+ return out;
+}
+
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
block_q4_0x4 out;
@@ -1534,6 +1749,38 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block
GGML_UNUSED(data_size);
}
+static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t,
+ int interleave_block,
+ const void * GGML_RESTRICT data,
+ size_t data_size) {
+ GGML_ASSERT(t->type == GGML_TYPE_Q8_0);
+ GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
+ constexpr int nrows_interleaved = 4;
+
+ block_q8_0x4 * dst = (block_q8_0x4 *) t->data;
+ const block_q8_0 * src = (const block_q8_0 *) data;
+ block_q8_0 dst_tmp[4];
+ int nrow = ggml_nrows(t);
+ int nblocks = t->ne[0] / QK8_0;
+
+ GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0));
+
+ if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+ return -1;
+ }
+
+ for (int b = 0; b < nrow; b += nrows_interleaved) {
+ for (int64_t x = 0; x < nblocks; x++) {
+ for (int i = 0; i < nrows_interleaved; i++) {
+ dst_tmp[i] = src[x + i * nblocks];
+ }
+ *dst++ = make_block_q8_0x4(dst_tmp, interleave_block);
+ }
+ src += nrows_interleaved * nblocks;
+ }
+ return 0;
+}
+
static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
block_iq4_nlx4 out;
@@ -1702,6 +1949,14 @@ template <> int repack(struct ggml_tensor * t, const void *
return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
}
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+ return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
+}
+
+template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) {
+ return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size);
+}
+
// gemv
template
void gemv(int, float *, size_t, const void *, const void *, int, int);
@@ -1738,6 +1993,14 @@ template <> void gemv(int n, float * s, size
ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+ ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+ ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
// gemm
template
void gemm(int, float *, size_t, const void *, const void *, int, int);
@@ -1774,6 +2037,14 @@ template <> void gemm(int n, float * s, size
ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+ ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+ ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
class tensor_traits_base : public ggml::cpu::tensor_traits {
public:
virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
@@ -2168,6 +2439,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0;
static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0;
+ // instance for Q8_0
+ static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0;
+ static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0;
+
if (cur->type == GGML_TYPE_Q4_0) {
if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)
|| (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) {
@@ -2218,6 +2493,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
return &iq4_nl_4x4_q8_0;
}
}
+ } else if (cur->type == GGML_TYPE_Q8_0) {
+ if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
+ if (cur->ne[1] % 4 == 0) {
+ return &q8_0_4x8_q8_0;
+ }
+ }
+ if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+ if (cur->ne[1] % 4 == 0) {
+ return &q8_0_4x4_q8_0;
+ }
+ }
}
return nullptr;
diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h
index c4d928cd15..af98e70344 100644
--- a/ggml/src/ggml-cpu/repack.h
+++ b/ggml/src/ggml-cpu/repack.h
@@ -98,6 +98,10 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
// Native implementations
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
@@ -120,6 +124,10 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
#if defined(__cplusplus)
} // extern "C"
diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp
index ac8633e212..427e63245b 100644
--- a/ggml/src/ggml-cpu/vec.cpp
+++ b/ggml/src/ggml-cpu/vec.cpp
@@ -195,8 +195,48 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
sumf += (ggml_float)_mm_cvtss_f32(g);
#undef LOAD
-#endif
+#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfwma)
+ size_t vl = __riscv_vsetvlmax_e32m4();
+ // initialize accumulators to all zeroes
+ vfloat32m4_t vsum0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+ vfloat32m4_t vsum1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+
+ // calculate step size
+ const size_t epr = __riscv_vsetvlmax_e16m2();
+ const size_t step = epr * 2;
+ const int np = (n & ~(step - 1));
+
+ // unroll by 2
+ for (; i < np; i += step) {
+ vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], epr);
+ vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], epr);
+ vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, epr);
+ __asm__ __volatile__ ("" ::: "memory");
+
+ vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i + epr], epr);
+ vbfloat16m2_t ay1 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i + epr], epr);
+ vsum1 = __riscv_vfwmaccbf16_vv_f32m4(vsum1, ax1, ay1, epr);
+ __asm__ __volatile__ ("" ::: "memory");
+ }
+
+ // accumulate in 1 register
+ vsum0 = __riscv_vfadd_vv_f32m4(vsum0, vsum1, vl);
+
+ // leftovers
+ for (i = np; i < n; i += vl) {
+ vl = __riscv_vsetvl_e16m2(n - i);
+ vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], vl);
+ vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], vl);
+ vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, vl);
+ }
+
+ // reduce
+ vl = __riscv_vsetvlmax_e32m4();
+ vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
+ sumf += __riscv_vfmv_f_s_f32m1_f32(redsum);
+
+#endif
for (; i < n; ++i) {
sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) *
GGML_BF16_TO_FP32(y[i]));
diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h
index bd80805fdc..3198b33b50 100644
--- a/ggml/src/ggml-cpu/vec.h
+++ b/ggml/src/ggml-cpu/vec.h
@@ -224,13 +224,71 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
}
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
- #elif defined(__riscv_v_intrinsic)
- // todo: RVV impl
- for (int i = 0; i < n; ++i) {
- for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
- sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
- }
- }
+
+ #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
+ size_t vl = __riscv_vsetvlmax_e32m4();
+
+ // initialize accumulators to all zeroes
+ vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+ vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+ vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+ vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl);
+
+ // calculate step size
+ const size_t epr = __riscv_vsetvlmax_e16m2();
+ const size_t step = epr * 2;
+ const int np = (n & ~(step - 1));
+
+ // unroll by 2 along the row dimension
+ for (int i = 0; i < np; i += step) {
+ vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr);
+ vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr);
+ vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr);
+ vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr);
+ vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr);
+
+ vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr);
+ vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr);
+ vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr);
+ vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr);
+ vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr);
+ }
+
+ vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl);
+ vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl);
+
+ // leftovers
+ for (int i = np; i < n; i += vl) {
+ vl = __riscv_vsetvl_e16m2(n - i);
+ vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl);
+ vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl);
+ vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl);
+
+ vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl);
+ vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl);
+ }
+
+ // reduce
+ vl = __riscv_vsetvlmax_e32m2();
+ vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0),
+ __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl);
+ vl = __riscv_vsetvlmax_e32m1();
+ vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0),
+ __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl);
+ vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
+ acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
+
+ vl = __riscv_vsetvlmax_e32m2();
+ vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0),
+ __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl);
+ vl = __riscv_vsetvlmax_e32m1();
+ vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0),
+ __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl);
+ vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
+ acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
+ sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
+ sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
+
#else
const int np = (n & ~(GGML_F16_STEP - 1));
@@ -475,15 +533,39 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
}
np = n;
#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic
- const int np = n;
- _Float16 hv = (_Float16)v;
- for (int i = 0, avl; i < n; i += avl) {
- avl = __riscv_vsetvl_e16m8(n - i);
- vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl);
- vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl);
- vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl);
- __riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl);
+ const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
+ const _Float16 scale = *(const _Float16*)(&s);
+
+ // calculate step size
+ const int epr = __riscv_vsetvlmax_e16m4();
+ const int step = epr * 2;
+ int np = (n & ~(step - 1));
+
+ // unroll by 2
+ for (int i = 0; i < np; i += step) {
+ vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr);
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
+ ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr);
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
+ __asm__ __volatile__ ("" ::: "memory");
+
+ vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr);
+ vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
+ ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr);
+ __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
+ __asm__ __volatile__ ("" ::: "memory");
}
+
+ // leftovers
+ int vl;
+ for (int i = np; i < n; i += vl) {
+ vl = __riscv_vsetvl_e16m4(n - i);
+ vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl);
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
+ ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl);
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
+ }
+ np = n;
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
@@ -724,13 +806,34 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
svst1_f16(pg, (__fp16 *)(y + np), out);
}
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
- for (int i = 0, vl; i < n; i += vl) {
- vl = __riscv_vsetvl_e16m2(n - i);
- vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl);
- vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl);
- vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl);
- vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl);
- __riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl);
+ const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v);
+ const _Float16 scale = *(const _Float16*)(&s);
+
+ // calculate step size
+ const int epr = __riscv_vsetvlmax_e16m4();
+ const int step = epr * 2;
+ const int np = (n & ~(step - 1));
+
+ // unroll by 2
+ for (int i = 0; i < np; i += step) {
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr);
+ ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr);
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr);
+ __asm__ __volatile__ ("" ::: "memory");
+
+ vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr);
+ ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr);
+ __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr);
+ __asm__ __volatile__ ("" ::: "memory");
+ }
+
+ // leftovers
+ int vl;
+ for (int i = np; i < n; i += vl) {
+ vl = __riscv_vsetvl_e16m4(n - i);
+ vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl);
+ ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl);
+ __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl);
}
#elif defined(GGML_SIMD)
const int np = (n & ~(GGML_F16_STEP - 1));
diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu
index 5340eedc08..51967c667c 100644
--- a/ggml/src/ggml-cuda/argmax.cu
+++ b/ggml/src/ggml-cuda/argmax.cu
@@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
}
#pragma unroll
- for (int offset = 16; offset > 0; offset >>= 1) {
+ for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
@@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
argmax = shared_argmax[lane_id];
}
#pragma unroll
- for (int offset = 16; offset > 0; offset >>= 1) {
+ for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
if (val > maxval) {
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index ab0f6fe9ce..55fa2e6a7c 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3076,8 +3076,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
- if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
@@ -3085,7 +3088,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
- if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
+
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
@@ -3094,8 +3101,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
+ ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
+ ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
+ int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
- if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
+ if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu
index 347abc1866..691d8dcb14 100644
--- a/ggml/src/ggml-cuda/mean.cu
+++ b/ggml/src/ggml-cuda/mean.cu
@@ -63,6 +63,9 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
+
+ // Heuristic for block size selection to optimize occupancy.
+ // See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132
if ((nrows / nsm) < 2) {
const dim3 block_dims(512, 1, 1);
reduce_rows_f32*norm=*/true><<>>(src0_d, dst_d, ncols);
diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh
index dcfa40f4d5..3268dadfe8 100644
--- a/ggml/src/ggml-cuda/mma.cuh
+++ b/ggml/src/ggml-cuda/mma.cuh
@@ -76,15 +76,29 @@ namespace ggml_cuda_mma {
// For the A/C matrices this means I major == row major, J major == column major.
// For the B matrix this means I major == column major, J major == row major.
// MIRRORED == Each data value is held exactly once per thread subgroup.
- DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell.
- DATA_LAYOUT_I_MAJOR_MIRRORED = 10,
- DATA_LAYOUT_J_MAJOR_MIRRORED = 20,
+ DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
+ DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
+ DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
+ DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
};
// Implemented mma combinations are:
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
+ static constexpr bool is_i_major(const data_layout dl) {
+ return dl == DATA_LAYOUT_I_MAJOR ||
+ dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
+ }
+
+ static constexpr __device__ data_layout get_input_data_layout() {
+#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ return DATA_LAYOUT_I_MAJOR_MIRRORED;
+#else
+ return DATA_LAYOUT_I_MAJOR;
+#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ }
+
template
struct tile {};
@@ -115,9 +129,9 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 4) {
return threadIdx.x % 32;
} else if constexpr (I == 16 && J == 16) {
- return 4 * (threadIdx.x / 16) + l;
+ return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 32) {
- return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
+ return threadIdx.x % 32;
} else {
NO_DEVICE_CODE;
return -1;
@@ -132,9 +146,9 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 4) {
return 2 * (threadIdx.x / 32) + l;
} else if constexpr (I == 16 && J == 16) {
- return threadIdx.x % 16;
+ return 4 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 32) {
- return threadIdx.x % 32;
+ return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
} else {
NO_DEVICE_CODE;
return -1;
@@ -171,28 +185,19 @@ namespace ggml_cuda_mma {
}
}
#elif defined(AMD_WMMA_AVAILABLE)
-#if defined(RDNA4)
static constexpr int ne = I * J / 32;
-#elif defined(RDNA3)
- static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16;
-#endif // defined(RDNA4)
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
+ if (I == 16 && J == 8) return true;
+ if (I == 16 && J == 4) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
- if constexpr (I == 16 && J == 16) {
-#if defined(RDNA4)
- return 8 * (threadIdx.x / 16) + l;
-#elif defined(RDNA3)
- return 2 * l + (threadIdx.x / 16);
-#else
- NO_DEVICE_CODE;
- return -1;
-#endif // defined(RDNA4)
+ if constexpr (supported()) {
+ return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
@@ -201,7 +206,17 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
- return threadIdx.x % 16;
+ // matrix C
+#if defined(RDNA3)
+ return 2 * l + (threadIdx.x / 16);
+#else
+ return ne * (threadIdx.x / 16) + l;
+#endif // defined(RDNA3)
+ } else if constexpr (I == 16 && J == 8) {
+ // mmq input for RDNA4
+ return ne * (threadIdx.x / 16) + l;
+ } else if constexpr (I == 16 && J == 4) {
+ return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
@@ -293,12 +308,7 @@ namespace ggml_cuda_mma {
}
}
#elif defined(AMD_WMMA_AVAILABLE)
-#if defined(RDNA3)
- // RDNA3 has duplicated data as input.
- static constexpr int ne = I * J / 32 * 2;
-#else
static constexpr int ne = I * J / 32;
-#endif // defined(RDNA3)
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
@@ -317,14 +327,7 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
-#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l;
-#elif defined(RDNA3)
- return l;
-#else
- NO_DEVICE_CODE;
- return -1;
-#endif // defined(RDNA4)
} else {
NO_DEVICE_CODE;
return -1;
@@ -382,42 +385,19 @@ namespace ggml_cuda_mma {
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if defined(AMD_WMMA_AVAILABLE)
-#if defined(RDNA3)
- // RDNA3 has duplicated data as input.
- static constexpr int ne = I * J / 32 * 2;
-#else
static constexpr int ne = I * J / 32;
-#endif // defined(RDNA3)
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
- if (I == 16 && J == 8) return true;
- return false;
+ return tile::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
- if constexpr (I == 16 && J == 8) {
- return threadIdx.x % 16;
- } else {
- NO_DEVICE_CODE;
- return -1;
- }
+ return tile::get_i(l);
}
static __device__ __forceinline__ int get_j(const int l) {
- if constexpr (I == 16 && J == 8) {
-#if defined(RDNA4)
- return 4 * (threadIdx.x / 16) + l;
-#elif defined(RDNA3)
- return l;
-#else
- NO_DEVICE_CODE;
- return -1;
-#endif // defined(RDNA4)
- } else {
- NO_DEVICE_CODE;
- return -1;
- }
+ return tile::get_j(l);
}
#else
static constexpr int ne = I * J / WARP_SIZE;
@@ -458,11 +438,87 @@ namespace ggml_cuda_mma {
#endif // defined(AMD_WMMA_AVAILABLE)
};
+ template
+ struct tile {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
+
+ static constexpr int ne = tile::ne;
+ T x[ne] = {0};
+
+ static constexpr __device__ bool supported() {
+ return tile::supported();
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ return tile::get_j(l);
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ return tile::get_i(l);
+ }
+ };
+
+ template
+ struct tile {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+
+ // RDNA3
+ static constexpr int ne = I * J / 32 * 2;
+
+ T x[ne] = {0};
+
+ static constexpr __device__ bool supported() {
+ if (I == 16 && J == 16) return true;
+ if (I == 16 && J == 8) return true;
+ if (I == 16 && J == 4) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int /*l*/) {
+ if constexpr (supported()) {
+ return threadIdx.x % 16;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (supported()) {
+ return l;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+ };
+
template
struct tile {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+#if defined(RDNA3)
+ static constexpr int ne = tile::ne;
+
+ half2 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ return tile::supported();
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ return tile::get_i(l);
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ return tile::get_j(l);
+ }
+#else // Volta
static constexpr int ne = I * J / (WARP_SIZE/4);
half2 x[ne] = {{0.0f, 0.0f}};
@@ -489,6 +545,29 @@ namespace ggml_cuda_mma {
return -1;
}
}
+#endif // defined(RDNA3)
+ };
+
+ template
+ struct tile {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+ static constexpr int ne = tile::ne;
+
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ return tile::supported();
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ return tile::get_i(l);
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ return tile::get_j(l);
+ }
};
template
@@ -569,55 +648,28 @@ namespace ggml_cuda_mma {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
} else {
- int64_t * xi = (int64_t *) t.x;
- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
- xi[0] = xs[0];
+ ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
}
#elif defined(AMD_WMMA_AVAILABLE)
- if constexpr (std::is_same_v || std::is_same_v) {
-#if defined(RDNA4)
- ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
-#elif defined(RDNA3)
- ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
- ggml_cuda_memcpy_1(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2));
-#else
- NO_DEVICE_CODE;
-#endif // defined(RDNA4)
- } else if constexpr (std::is_same_v) {
- if constexpr (I == 16 && J == 4) {
- int64_t * xi = (int64_t *) t.x;
-#if defined(RDNA4)
- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
- xi[0] = xs[0];
-#elif defined(RDNA3)
- static_assert(tile::ne >= 4, "fragment too small");
- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
- xi[0] = xs[0];
- xi[1] = xs[1];
-#endif // defined(RDNA4)
- } else if constexpr (I == 16 && J == 8) {
- int64_t * xi = (int64_t *) t.x;
-#if defined(RDNA4)
- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
- xi[0] = xs[0];
-
- const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
- xi[1] = xs1[0];
-#elif defined(RDNA3)
- static_assert(tile::ne >= 8, "fragment too small");
- const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
- // contiguous four 64-bit chunks per lane for the wider RDNA3 fragment
- xi[0] = xs[0];
- xi[1] = xs[1];
- const int64_t * xs1 = xs + 2;
- xi[2] = xs1[0];
- xi[3] = xs1[1];
-#endif // defined(RDNA4)
+ // All wmma layout has contiguous data when i-major.
+ if constexpr (is_i_major(dl)) {
+ // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
+ constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
+ if constexpr (sizeof(t.x) > aligned_copy_bytes) {
+ static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
+ constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
+#pragma unroll
+ for (int i = 0; i < aligned_copy_count; ++i) {
+ ggml_cuda_memcpy_1(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
+ }
} else {
- NO_DEVICE_CODE;
+ ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
}
} else {
- NO_DEVICE_CODE;
+#pragma unroll
+ for (int l = 0; l < t.ne; ++l) {
+ t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
+ }
}
#else
#pragma unroll
@@ -660,9 +712,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE
}
- template
+ template
static __device__ __forceinline__ void load_ldmatrix(
- tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+ tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
#if defined(TURING_MMA_AVAILABLE)
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
@@ -832,8 +884,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE
}
+ template
static __device__ __forceinline__ void mma(
- tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) {
+ tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
#ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
@@ -887,8 +940,9 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
+ template
static __device__ __forceinline__ void mma(
- tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
#ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x;
@@ -940,8 +994,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE
}
+ template
static __device__ __forceinline__ void mma(
- tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
@@ -967,8 +1022,9 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
+ template
static __device__ __forceinline__ void mma(
- tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
+ tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
#if defined(AMD_MFMA_AVAILABLE)
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
int32x4_t * acc = (int32x4_t *) D.x;
@@ -1122,8 +1178,9 @@ namespace ggml_cuda_mma {
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
}
-static __device__ __forceinline__ void mma(
- tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
+ template
+ static __device__ __forceinline__ void mma(
+ tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
#if defined(AMD_WMMA_AVAILABLE)
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;
diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh
index e1c695c5c0..e36730948f 100644
--- a/ggml/src/ggml-cuda/mmf.cuh
+++ b/ggml/src/ggml-cuda/mmf.cuh
@@ -32,11 +32,13 @@ static __global__ void mul_mat_f(
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
- constexpr int tile_B_I = std::is_same_v ? 8 : 16;
- constexpr int tile_C_J = std::is_same_v ? 8 : 16;
- typedef tile<16, 8, T> tile_A;
- typedef tile tile_B;
- typedef tile<16, tile_C_J, float> tile_C;
+ constexpr bool is_tf32 = std::is_same_v;
+ constexpr int tile_B_I = is_tf32 ? 8 : 16;
+ constexpr int tile_C_J = is_tf32 ? 8 : 16;
+ constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
+ typedef tile<16, 8, T, ab_layout> tile_A;
+ typedef tile tile_B;
+ typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else {
@@ -272,11 +274,13 @@ static __global__ void mul_mat_f_ids(
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
- constexpr int tile_B_I = std::is_same_v ? 8 : 16;
- constexpr int tile_C_J = std::is_same_v ? 8 : 16;
- typedef tile<16, 8, T> tile_A;
- typedef tile tile_B;
- typedef tile<16, tile_C_J, float> tile_C;
+ constexpr bool is_tf32 = std::is_same_v;
+ constexpr int tile_B_I = is_tf32 ? 8 : 16;
+ constexpr int tile_C_J = is_tf32 ? 8 : 16;
+ constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
+ typedef tile<16, 8, T, ab_layout> tile_A;
+ typedef tile tile_B;
+ typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else {
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 1298f99fff..fa8a72c9c1 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -797,9 +797,10 @@ template
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
- typedef tile<16, 8, int> tile_A;
- typedef tile<16, 8, int> tile_B;
- typedef tile<16, 16, int> tile_C;
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 8, int, input_layout> tile_A;
+ typedef tile<16, 8, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@@ -966,9 +967,10 @@ template
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
- typedef tile<16, 8, int> tile_A;
- typedef tile<16, 8, int> tile_B;
- typedef tile<16, 16, int> tile_C;
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 8, int, input_layout> tile_A;
+ typedef tile<16, 8, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@@ -1130,10 +1132,11 @@ template
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE)
- typedef tile<16, 8, int> tile_A;
- typedef tile<16, 8, int> tile_B;
- typedef tile<16, 16, int> tile_C;
- typedef tile<64, 2, int> tile_load;
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 8, int, input_layout> tile_A;
+ typedef tile<16, 8, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+ typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@@ -1179,9 +1182,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
- typedef tile<16, 4, int> tile_A;
- typedef tile<16, 4, int> tile_B;
- typedef tile<16, 16, int> tile_C;
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 4, int, input_layout> tile_A;
+ typedef tile<16, 4, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@@ -1435,10 +1439,11 @@ template
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE)
- typedef tile<16, 8, int> tile_A;
- typedef tile<16, 8, int> tile_B;
- typedef tile<16, 16, int> tile_C;
- typedef tile<64, 2, int> tile_load;
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 8, int, input_layout> tile_A;
+ typedef tile<16, 8, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+ typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@@ -1501,10 +1506,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
-
- typedef tile<16, 4, int> tile_A;
- typedef tile<16, 4, int> tile_B;
- typedef tile<16, 16, int> tile_C;
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 4, int, input_layout> tile_A;
+ typedef tile<16, 4, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@@ -2265,10 +2270,11 @@ template
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE)
- typedef tile<16, 8, int> tile_A;
- typedef tile<16, 8, int> tile_B;
- typedef tile<16, 16, int> tile_C;
- typedef tile<64, 2, int> tile_load;
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 8, int, input_layout> tile_A;
+ typedef tile<16, 8, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+ typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@@ -2316,9 +2322,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
- typedef tile<16, 4, int> tile_A;
- typedef tile<16, 4, int> tile_B;
- typedef tile<16, 16, int> tile_C;
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 4, int, input_layout> tile_A;
+ typedef tile<16, 4, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity;
@@ -3015,7 +3022,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int tileC_IJ = mmq_get_granularity_device(0);
- typedef tile tile_C;
+ typedef tile tile_C;
constexpr int rows_per_warp = granularity;
#else
typedef tile<16, 8, int> tile_C;
diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu
index 4197973360..6d5ea704c6 100644
--- a/ggml/src/ggml-cuda/ssm-conv.cu
+++ b/ggml/src/ggml-cuda/ssm-conv.cu
@@ -102,31 +102,25 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
const int threads = 128;
GGML_ASSERT(nr % threads == 0);
- if (n_t <= 32) {
- const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
- if (nc == 4) {
- ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
- dst, dst_nb0, dst_nb1, dst_nb2, n_t);
- } else if (nc == 3) {
- ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
- dst, dst_nb0, dst_nb1, dst_nb2, n_t);
+ auto launch_kernel = [&](auto NC) {
+ constexpr int kNC = decltype(NC)::value;
+ if (n_t <= 32) {
+ const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
+ ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
+ dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
- GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
- }
- } else {
- if (nc == 4) {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
- ssm_conv_long_token_f32<<>>(
+ ssm_conv_long_token_f32<<>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
- } else if (nc == 3) {
- const int64_t split_n_t = 32;
- dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
- ssm_conv_long_token_f32<<>>(
- src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
- } else {
- GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
}
+ };
+
+ switch (nc) {
+ case 3: launch_kernel(std::integral_constant{}); break;
+ case 4: launch_kernel(std::integral_constant{}); break;
+ case 9: launch_kernel(std::integral_constant{}); break;
+ default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
}
}
diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu
index 572379fcbf..48e569efa0 100644
--- a/ggml/src/ggml-cuda/topk-moe.cu
+++ b/ggml/src/ggml-cuda/topk-moe.cu
@@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
}
}
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
+ const ggml_tensor * weights,
+ const ggml_tensor * get_rows,
+ const ggml_tensor * argsort,
+ const ggml_tensor * clamp,
+ int n_expert) {
+ ggml_tensor * probs = get_rows->src[0];
+ if (probs->op != GGML_OP_RESHAPE) {
+ return false;
+ }
+ probs = probs->src[0];
+ ggml_tensor * selection_probs = argsort->src[0];
+
+ if (probs != selection_probs) {
+ return false;
+ }
+
float scale = 1.0f;
float max_bias = 0.0f;
@@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return false;
}
- const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;
diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh
index 2eff408b03..6b6c13c587 100644
--- a/ggml/src/ggml-cuda/topk-moe.cuh
+++ b/ggml/src/ggml-cuda/topk-moe.cuh
@@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
-bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax,
+ const ggml_tensor * weights,
+ const ggml_tensor * get_rows,
+ const ggml_tensor * argsort,
+ const ggml_tensor * clamp,
+ int n_expert);
std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt
index ac422027b9..d58e287823 100644
--- a/ggml/src/ggml-hexagon/CMakeLists.txt
+++ b/ggml/src/ggml-hexagon/CMakeLists.txt
@@ -2,6 +2,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
+set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
add_library(htp_iface OBJECT
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
@@ -41,7 +42,8 @@ set(HTP_CMAKE_ARGS
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
-DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT}
-DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
- -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG})
+ -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
+ -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
ExternalProject_Add(htp-v68
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp
index 72a82a8911..6a00abacc3 100644
--- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp
+++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp
@@ -1976,9 +1976,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
break;
case GGML_TYPE_F16:
- if (!opt_experimental) {
- return false;
- }
break;
default:
@@ -2164,8 +2161,14 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
}
// src0, src1 & dst must be mapped to the same session
- if (!hex_supported_buffer(sess, src0, src1, dst)) {
- return false;
+ if(src1){
+ if (!hex_supported_buffer(sess, src0, src1, dst)) {
+ return false;
+ }
+ }else{
+ if (!hex_supported_buffer(sess, src0, dst)) {
+ return false;
+ }
}
return true;
@@ -2665,6 +2668,10 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
req.op = HTP_OP_UNARY_SILU;
supported = true;
}
+ else if (ggml_get_unary_op(dst) == GGML_UNARY_OP_GELU){
+ req.op = HTP_OP_UNARY_GELU;
+ supported = true;
+ }
break;
case GGML_OP_GLU:
@@ -2680,6 +2687,7 @@ static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) {
case GGML_OP_SOFT_MAX:
req.op = HTP_OP_SOFTMAX;
supported = true;
+ break;
default:
break;
@@ -2959,6 +2967,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
case GGML_OP_UNARY:
if (ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) {
ggml_hexagon_unary(node, flags);
+ } else if (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU) {
+ ggml_hexagon_unary(node, flags);
}
break;
case GGML_OP_GLU:
@@ -3257,7 +3267,6 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
auto sess = static_cast(dev->context);
bool supp = false;
-
switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
@@ -3297,10 +3306,13 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
if (ggml_get_unary_op(op) == GGML_UNARY_OP_SILU) {
supp = ggml_hexagon_supported_activations(sess, op);
}
+ else if (ggml_get_unary_op(op) == GGML_UNARY_OP_GELU){
+ supp = ggml_hexagon_supported_activations(sess, op);
+ }
break;
case GGML_OP_GLU:
- if ((ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU) /* || (ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU_OAI) */) {
+ if ((ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU) || (ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU_OAI) ) {
supp = ggml_hexagon_supported_activations(sess, op);
}
break;
diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt
index 22e3fea11d..2cf8aaa42a 100644
--- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt
+++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt
@@ -31,7 +31,8 @@ add_library(${HTP_LIB} SHARED
)
target_compile_definitions(${HTP_LIB} PRIVATE
- $,HTP_DEBUG=1,NDEBUG=1>)
+ $,HTP_DEBUG=1,NDEBUG=1>
+ FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
build_idl(htp_iface.idl ${HTP_LIB})
diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c
index 87b09cca3a..586b5c1f92 100644
--- a/ggml/src/ggml-hexagon/htp/act-ops.c
+++ b/ggml/src/ggml-hexagon/htp/act-ops.c
@@ -231,7 +231,7 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
// x (src0_spad_data) = std::min(src0_p[k], limit);
hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
- hvx_clamp_scalar_f32((const uint8_t *) src1, limit, limit, src1_spad_data, nc);
+ hvx_clamp_scalar_f32((const uint8_t *) src1, -limit, limit, src1_spad_data, nc);
// y (src1_spad_data) = y1 + 1.f
hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc);
// x1 (dst_spad_data) = alpha * (x)
@@ -255,6 +255,91 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
+
+static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
+ struct htp_tensor * dst,
+ const int32_t * op_params,
+ struct htp_spad * src0_spad,
+ struct htp_spad * dst_spad,
+ uint32_t nth,
+ uint32_t ith,
+ uint32_t src0_nrows_per_thread) {
+ htp_act_preamble2;
+
+ uint64_t t1, t2;
+ t1 = HAP_perf_get_qtimer_count();
+
+ const size_t src0_row_size = nb01;
+ const size_t dst_row_size = nb1;
+
+ const uint32_t src0_nrows = ne01 * ne02 * ne03;
+
+ const uint32_t src0_start_row = src0_nrows_per_thread * ith;
+ const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
+
+ // no work for this thread
+ if (src0_start_row >= src0_end_row) {
+ return;
+ }
+
+ int is_aligned = 1;
+ int opt_path = 0;
+ if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
+ is_aligned = 0;
+ FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
+ }
+ if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
+ opt_path = 1;
+ }
+
+ const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
+ uint8_t * restrict data_dst = (uint8_t *) dst->data;
+
+ uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
+ uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
+
+ const int BLOCK = 8;
+ for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
+ const uint32_t block_end = MIN(ir + BLOCK, src0_end_row);
+
+ // Prefetch next block
+ if (block_end < src0_end_row) {
+ const float * restrict prefetch_ptr = (float *) (data_src0 + (block_end * src0_row_size));
+ htp_l2fetch(prefetch_ptr, 1, block_end * src0_row_size, src0_row_size);
+ }
+
+ // Process rows in current block
+ for (uint32_t ib = ir; ib < block_end; ib++) {
+ const float * restrict src0 = (float *) (data_src0 + (ib * src0_row_size));
+ float * restrict dst = (float *) (data_dst + (ib * dst_row_size));
+
+ // gelu = x * sigmoid(1.702 * x) // current implementation
+ if (1 == opt_path) {
+ hvx_mul_scalar_f32((const uint8_t *) src0, (float) 1.702, (uint8_t *) src0_spad_data, ne0);
+ hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0);
+ hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
+ } else {
+ hvx_mul_scalar_f32( (const uint8_t *) src0, (float)1.702, (uint8_t *) src0_spad_data, ne0);
+ hvx_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0);
+ hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
+ }
+ }
+ }
+
+ t2 = HAP_perf_get_qtimer_count();
+
+ FARF(HIGH, "gelu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02,
+ ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
+}
+
+static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) {
+ struct htp_ops_context * octx = (struct htp_ops_context *) data;
+ unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
+ octx->src0_nrows_per_thread);
+}
+
+
+
static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
struct htp_tensor * dst,
const int32_t * op_params,
@@ -371,7 +456,10 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) {
act_op_func = glu_swiglu_oai_fp32;
op_type = "swiglu-oai-f32";
break;
-
+ case HTP_OP_UNARY_GELU:
+ act_op_func = unary_gelu_fp32;
+ op_type = "gelu-f32";
+ break;
default:
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
return HTP_STATUS_NO_SUPPORT;
diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h
index 9278f41f4e..a61652304a 100644
--- a/ggml/src/ggml-hexagon/htp/htp-msg.h
+++ b/ggml/src/ggml-hexagon/htp/htp-msg.h
@@ -51,11 +51,12 @@ enum htp_op {
HTP_OP_MUL_MAT_ID = 5,
HTP_OP_RMS_NORM = 6,
HTP_OP_UNARY_SILU = 7,
- HTP_OP_GLU_SWIGLU = 8,
- HTP_OP_GLU_SWIGLU_OAI = 9,
- HTP_OP_SOFTMAX = 10,
- HTP_OP_ADD_ID = 11,
- HTP_OP_ROPE = 12,
+ HTP_OP_UNARY_GELU = 8,
+ HTP_OP_GLU_SWIGLU = 9,
+ HTP_OP_GLU_SWIGLU_OAI = 10,
+ HTP_OP_SOFTMAX = 11,
+ HTP_OP_ADD_ID = 12,
+ HTP_OP_ROPE = 13,
INVALID
};
diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.c b/ggml/src/ggml-hexagon/htp/hvx-utils.c
index e02b1d9099..f9e02ab67e 100644
--- a/ggml/src/ggml-hexagon/htp/hvx-utils.c
+++ b/ggml/src/ggml-hexagon/htp/hvx-utils.c
@@ -49,6 +49,8 @@ void hvx_mul_f32(const uint8_t * restrict src0,
FARF(HIGH, "hvx_mul_f32: unaligned loop in hvx op, possibly slower execution\n");
}
+
+ bool handled_leftover = false;
if (0 == unaligned_loop) {
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
@@ -60,18 +62,59 @@ void hvx_mul_f32(const uint8_t * restrict src0,
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
}
} else {
+ int step_of_1 = num_elems_whole >> 5; // divby 32, because 32 float = 128 bytes per HVX vector
+ int leftover_size = left_over * sizeof(float);
+
+
+ HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
+ HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
+ HVX_UVector * restrict vec_out = (HVX_UVector *) dst;
+
+ HVX_Vector slinep;
+ HVX_Vector slinec;
+ HVX_Vector sline;
+ HVX_Vector sline2p;
+ HVX_Vector sline2c;
+ HVX_Vector sline2;
+
+ slinep = *vec_in1++;
+ sline2p = *vec_in2++;
#pragma unroll(4)
- for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
- HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
- HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
+ for (int i = step_of_1 - 1; i > 0; i--) {
+ slinec = *vec_in1++;
+ sline2c = *vec_in2++;
+ sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
+ sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
- HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2);
+ *((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2));
+ slinep = slinec;
+ sline2p = sline2c;
+ }
+ if (step_of_1 > 1) {
+ slinec = htp_is_aligned(vec_in1, VLEN) && left_over == 0 ? slinep : *vec_in1++;
+ sline2c = htp_is_aligned(vec_in2, VLEN) && left_over == 0 ? sline2p : *vec_in2++;
- *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+ sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
+ sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
+ *((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2));
+ slinep = slinec;
+ sline2p = sline2c;
+ }
+ if (left_over > 0) {
+ slinec = (is_in_one_chunk(vec_in1, leftover_size, VLEN) ? slinep : *vec_in1++);
+
+ sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0);
+ sline2c = (is_in_one_chunk(vec_in2, leftover_size, VLEN) ? sline2p : *vec_in2++);
+ sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1);
+
+ HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(sline, sline2);
+ hvx_vec_store_u(vec_out, leftover_size, Q6_Vsf_equals_Vqf32(out));
+ handled_leftover = true;
}
}
- if (left_over > 0) {
+
+ if (left_over > 0 && !handled_leftover) {
const float * src0f = (const float *) src0 + num_elems_whole;
const float * src1f = (const float *) src1 + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole;
@@ -464,7 +507,7 @@ void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
}
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
-
+ bool handled_leftover = false;
if (0 == unaligned_loop) {
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
@@ -475,17 +518,47 @@ void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
}
} else {
+ int step_of_1 = num_elems >> 5; // divby 32, because 32 float = 128 bytes per HVX vector
+ int leftover_size = left_over * sizeof(float);
+
+ HVX_Vector * input_v_ptr = (HVX_Vector *) src;
+ HVX_UVector * output_v_ptr = (HVX_UVector *) dst;
+
+ HVX_Vector slinep;
+ HVX_Vector slinec;
+ HVX_Vector sline;
+
+ slinep = *input_v_ptr++;
+
#pragma unroll(4)
- for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
- HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
+ for (int i = step_of_1 - 1; i > 0; i--) {
+ slinec = *input_v_ptr++;
+ sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
+ *((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
+ /* Prepare slinep for next iteration */
+ slinep = slinec;
+ }
- HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec);
+ if (step_of_1 > 0) {
+ slinec = htp_is_aligned(input_v_ptr, VLEN) && left_over == 0 ? slinep : *input_v_ptr++;
+ sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
+ *((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
- *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
+ slinep = slinec;
+ }
+
+ if (leftover_size > 0) {
+ slinec = (is_in_one_chunk(input_v_ptr, leftover_size, VLEN) ? slinep : *input_v_ptr++);
+
+ sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src);
+
+ HVX_Vector sout = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec));
+ hvx_vec_store_u(output_v_ptr, leftover_size, sout);
+ handled_leftover = true;
}
}
- if (left_over > 0) {
+ if (left_over > 0 && !handled_leftover) {
const float * srcf = (const float *) src + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole;
@@ -875,35 +948,45 @@ float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
size_t left_over = num_elems & (VLEN_FP32 - 1);
size_t num_elems_whole = num_elems - left_over;
-
+ int unalign_address = 0;
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
+ unalign_address = 1;
}
- assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
-
const float * src_f = (const float *) src;
- HVX_Vector vec_min = Q6_V_vsplat_R(val);
+ HVX_Vector vec_min = hvx_vec_splat_fp32(val);
- HVX_Vector * restrict vec_in = (HVX_Vector *) src;
- HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
+ if(unalign_address == 0){
+ HVX_Vector * restrict vec_in = (HVX_Vector *) src;
+ HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
- #pragma unroll(4)
- for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
- vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
- *vec_out++ = Q6_Vsf_equals_Vqf32(vec_min);
+ #pragma unroll(4)
+ for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+ HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
+ *vec_out++ = (min_clamp);
+ }
+ }else{
+ HVX_UVector * restrict vec_in = (HVX_Vector *) src;
+ HVX_UVector * restrict vec_out = (HVX_Vector *) dst;
+
+ #pragma unroll(4)
+ for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+ HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
+ *vec_out++ = (min_clamp);
+ }
}
- if (left_over > 0) {
+ if (left_over > 0 ) {
const float * srcf = (const float *) src + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole;
- HVX_Vector in = *(HVX_UVector *) srcf;
+ HVX_UVector in = *(HVX_UVector *) srcf;
- vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, in);
+ HVX_UVector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, in);
- hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(vec_min));
+ hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, (min_clamp));
}
}
@@ -915,46 +998,70 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src,
size_t left_over = num_elems & (VLEN_FP32 - 1);
size_t num_elems_whole = num_elems - left_over;
+ int unalign_address = 0;
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
+ unalign_address = 1;
}
- assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
-
- HVX_Vector * restrict vec_in = (HVX_Vector *) src;
- HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
-
HVX_Vector range_left = hvx_vec_splat_fp32(limit_left);
HVX_Vector range_right = hvx_vec_splat_fp32(limit_right);
- #pragma unroll(4)
- for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
- HVX_Vector in_vec = *vec_in++;
- HVX_Vector temp_v = in_vec;
+ if(unalign_address == 0){
+ HVX_Vector * restrict vec_in = (HVX_Vector *) src;
+ HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
- HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
- HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
- in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
- in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
- *vec_out++ = Q6_Vsf_equals_Vqf32(in_vec);
+ #pragma unroll(4)
+ for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+ HVX_Vector in_vec = *vec_in++;
+ HVX_Vector temp_v = in_vec;
+
+ HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
+ HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
+
+ in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
+ in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
+
+ *vec_out++ = in_vec;
+ }
+
+ }else{
+
+ HVX_UVector * restrict vec_in = (HVX_UVector *) src;
+ HVX_UVector * restrict vec_out = (HVX_UVector *) dst;
+
+ #pragma unroll(4)
+ for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
+ HVX_Vector in_vec = *vec_in++;
+ HVX_Vector temp_v = in_vec;
+
+ HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
+ HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
+
+ in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
+ in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
+
+ *vec_out++ = in_vec;
+ }
+
}
if (left_over > 0) {
const float * srcf = (const float *) src + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole;
- HVX_Vector in = *(HVX_UVector *) srcf;
+ HVX_Vector in_vec = *(HVX_UVector *) srcf;
- HVX_Vector temp_v = in;
+ HVX_Vector temp_v = in_vec;
- HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in, range_right);
- HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in);
+ HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
+ HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
- in = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
- in = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
+ in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
+ in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec);
- hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(in));
+ hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
}
}
diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h
index 80658105c5..566048297d 100644
--- a/ggml/src/ggml-hexagon/htp/hvx-utils.h
+++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h
@@ -265,12 +265,16 @@ static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t
}
}
+
+/* Return whether 'n' elements from vector are in the one chunk of 'chunk_size'. */
static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
return right_off <= chunk_size;
}
+
+
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
HVX_VectorAlias u = { .v = v };
@@ -994,6 +998,59 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t *
}
}
+
+static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems){
+ int step_of_1 = num_elems >> 5; // divby 32, because 32 float = 128 bytes per HVX vector
+ int leftover = num_elems - (step_of_1 * VLEN_FP32);
+
+ int32_t leftover_size = leftover * sizeof(float);
+
+ static const float kMinExp = -87.f; // 0
+ static const float kMaxExp = 87.f; // 1
+
+ const HVX_Vector one = hvx_vec_splat_fp32(1.f);
+ const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
+ const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
+
+ const float *input = (float *)src;
+ float *output = (float *)dst;
+
+ HVX_Vector * input_v_ptr = (HVX_Vector *) input;
+ HVX_UVector * output_v_ptr = (HVX_UVector *) output;
+
+ HVX_Vector slinep;
+ HVX_Vector slinec;
+ HVX_Vector sline;
+
+ slinep = *input_v_ptr++;
+ #pragma unroll(4)
+ for (int i = step_of_1 - 1; i > 0; i--) {
+ slinec = *input_v_ptr++;
+ sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
+ *((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
+ /* Prepare slinep for next iteration */
+ slinep = slinec;
+ }
+
+ if (step_of_1 > 0) {
+ slinec = htp_is_aligned(input_v_ptr, 128) && leftover == 0 ? slinep : *input_v_ptr++;
+ sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
+ *((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
+ ;
+
+ slinep = slinec;
+ }
+ if (leftover > 0) {
+ slinec = (is_in_one_chunk(input_v_ptr, leftover_size, 128) ? slinep : *input_v_ptr++);
+
+ sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
+
+ HVX_Vector sout = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
+ hvx_vec_store_u(output_v_ptr, leftover_size, sout);
+ }
+}
+
+
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
void hvx_mul_f32(const uint8_t * restrict src0,
const uint8_t * restrict src1,
diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c
index b60b352a7b..656c369d0a 100644
--- a/ggml/src/ggml-hexagon/htp/main.c
+++ b/ggml/src/ggml-hexagon/htp/main.c
@@ -798,6 +798,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
break;
case HTP_OP_UNARY_SILU:
+ case HTP_OP_UNARY_GELU:
if (n_bufs != 2) {
FARF(ERROR, "Bad act-req buffer list");
continue;
@@ -806,6 +807,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
break;
case HTP_OP_GLU_SWIGLU:
+ case HTP_OP_GLU_SWIGLU_OAI:
case HTP_OP_SOFTMAX:
if ((n_bufs != 2) && (n_bufs != 3)) {
FARF(ERROR, "Bad act-req buffer list");
diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c
index c99b6a0d18..0c9188244d 100644
--- a/ggml/src/ggml-hexagon/htp/matmul-ops.c
+++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c
@@ -92,6 +92,18 @@ static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = {
0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
};
+// vdelta control to replicate first fp16 value across all elements
+static const uint8_t __attribute__((aligned(128))) repl_2x_fp16[128] = {
+ 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+ 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
+};
+
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
@@ -903,7 +915,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
const float * restrict vy = (const float * restrict) y;
for (uint32_t i = 0; i < n; i++) {
- rsum += vx[i] * (__fp16) vy[i];
+ rsum += (float)vx[i] * vy[i];
}
*s = rsum;
return;
@@ -917,7 +929,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
// for some reason we need volatile here so that the compiler doesn't try anything funky
volatile HVX_Vector rsum = Q6_V_vsplat_R(0);
-
+ float r_sum_scalar = 0.0f;
uint32_t i = 0;
for (i = 0; i < nv0; i++) {
@@ -926,31 +938,42 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
HVX_Vector x = vx[i];
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
- HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
- HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
+ //NOTE: need volatile here to prevent compiler optimization
+ // Seem compiler cannot guarantee read-after-write??
+ volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
+ volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
}
if (nv1) {
- HVX_VectorPair yp = vy[i];
+ // HVX_VectorPair yp = vy[i];
- HVX_Vector x = vx[i];
- HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
+ // HVX_Vector x = vx[i];
+ // HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0
- if (nv1 >= 32) {
- HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
- nv1 -= 32;
- }
+ // if (nv1 >= 32) {
+ // volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
+ // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
+ // nv1 -= 32;
+ // }
+ // rsum = hvx_vec_qf32_reduce_sum(rsum);
+
+ // if (nv1) {
+ // volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
+ // HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
+ // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
+ // }
+
+ //process the remainder using scalar loop
rsum = hvx_vec_qf32_reduce_sum(rsum);
+ const __fp16 * restrict sx = (const __fp16 * restrict) x;
+ const float * restrict sy = (const float * restrict) y;
- if (nv1) {
- HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
- HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1);
- rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
+ for (uint32_t i = nv0 * 64; i < n; i++) {
+ r_sum_scalar += (float) sx[i] * sy[i];
}
// hvx_vec_dump_fp16("X", x);
@@ -961,7 +984,7 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
rsum = hvx_vec_qf32_reduce_sum(rsum);
}
- *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum));
+ *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar;
# ifdef HTP_DEBUG
{
@@ -1498,9 +1521,6 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
- const size_t src0_row_size = sizeof(__fp16) * ne00;
- const size_t src1_row_size = sizeof(float) * ne10;
-
assert(ne12 % ne02 == 0);
assert(ne13 % ne03 == 0);
@@ -1510,8 +1530,6 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
// This is the size of the rest of the dimensions of the result
const uint32_t nr1 = ne1 * ne2 * ne3;
- uint32_t chunk_size = 64;
-
// distribute the thread work across the inner or outer loop based on which one is larger
uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
@@ -1544,11 +1562,11 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
const uint32_t blck_0 = 64;
const uint32_t blck_1 = 64;
- float tmp[32];
+ __attribute__((aligned(128))) float tmp[64];
for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
- for (uint32_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1++) {
+ for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) {
const uint32_t i13 = (ir1 / (ne12 * ne1));
const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
@@ -1561,13 +1579,16 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
const uint32_t i2 = i12;
const uint32_t i3 = i13;
- const uint8_t * restrict src0_row = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
+ const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03);
const uint8_t * restrict src1_col =
- (const uint8_t *) src1->data + (i11 + i12 * ne11 + i13 * ne12 * ne11) * src1_row_size;
+ (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13);
float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
- for (uint32_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0++) {
- vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row + ir0 * src0_row_size, src1_col);
+ const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end);
+ for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) {
+ // Use nb01 stride for non-contiguous src0 support
+ const uint8_t * restrict src0_row = src0_base + ir0 * nb01;
+ vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col);
}
hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0);
@@ -1585,6 +1606,118 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
// *** dynamic quant
+static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
+ assert((unsigned long) x % 128 == 0);
+ assert((unsigned long) y_q % 128 == 0);
+
+ HVX_Vector * vx = (HVX_Vector *) x;
+ HVX_Vector zero = Q6_V_vsplat_R(0);
+
+ // Use reduce max fp32 to find max(abs(e)) first
+ HVX_Vector vmax0_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[0]));
+ HVX_Vector vmax1_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[1]));
+ HVX_Vector vmax2_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[2]));
+ HVX_Vector vmax3_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[3]));
+ // Load and convert into QF32
+ HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
+ HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
+ HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
+ HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
+
+ // Convert to QF32
+ HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
+ HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
+ HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
+ HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
+
+ // Combine and convert to fp16
+ HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
+ HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
+
+ // Convert into fp16
+ HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
+ HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
+
+ // Replicate first fp16 scale across all lanes
+ HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_fp16;
+ vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
+ vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
+
+ HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
+ HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
+ HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
+ HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
+
+ hvx_vec_store_u(y_d + 0, 2, vd01_hf);
+ HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64);
+ hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf);
+
+ hvx_vec_store_u(y_d + 4, 2, vd23_hf);
+ rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64);
+ hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
+
+ // Divide input by the scale
+ HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
+ HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
+ vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
+ vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
+
+ // Convert to int8
+ HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
+ HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
+ HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
+
+ *(HVX_Vector *) y_q = vx_i8;
+}
+
+static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
+ assert((unsigned long) x % 128 == 0);
+ assert((unsigned long) y_q % 128 == 0);
+
+ HVX_Vector * vx = (HVX_Vector *) x;
+
+ // Load and convert into QF32
+ HVX_Vector zero = Q6_V_vsplat_R(0);
+ HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
+ HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
+ HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
+ HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
+
+ // Convert into fp16
+ HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
+ HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
+
+ // Compute max and scale
+ HVX_Vector vmax01_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
+ HVX_Vector vmax23_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx23_hf));
+
+ // Replicate first fp16 scale across all lanes
+ HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
+ vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
+ vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
+
+ HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
+ HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
+ HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
+ HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
+
+ hvx_vec_store_u(y_d + 0, 4, vd01_hf);
+ hvx_vec_store_u(y_d + 4, 4, vd23_hf);
+
+ // Divide input by the scale
+ HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
+ HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
+ vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
+ vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
+
+ // Convert to int8
+ HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
+ HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
+ HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
+
+ *(HVX_Vector *) y_q = vx_i8;
+}
+
static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
assert((unsigned long) x % 128 == 0);
assert((unsigned long) y_q % 128 == 0);
@@ -1646,10 +1779,24 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u
uint8_t * restrict t_d = (uint8_t *) x;
for (uint32_t i = 0; i < nb; i++) {
+#if FP32_QUANTIZE_GROUP_SIZE == 32
+ quantize_block_fp32_q8x1(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
+ t_d + (i * 2 + 0) * dblk_size / 2);
+ quantize_block_fp32_q8x1(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
+ t_d + (i * 2 + 1) * dblk_size / 2);
+#elif FP32_QUANTIZE_GROUP_SIZE == 64
+ quantize_block_fp32_q8x2(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
+ t_d + (i * 2 + 0) * dblk_size / 2);
+ quantize_block_fp32_q8x2(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
+ t_d + (i * 2 + 1) * dblk_size / 2);
+#elif FP32_QUANTIZE_GROUP_SIZE == 128
quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
t_d + (i * 2 + 0) * dblk_size / 2);
quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
t_d + (i * 2 + 1) * dblk_size / 2);
+#else
+#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128"
+#endif
}
// now copy the scales into final location
@@ -1662,6 +1809,7 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
uint32_t nth,
uint32_t ith,
uint32_t nrows_per_thread) {
+
uint64_t t1 = HAP_perf_get_qtimer_count();
const uint32_t ne0 = src->ne[0];
diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
index 18a45d2d96..13cf1f5f9d 100644
--- a/ggml/src/ggml-rpc/ggml-rpc.cpp
+++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
@@ -583,7 +583,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
if (tensor->buffer) {
ggml_backend_buffer_t buffer = tensor->buffer;
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- result.buffer = ctx->remote_ptr;
+ result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
} else {
result.buffer = 0;
}
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 34ec09d403..a524adbe0c 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -689,6 +689,7 @@ struct vk_device_struct {
vk_pipeline pipeline_gelu_quick[2];
vk_pipeline pipeline_silu[2];
vk_pipeline pipeline_relu[2];
+ vk_pipeline pipeline_xielu[2];
vk_pipeline pipeline_neg[2];
vk_pipeline pipeline_tanh[2];
vk_pipeline pipeline_sigmoid[2];
@@ -730,7 +731,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
- vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
+ vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
@@ -855,6 +856,15 @@ struct vk_subbuffer {
}
};
+// vk_event is used for the event-related backend interfaces. It uses 'event' for
+// event_wait and 'fence' for event_synchronize. Polling on an event for
+// event_synchronize wouldn't be sufficient to wait for command buffers to complete,
+// and would lead to validation errors.
+struct vk_event {
+ vk::Event event;
+ vk::Fence fence;
+};
+
struct vk_semaphore {
vk::Semaphore s;
uint64_t value;
@@ -990,6 +1000,8 @@ struct vk_op_push_constants {
uint32_t KY;
float param1;
float param2;
+ float param3;
+ float param4;
};
struct vk_op_glu_push_constants {
@@ -1258,6 +1270,7 @@ struct vk_op_im2col_push_constants {
int32_t s0; int32_t s1;
int32_t p0; int32_t p1;
int32_t d0; int32_t d1;
+ uint32_t batch_IC;
};
struct vk_op_im2col_3d_push_constants {
@@ -1527,6 +1540,8 @@ private:
#endif // GGML_VULKAN_MEMORY_DEBUG
static bool vk_perf_logger_enabled = false;
+static bool vk_perf_logger_concurrent = false;
+static bool vk_enable_sync_logger = false;
// number of calls between perf logger prints
static uint32_t vk_perf_logger_frequency = 1;
@@ -1577,14 +1592,14 @@ class vk_perf_logger {
flops.clear();
}
- void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {
+ std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) {
+ *n_flops = 0;
std::string fusion_str;
if (fusion_name) {
fusion_str = fusion_name + std::string(" ");
}
if (node->op == GGML_OP_UNARY) {
- timings[fusion_str + ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time);
- return;
+ return fusion_str + ggml_unary_op_name(ggml_get_unary_op(node));
}
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
const uint64_t m = node->ne[0];
@@ -1606,9 +1621,8 @@ class vk_perf_logger {
name += " batch=" + std::to_string(batch);
}
name = fusion_str + name;
- timings[name].push_back(time);
- flops[name].push_back(m * n * (k + (k - 1)) * batch);
- return;
+ *n_flops = m * n * (k + (k - 1)) * batch;
+ return name;
}
if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
std::string name = ggml_op_name(node->op);
@@ -1624,20 +1638,17 @@ class vk_perf_logger {
uint64_t size_M = Cout;
uint64_t size_K = Cin * KW * KH;
uint64_t size_N = N * OW * OH;
- uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1));
+ *n_flops = size_M * size_N * (size_K + (size_K - 1));
name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
", N=N*OW*OH=" + std::to_string(size_N);
name = fusion_str + name;
- flops[name].push_back(n_flops);
- timings[name].push_back(time);
- return;
+ return name;
}
if (node->op == GGML_OP_RMS_NORM) {
std::string name = ggml_op_name(node->op);
name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
name = fusion_str + name;
- timings[name].push_back(time);
- return;
+ return name;
}
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
const ggml_tensor * dst = node;
@@ -1653,8 +1664,7 @@ class vk_perf_logger {
" k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
" v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
" m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
- timings[name.str()].push_back(time);
- return;
+ return name.str();
}
if (node->op == GGML_OP_TOP_K) {
std::stringstream name;
@@ -1662,11 +1672,38 @@ class vk_perf_logger {
name << ggml_op_name(node->op) <<
" K=" << node->ne[0] <<
" (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
- timings[name.str()].push_back(time);
- return;
+ return name.str();
}
- timings[fusion_str + ggml_op_name(node->op)].push_back(time);
+ return fusion_str + ggml_op_name(node->op);
}
+
+ void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {
+ uint64_t n_flops;
+ std::string name = get_node_fusion_name(node, fusion_name, &n_flops);
+ if (n_flops) {
+ flops[name].push_back(n_flops);
+ }
+ timings[name].push_back(time);
+ }
+
+ void log_timing(const std::vector