Merge branch 'master' into dev-ocl-test-fix
This commit is contained in:
commit
6ebbac9715
|
|
@ -89,7 +89,10 @@ nix:
|
|||
embedding:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: examples/embedding/
|
||||
|
||||
jinja parser:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- common/jinja/**
|
||||
Ascend NPU:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Get latest Vulkan SDK version
|
||||
id: vulkan_sdk_version
|
||||
|
|
@ -24,7 +24,7 @@ jobs:
|
|||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
|
|
@ -47,10 +47,10 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-toolchain
|
||||
with:
|
||||
path: ./spacemit_toolchain
|
||||
|
|
@ -73,10 +73,10 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-rocm
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ jobs:
|
|||
linux:
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ jobs:
|
|||
# runs-on: ubuntu-24.04
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@v6
|
||||
# - name: Setup Riscv
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture riscv64
|
||||
|
|
@ -52,7 +52,7 @@ jobs:
|
|||
# runs-on: ubuntu-24.04
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@v6
|
||||
# - name: Setup Riscv
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture riscv64
|
||||
|
|
@ -99,7 +99,7 @@ jobs:
|
|||
# runs-on: ubuntu-24.04
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions/checkout@v6
|
||||
# - name: Setup Arm64
|
||||
# run: |
|
||||
# sudo dpkg --add-architecture arm64
|
||||
|
|
@ -146,7 +146,7 @@ jobs:
|
|||
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- name: Setup LoongArch
|
||||
run: |
|
||||
rm -f /etc/apt/sources.list.d/*
|
||||
|
|
@ -201,7 +201,7 @@ jobs:
|
|||
container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- name: Setup LoongArch
|
||||
run: |
|
||||
rm -f /etc/apt/sources.list.d/*
|
||||
|
|
@ -262,10 +262,10 @@ jobs:
|
|||
SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2"
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Use SpacemiT Toolchain Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-toolchain
|
||||
with:
|
||||
path: ./spacemit_toolchain
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ on:
|
|||
'**/*.m',
|
||||
'**/*.metal',
|
||||
'**/*.comp',
|
||||
'**/*.glsl'
|
||||
'**/*.glsl',
|
||||
'**/*.wgsl'
|
||||
]
|
||||
|
||||
pull_request:
|
||||
|
|
@ -42,7 +43,8 @@ on:
|
|||
'**/*.m',
|
||||
'**/*.metal',
|
||||
'**/*.comp',
|
||||
'**/*.glsl'
|
||||
'**/*.glsl',
|
||||
'**/*.wgsl'
|
||||
]
|
||||
|
||||
concurrency:
|
||||
|
|
@ -63,7 +65,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -99,7 +101,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -135,7 +137,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -189,7 +191,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -269,7 +271,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -317,7 +319,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
|
|
@ -347,7 +349,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
# - name: ccache
|
||||
# uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -380,7 +382,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -414,7 +416,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -436,7 +438,7 @@ jobs:
|
|||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Use Vulkan SDK Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
|
|
@ -472,7 +474,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -494,7 +496,7 @@ jobs:
|
|||
echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Use Vulkan SDK Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-sdk
|
||||
with:
|
||||
path: ./vulkan_sdk
|
||||
|
|
@ -543,7 +545,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -585,7 +587,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
|
|
@ -616,7 +618,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
|
|
@ -644,7 +646,7 @@ jobs:
|
|||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: add oneAPI to apt
|
||||
shell: bash
|
||||
|
|
@ -668,7 +670,7 @@ jobs:
|
|||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -693,7 +695,7 @@ jobs:
|
|||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: add oneAPI to apt
|
||||
shell: bash
|
||||
|
|
@ -717,7 +719,7 @@ jobs:
|
|||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -749,7 +751,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -781,7 +783,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -813,7 +815,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
|
|
@ -843,7 +845,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -853,7 +855,7 @@ jobs:
|
|||
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
|
||||
|
||||
- name: Download xcframework artifact
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
|
|
@ -885,7 +887,7 @@ jobs:
|
|||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -954,7 +956,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -1053,7 +1055,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install dependencies
|
||||
env:
|
||||
|
|
@ -1092,7 +1094,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -1145,7 +1147,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -1177,7 +1179,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
|
|
@ -1187,7 +1189,7 @@ jobs:
|
|||
7z x data.tar
|
||||
|
||||
- name: Use ROCm Installation Cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
id: cache-rocm
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
|
|
@ -1239,7 +1241,7 @@ jobs:
|
|||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Xcode
|
||||
uses: maxim-lobanov/setup-xcode@v1
|
||||
|
|
@ -1269,7 +1271,7 @@ jobs:
|
|||
./build-xcframework.sh
|
||||
|
||||
- name: Upload xcframework artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: llama-xcframework
|
||||
path: build-apple/llama.xcframework/
|
||||
|
|
@ -1285,7 +1287,7 @@ jobs:
|
|||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
# Disabled due to size (400MB) and always 0 cache hits
|
||||
# - name: ccache
|
||||
|
|
@ -1295,7 +1297,7 @@ jobs:
|
|||
# evict-old-files: 1d
|
||||
|
||||
- name: Set up JDK
|
||||
uses: actions/setup-java@v3
|
||||
uses: actions/setup-java@v5
|
||||
with:
|
||||
java-version: 17
|
||||
distribution: zulu
|
||||
|
|
@ -1327,7 +1329,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install OpenCL Headers and Libs
|
||||
id: install_opencl
|
||||
|
|
@ -1371,7 +1373,7 @@ jobs:
|
|||
id: update_presets
|
||||
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||
run: |
|
||||
cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||
cp docs/backend/snapdragon/CMakeUserPresets.json .
|
||||
|
||||
- name: Build
|
||||
id: ndk_build
|
||||
|
|
@ -1402,7 +1404,7 @@ jobs:
|
|||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
|
@ -1460,7 +1462,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -1486,7 +1488,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -1512,7 +1514,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -1538,7 +1540,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -1564,7 +1566,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -1590,7 +1592,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
|
|
@ -1604,7 +1606,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
|
|
@ -1618,7 +1620,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
|
|
@ -1632,7 +1634,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
|
|
@ -1645,7 +1647,7 @@ jobs:
|
|||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
# uses: actions/checkout@v6
|
||||
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
|
|
@ -1659,7 +1661,7 @@ jobs:
|
|||
# steps:
|
||||
# - name: Clone
|
||||
# id: checkout
|
||||
# uses: actions/checkout@v4
|
||||
# uses: actions/checkout@v6
|
||||
|
||||
# - name: Test
|
||||
# id: ggml-ci
|
||||
|
|
@ -1673,7 +1675,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
|
|
@ -1686,7 +1688,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dawn Dependency
|
||||
id: dawn-depends
|
||||
|
|
@ -1714,7 +1716,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
|
|
@ -1728,7 +1730,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -1773,7 +1775,7 @@ jobs:
|
|||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Check environment
|
||||
run: |
|
||||
|
|
@ -1875,7 +1877,7 @@ jobs:
|
|||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
|
|
@ -1969,7 +1971,7 @@ jobs:
|
|||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
|
|
@ -2043,7 +2045,7 @@ jobs:
|
|||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
|
|
@ -2089,7 +2091,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
|
|
|
|||
|
|
@ -19,16 +19,16 @@ on:
|
|||
|
||||
jobs:
|
||||
check-vendor:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ permissions:
|
|||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
|
||||
days-before-issue-stale: 30
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ jobs:
|
|||
# If you do not check out your code, Copilot will do this for you.
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -45,7 +45,7 @@ jobs:
|
|||
sudo chmod +x /usr/local/bin/git-clang-format
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ jobs:
|
|||
- { tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true, runs_on: "ubuntu-22.04" }
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0 # preserve git history, so we can determine the build number
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ jobs:
|
|||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
|
|
@ -208,7 +208,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@ concurrency:
|
|||
|
||||
jobs:
|
||||
editorconfig:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- uses: editorconfig-checker/action-editorconfig-checker@v2
|
||||
with:
|
||||
version: v3.0.3
|
||||
|
|
|
|||
|
|
@ -21,12 +21,12 @@ on:
|
|||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.9.x'
|
||||
- name: Install dependencies
|
||||
|
|
|
|||
|
|
@ -7,11 +7,11 @@ jobs:
|
|||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
repository: "ggml-org/llama.cpp"
|
||||
- uses: actions/labeler@v5
|
||||
- uses: actions/labeler@v6
|
||||
with:
|
||||
configuration-path: '.github/labeler.yml'
|
||||
|
|
|
|||
|
|
@ -12,14 +12,14 @@ on:
|
|||
|
||||
jobs:
|
||||
pre-tokenizer-hashes:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
|
|
|||
|
|
@ -20,13 +20,13 @@ concurrency:
|
|||
|
||||
jobs:
|
||||
python-check-requirements:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
name: check-requirements
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Python environment
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Run check-requirements.sh script
|
||||
|
|
|
|||
|
|
@ -15,13 +15,13 @@ concurrency:
|
|||
|
||||
jobs:
|
||||
flake8-lint:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
name: Lint
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Python environment
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: flake8 Lint
|
||||
|
|
|
|||
|
|
@ -24,14 +24,12 @@ jobs:
|
|||
name: pyright type-check
|
||||
steps:
|
||||
- name: Check out source repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Python environment
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Python dependencies
|
||||
# TODO: use a venv
|
||||
run: pip install -r requirements/requirements-all.txt
|
||||
pip-install: -r requirements/requirements-all.txt
|
||||
- name: Type-check with Pyright
|
||||
uses: jakebailey/pyright-action@v2
|
||||
with:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
|
@ -63,7 +63,7 @@ jobs:
|
|||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.tar.gz
|
||||
name: llama-bin-macos-arm64.tar.gz
|
||||
|
|
@ -74,7 +74,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ jobs:
|
|||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz -s ",./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.tar.gz
|
||||
name: llama-bin-macos-x64.tar.gz
|
||||
|
|
@ -133,7 +133,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
|
@ -173,7 +173,7 @@ jobs:
|
|||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-${{ matrix.build }}.tar.gz
|
||||
|
|
@ -184,7 +184,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
|
@ -226,7 +226,7 @@ jobs:
|
|||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz
|
||||
name: llama-bin-ubuntu-vulkan-x64.tar.gz
|
||||
|
|
@ -242,7 +242,7 @@ jobs:
|
|||
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
|
@ -278,7 +278,7 @@ jobs:
|
|||
7z a -snl llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-cpu-${{ matrix.arch }}.zip
|
||||
name: llama-bin-win-cpu-${{ matrix.arch }}.zip
|
||||
|
|
@ -305,7 +305,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -360,7 +360,7 @@ jobs:
|
|||
7z a -snl llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
|
||||
name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip
|
||||
|
|
@ -375,7 +375,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Install ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -416,7 +416,7 @@ jobs:
|
|||
7z a -snl llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
|
|
@ -431,7 +431,7 @@ jobs:
|
|||
7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\*
|
||||
|
||||
- name: Upload Cuda runtime
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip
|
||||
|
|
@ -451,7 +451,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
|
|
@ -511,7 +511,7 @@ jobs:
|
|||
7z a -snl llama-bin-win-sycl-x64.zip ./build/bin/*
|
||||
|
||||
- name: Upload the release package
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-sycl-x64.zip
|
||||
name: llama-bin-win-sycl-x64.zip
|
||||
|
|
@ -531,7 +531,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
|
|
@ -542,7 +542,7 @@ jobs:
|
|||
|
||||
- name: Cache ROCm Installation
|
||||
id: cache-rocm
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: C:\Program Files\AMD\ROCm
|
||||
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
|
||||
|
|
@ -617,7 +617,7 @@ jobs:
|
|||
7z a -snl llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-bin-win-hip-${{ matrix.name }}-x64.zip
|
||||
name: llama-bin-win-hip-${{ matrix.name }}-x64.zip
|
||||
|
|
@ -627,7 +627,7 @@ jobs:
|
|||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
|
@ -672,7 +672,7 @@ jobs:
|
|||
zip -r -y llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
||||
name: llama-${{ steps.tag.outputs.name }}-xcframework.zip
|
||||
|
|
@ -703,7 +703,7 @@ jobs:
|
|||
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
|
@ -763,7 +763,7 @@ jobs:
|
|||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz
|
||||
name: llama-bin-${{ matrix.chip_type }}-openEuler-${{ matrix.arch }}${{ matrix.use_acl_graph == 'on' && '-aclgraph' || '' }}.tar.gz
|
||||
|
|
@ -794,7 +794,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
|
|
@ -804,7 +804,7 @@ jobs:
|
|||
|
||||
- name: Download artifacts
|
||||
id: download-artifact
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
path: ./artifact
|
||||
merge-multiple: true
|
||||
|
|
@ -887,7 +887,7 @@ jobs:
|
|||
|
||||
- name: Upload release
|
||||
id: upload_release
|
||||
uses: actions/github-script@v3
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
github-token: ${{secrets.GITHUB_TOKEN}}
|
||||
script: |
|
||||
|
|
@ -897,7 +897,7 @@ jobs:
|
|||
for (let file of await fs.readdirSync('./release')) {
|
||||
if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) {
|
||||
console.log('uploadReleaseAsset', file);
|
||||
await github.repos.uploadReleaseAsset({
|
||||
await github.rest.repos.uploadReleaseAsset({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
release_id: release_id,
|
||||
|
|
|
|||
|
|
@ -37,14 +37,14 @@ jobs:
|
|||
continue-on-error: true
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Setup Node.js
|
||||
id: node
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
|
|
@ -131,14 +131,14 @@ jobs:
|
|||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
|
@ -148,7 +148,7 @@ jobs:
|
|||
pip install -r tools/server/tests/requirements.txt
|
||||
|
||||
- name: Setup Node.js for WebUI
|
||||
uses: actions/setup-node@v4
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: "22"
|
||||
cache: "npm"
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ jobs:
|
|||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
|
@ -72,12 +72,12 @@ jobs:
|
|||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON -DGGML_SCHED_NO_REALLOC=ON
|
||||
cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
|
@ -100,7 +100,7 @@ jobs:
|
|||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
|
||||
|
|
@ -108,12 +108,12 @@ jobs:
|
|||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON
|
||||
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON -DGGML_SCHED_NO_REALLOC=ON
|
||||
cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
|
||||
|
||||
- name: Python setup
|
||||
id: setup_python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
|
|
|
|||
|
|
@ -14,14 +14,14 @@ on:
|
|||
|
||||
jobs:
|
||||
update-ops-docs:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on: ubuntu-slim
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.x'
|
||||
|
||||
|
|
|
|||
|
|
@ -21,23 +21,24 @@ jobs:
|
|||
|
||||
- name: Find latest release
|
||||
id: find_latest_release
|
||||
uses: actions/github-script@v6
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const { data: releases } = await github.rest.repos.listReleases({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
console.log("Latest release:", releases[0].tag_name);
|
||||
return releases[0].tag_name;
|
||||
const { tag_name: version, assets: assets } = releases.find(({assets}) => assets.find(asset => asset.name.includes('win-vulkan')));
|
||||
const { browser_download_url: asset_url } = assets.find(asset => asset.name.includes('win-vulkan'));
|
||||
console.log("Latest release:", version);
|
||||
core.setOutput('VERSION', version);
|
||||
core.setOutput('ASSETURL', asset_url);
|
||||
|
||||
- name: Update manifest
|
||||
env:
|
||||
VERSION: ${{ steps.find_latest_release.outputs.result }}
|
||||
run: |
|
||||
echo "Updating manifest..."
|
||||
komac update --version ${{ env.VERSION }} \
|
||||
--urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \
|
||||
komac update --version ${{ steps.find_latest_release.outputs.VERSION }} \
|
||||
--urls "${{ steps.find_latest_release.outputs.ASSETURL }}" \
|
||||
--token ${{ secrets.WINGET_GITHUB_TOKEN }} \
|
||||
--submit \
|
||||
ggml.llamacpp
|
||||
|
|
|
|||
|
|
@ -15,8 +15,10 @@
|
|||
/common/common.* @ggerganov
|
||||
/common/console.* @ggerganov
|
||||
/common/http.* @angt
|
||||
/common/jinja/ @ngxson @CISC @aldehir
|
||||
/common/llguidance.* @ggerganov
|
||||
/common/log.* @ggerganov
|
||||
/common/ngram-map.* @srogmann
|
||||
/common/peg-parser.* @aldehir
|
||||
/common/sampling.* @ggerganov
|
||||
/common/speculative.* @ggerganov
|
||||
|
|
@ -66,6 +68,7 @@
|
|||
/ggml/src/ggml-rpc/ @rgerganov
|
||||
/ggml/src/ggml-threading.* @ggerganov
|
||||
/ggml/src/ggml-vulkan/ @0cc4m
|
||||
/ggml/src/ggml-virtgpu/ @kpouget
|
||||
/ggml/src/ggml-webgpu/ @reeselevine
|
||||
/ggml/src/ggml-zdnn/ @taronaeo @Andreas-Krebbel @AlekseiNikiforovIBM
|
||||
/ggml/src/ggml.c @ggerganov
|
||||
|
|
|
|||
|
|
@ -132,6 +132,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
|
||||
- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
|
||||
- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
|
||||
- [x] [RWKV-7](https://huggingface.co/collections/shoumenchougou/rwkv7-gxx-gguf)
|
||||
- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
|
||||
- [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1)
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
|
|
@ -212,6 +213,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
- [llama.vim](https://github.com/ggml-org/llama.vim) (MIT)
|
||||
- [LARS](https://github.com/abgulati/LARS) (AGPL)
|
||||
- [Llama Assistant](https://github.com/vietanhdev/llama-assistant) (GPL)
|
||||
- [LlamaLib](https://github.com/undreamai/LlamaLib) (Apache-2.0)
|
||||
- [LLMFarm](https://github.com/guinmoon/LLMFarm?tab=readme-ov-file) (MIT)
|
||||
- [LLMUnity](https://github.com/undreamai/LLMUnity) (MIT)
|
||||
- [LMStudio](https://lmstudio.ai/) (proprietary)
|
||||
|
|
|
|||
|
|
@ -254,7 +254,7 @@ function gg_run_ctest_release {
|
|||
(time make -j$(nproc) ) 2>&1 | tee -a $OUT/${ci}-make.log
|
||||
|
||||
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||
(time ctest --output-on-failure -L main ) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||
(time ctest --output-on-failure -L 'main|python' ) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||
else
|
||||
(time ctest --output-on-failure -L main -E test-opt ) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -73,6 +73,10 @@ add_library(${TARGET} STATIC
|
|||
log.h
|
||||
ngram-cache.cpp
|
||||
ngram-cache.h
|
||||
ngram-map.cpp
|
||||
ngram-map.h
|
||||
ngram-mod.cpp
|
||||
ngram-mod.h
|
||||
peg-parser.cpp
|
||||
peg-parser.h
|
||||
preset.cpp
|
||||
|
|
|
|||
149
common/arg.cpp
149
common/arg.cpp
|
|
@ -6,6 +6,7 @@
|
|||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "sampling.h"
|
||||
#include "speculative.h"
|
||||
#include "preset.h"
|
||||
|
||||
// fix problem with std::min and std::max
|
||||
|
|
@ -579,14 +580,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
|||
params.mmproj = res.mmproj;
|
||||
}
|
||||
// only download mmproj if the current example is using it
|
||||
for (auto & ex : mmproj_examples) {
|
||||
for (const auto & ex : mmproj_examples) {
|
||||
if (ctx_arg.ex == ex) {
|
||||
common_params_handle_model(params.mmproj, params.hf_token, params.offline);
|
||||
break;
|
||||
}
|
||||
}
|
||||
common_params_handle_model(params.speculative.model, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline);
|
||||
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
|
||||
}
|
||||
|
||||
// model is required (except for server)
|
||||
|
|
@ -1216,21 +1217,25 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"-lcs", "--lookup-cache-static"}, "FNAME",
|
||||
"path to static lookup cache to use for lookup decoding (not updated by generation)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.lookup_cache_static = value;
|
||||
params.speculative.lookup_cache_static = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-lcd", "--lookup-cache-dynamic"}, "FNAME",
|
||||
"path to dynamic lookup cache to use for lookup decoding (updated by generation)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.lookup_cache_dynamic = value;
|
||||
params.speculative.lookup_cache_dynamic = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP}));
|
||||
).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-c", "--ctx-size"}, "N",
|
||||
string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx),
|
||||
[](common_params & params, int value) {
|
||||
params.n_ctx = value;
|
||||
if (value == 0) {
|
||||
// disable context reduction in llama_params_fit if the user explicitly requests the full context size:
|
||||
params.fit_params_min_ctx = UINT32_MAX;
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_ARG_CTX_SIZE"));
|
||||
add_opt(common_arg(
|
||||
|
|
@ -1291,11 +1296,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"-kvu", "--kv-unified"},
|
||||
{"-no-kvu", "--no-kv-unified"},
|
||||
"use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)",
|
||||
[](common_params & params) {
|
||||
params.kv_unified = true;
|
||||
[](common_params & params, bool value) {
|
||||
params.kv_unified = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED}));
|
||||
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
|
||||
add_opt(common_arg(
|
||||
{"--context-shift"},
|
||||
{"--no-context-shift"},
|
||||
|
|
@ -1573,7 +1579,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--temp"}, "N",
|
||||
string_format("temperature (default: %.1f)", (double)params.sampling.temp),
|
||||
string_format("temperature (default: %.2f)", (double)params.sampling.temp),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.temp = std::stof(value);
|
||||
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
|
||||
|
|
@ -1590,7 +1596,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam().set_env("LLAMA_ARG_TOP_K"));
|
||||
add_opt(common_arg(
|
||||
{"--top-p"}, "N",
|
||||
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
|
||||
string_format("top-p sampling (default: %.2f, 1.0 = disabled)", (double)params.sampling.top_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.top_p = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
|
||||
|
|
@ -1598,7 +1604,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--min-p"}, "N",
|
||||
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
|
||||
string_format("min-p sampling (default: %.2f, 0.0 = disabled)", (double)params.sampling.min_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.min_p = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
|
||||
|
|
@ -1606,14 +1612,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--top-nsigma"}, "N",
|
||||
string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma),
|
||||
string_format("top-n-sigma sampling (default: %.2f, -1.0 = disabled)", params.sampling.top_n_sigma),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.top_n_sigma = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-probability"}, "N",
|
||||
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
||||
string_format("xtc probability (default: %.2f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.xtc_probability = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
|
||||
|
|
@ -1621,7 +1627,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--xtc-threshold"}, "N",
|
||||
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
|
||||
string_format("xtc threshold (default: %.2f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.xtc_threshold = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
|
||||
|
|
@ -1629,7 +1635,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--typical"}, "N",
|
||||
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sampling.typ_p),
|
||||
string_format("locally typical sampling, parameter p (default: %.2f, 1.0 = disabled)", (double)params.sampling.typ_p),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.typ_p = std::stof(value);
|
||||
}
|
||||
|
|
@ -1648,7 +1654,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--repeat-penalty"}, "N",
|
||||
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
|
||||
string_format("penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.penalty_repeat = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
|
||||
|
|
@ -1656,21 +1662,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--presence-penalty"}, "N",
|
||||
string_format("repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_present),
|
||||
string_format("repeat alpha presence penalty (default: %.2f, 0.0 = disabled)", (double)params.sampling.penalty_present),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.penalty_present = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--frequency-penalty"}, "N",
|
||||
string_format("repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)", (double)params.sampling.penalty_freq),
|
||||
string_format("repeat alpha frequency penalty (default: %.2f, 0.0 = disabled)", (double)params.sampling.penalty_freq),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.penalty_freq = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--dry-multiplier"}, "N",
|
||||
string_format("set DRY sampling multiplier (default: %.1f, 0.0 = disabled)", (double)params.sampling.dry_multiplier),
|
||||
string_format("set DRY sampling multiplier (default: %.2f, 0.0 = disabled)", (double)params.sampling.dry_multiplier),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.dry_multiplier = std::stof(value);
|
||||
}
|
||||
|
|
@ -1751,14 +1757,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--dynatemp-range"}, "N",
|
||||
string_format("dynamic temperature range (default: %.1f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
|
||||
string_format("dynamic temperature range (default: %.2f, 0.0 = disabled)", (double)params.sampling.dynatemp_range),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.dynatemp_range = std::stof(value);
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--dynatemp-exp"}, "N",
|
||||
string_format("dynamic temperature exponent (default: %.1f)", (double)params.sampling.dynatemp_exponent),
|
||||
string_format("dynamic temperature exponent (default: %.2f)", (double)params.sampling.dynatemp_exponent),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.dynatemp_exponent = std::stof(value);
|
||||
}
|
||||
|
|
@ -1774,7 +1780,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--mirostat-lr"}, "N",
|
||||
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
|
||||
string_format("Mirostat learning rate, parameter eta (default: %.2f)", (double)params.sampling.mirostat_eta),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.mirostat_eta = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
|
||||
|
|
@ -1782,7 +1788,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--mirostat-ent"}, "N",
|
||||
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
|
||||
string_format("Mirostat target entropy, parameter tau (default: %.2f)", (double)params.sampling.mirostat_tau),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.sampling.mirostat_tau = std::stof(value);
|
||||
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
|
||||
|
|
@ -1916,28 +1922,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_env("LLAMA_ARG_YARN_ORIG_CTX"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-ext-factor"}, "N",
|
||||
string_format("YaRN: extrapolation mix factor (default: %.1f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
|
||||
string_format("YaRN: extrapolation mix factor (default: %.2f, 0.0 = full interpolation)", (double)params.yarn_ext_factor),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_ext_factor = std::stof(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_YARN_EXT_FACTOR"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-attn-factor"}, "N",
|
||||
string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.1f)", (double)params.yarn_attn_factor),
|
||||
string_format("YaRN: scale sqrt(t) or attention magnitude (default: %.2f)", (double)params.yarn_attn_factor),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_attn_factor = std::stof(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_YARN_ATTN_FACTOR"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-beta-slow"}, "N",
|
||||
string_format("YaRN: high correction dim or alpha (default: %.1f)", (double)params.yarn_beta_slow),
|
||||
string_format("YaRN: high correction dim or alpha (default: %.2f)", (double)params.yarn_beta_slow),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_beta_slow = std::stof(value);
|
||||
}
|
||||
).set_env("LLAMA_ARG_YARN_BETA_SLOW"));
|
||||
add_opt(common_arg(
|
||||
{"--yarn-beta-fast"}, "N",
|
||||
string_format("YaRN: low correction dim or beta (default: %.1f)", (double)params.yarn_beta_fast),
|
||||
string_format("YaRN: low correction dim or beta (default: %.2f)", (double)params.yarn_beta_fast),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.yarn_beta_fast = std::stof(value);
|
||||
}
|
||||
|
|
@ -2194,18 +2200,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
add_opt(common_arg(
|
||||
{"--mmap"},
|
||||
{"--no-mmap"},
|
||||
string_format("whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"),
|
||||
string_format("whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: %s)", params.use_mmap ? "enabled" : "disabled"),
|
||||
[](common_params & params, bool value) {
|
||||
params.use_mmap = value;
|
||||
if (value) {
|
||||
params.use_direct_io = false; // disable direct io when mmap is explicitly enabled
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_ARG_MMAP"));
|
||||
add_opt(common_arg(
|
||||
{"-dio", "--direct-io"},
|
||||
{"-ndio", "--no-direct-io"},
|
||||
string_format("use DirectIO if available. Takes precedence over --mmap (default: %s)", params.use_direct_io ? "enabled" : "disabled"),
|
||||
string_format("use DirectIO if available. (default: %s)", params.use_direct_io ? "enabled" : "disabled"),
|
||||
[](common_params & params, bool value) {
|
||||
params.use_direct_io = value;
|
||||
}
|
||||
|
|
@ -2561,7 +2564,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
|
||||
"Same as --hf-repo, but for the draft model (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.model.hf_repo = value;
|
||||
params.speculative.mparams_dft.hf_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HFD_REPO"));
|
||||
add_opt(common_arg(
|
||||
|
|
@ -3331,14 +3334,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_MIN"));
|
||||
add_opt(common_arg(
|
||||
{"--draft-p-split"}, "P",
|
||||
string_format("speculative decoding split probability (default: %.1f)", (double)params.speculative.p_split),
|
||||
string_format("speculative decoding split probability (default: %.2f)", (double)params.speculative.p_split),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.p_split = std::stof(value);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE}).set_env("LLAMA_ARG_DRAFT_P_SPLIT"));
|
||||
add_opt(common_arg(
|
||||
{"--draft-p-min"}, "P",
|
||||
string_format("minimum speculative decoding probability (greedy) (default: %.1f)", (double)params.speculative.p_min),
|
||||
string_format("minimum speculative decoding probability (greedy) (default: %.2f)", (double)params.speculative.p_min),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.p_min = std::stof(value);
|
||||
}
|
||||
|
|
@ -3382,7 +3385,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
{"-md", "--model-draft"}, "FNAME",
|
||||
"draft model for speculative decoding (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.model.path = value;
|
||||
params.speculative.mparams_dft.path = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_MODEL_DRAFT"));
|
||||
add_opt(common_arg(
|
||||
|
|
@ -3392,6 +3395,68 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.speculative.replacements.push_back({ tgt, dft });
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
|
||||
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
|
||||
common_speculative_type_to_str(params.speculative.type).c_str()),
|
||||
[](common_params & params, const std::string & value) {
|
||||
if (value == "none") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
|
||||
} else if (value == "ngram-cache") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
|
||||
} else if (value == "ngram-simple") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE;
|
||||
} else if (value == "ngram-map-k") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
|
||||
} else if (value == "ngram-map-k4v") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
|
||||
} else if (value == "ngram-mod") {
|
||||
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
|
||||
} else {
|
||||
throw std::invalid_argument("unknown speculative decoding type without draft model");
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-n"}, "N",
|
||||
string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1 || value > 1024) {
|
||||
throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive");
|
||||
}
|
||||
params.speculative.ngram_size_n = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-size-m"}, "N",
|
||||
string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1 || value > 1024) {
|
||||
throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive");
|
||||
}
|
||||
params.speculative.ngram_size_m = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-check-rate"}, "N",
|
||||
string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1) {
|
||||
throw std::invalid_argument("ngram check rate must be at least 1");
|
||||
}
|
||||
params.speculative.ngram_check_rate = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--spec-ngram-min-hits"}, "N",
|
||||
string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),
|
||||
[](common_params & params, int value) {
|
||||
if (value < 1) {
|
||||
throw std::invalid_argument("ngram min hits must be at least 1");
|
||||
}
|
||||
params.speculative.ngram_min_hits = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
|
||||
string_format(
|
||||
|
|
@ -3618,8 +3683,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
|
||||
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
|
|
@ -3634,8 +3699,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF";
|
||||
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
|
||||
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
|
||||
params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
|
||||
params.port = 8012;
|
||||
params.n_ubatch = 1024;
|
||||
params.n_batch = 1024;
|
||||
|
|
|
|||
|
|
@ -129,7 +129,7 @@ static void parse_json_tool_calls(
|
|||
}
|
||||
}
|
||||
|
||||
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
|
||||
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax)
|
||||
: input_(input), is_partial_(is_partial), syntax_(syntax)
|
||||
{
|
||||
result_.role = "assistant";
|
||||
|
|
@ -1611,7 +1611,7 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
|||
builder.finish();
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax) {
|
||||
if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE ||
|
||||
syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) {
|
||||
|
|
@ -1630,12 +1630,12 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
|
|||
}
|
||||
auto msg = builder.result();
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax) {
|
||||
if (parser.empty()) {
|
||||
throw std::runtime_error("Failed to parse due to missing parser definition.");
|
||||
}
|
||||
|
|
@ -1663,7 +1663,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std
|
|||
mapper.from_ast(ctx.ast, result);
|
||||
}
|
||||
if (!is_partial) {
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
|
||||
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str());
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
#include "json-partial.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
|
@ -19,20 +19,20 @@ class common_chat_msg_partial_exception : public std::runtime_error {
|
|||
class common_chat_msg_parser {
|
||||
std::string input_;
|
||||
bool is_partial_;
|
||||
common_chat_syntax syntax_;
|
||||
common_chat_parser_params syntax_; // TODO: rename to params
|
||||
std::string healing_marker_;
|
||||
|
||||
size_t pos_ = 0;
|
||||
common_chat_msg result_;
|
||||
|
||||
public:
|
||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
const std::string & input() const { return input_; }
|
||||
size_t pos() const { return pos_; }
|
||||
const std::string & healing_marker() const { return healing_marker_; }
|
||||
const bool & is_partial() const { return is_partial_; }
|
||||
const common_chat_msg & result() const { return result_; }
|
||||
const common_chat_syntax & syntax() const { return syntax_; }
|
||||
const common_chat_parser_params & syntax() const { return syntax_; }
|
||||
|
||||
void move_to(size_t pos) {
|
||||
if (pos > input_.size()) {
|
||||
|
|
|
|||
412
common/chat.cpp
412
common/chat.cpp
|
|
@ -7,9 +7,6 @@
|
|||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
// #include <minja/chat-template.hpp>
|
||||
// #include <minja/minja.hpp>
|
||||
|
||||
#include "jinja/parser.h"
|
||||
#include "jinja/value.h"
|
||||
#include "jinja/runtime.h"
|
||||
|
|
@ -56,39 +53,73 @@ static bool has_content_or_tool_calls(const common_chat_msg & msg) {
|
|||
return !msg.content.empty() || !msg.tool_calls.empty();
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_msg::to_json_oaicompat() const
|
||||
{
|
||||
json message {
|
||||
{"role", "assistant"},
|
||||
};
|
||||
if (!reasoning_content.empty()) {
|
||||
message["reasoning_content"] = reasoning_content;
|
||||
json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
if (!content.empty() && !content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
}
|
||||
if (content.empty() && !tool_calls.empty()) {
|
||||
message["content"] = json();
|
||||
json jmsg {
|
||||
{"role", role},
|
||||
};
|
||||
if (!content.empty()) {
|
||||
jmsg["content"] = content;
|
||||
} else if (!content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
for (const auto & part : content_parts) {
|
||||
if (part.type != "text") {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
text += '\n';
|
||||
}
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
} else {
|
||||
auto & parts = jmsg["content"] = json::array();
|
||||
for (const auto & part : content_parts) {
|
||||
parts.push_back({
|
||||
{"type", part.type},
|
||||
{"text", part.text},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
message["content"] = content;
|
||||
jmsg["content"] = "";
|
||||
}
|
||||
if (!reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = reasoning_content;
|
||||
}
|
||||
if (!tool_name.empty()) {
|
||||
jmsg["name"] = tool_name;
|
||||
}
|
||||
if (!tool_call_id.empty()) {
|
||||
jmsg["tool_call_id"] = tool_call_id;
|
||||
}
|
||||
if (!tool_calls.empty()) {
|
||||
auto arr = json::array();
|
||||
for (const auto & tc : tool_calls) {
|
||||
arr.push_back({
|
||||
jmsg["tool_calls"] = json::array();
|
||||
auto & jtool_calls = jmsg["tool_calls"];
|
||||
for (const auto & tool_call : tool_calls) {
|
||||
json tc {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tc.name},
|
||||
{"arguments", tc.arguments},
|
||||
{"name", tool_call.name},
|
||||
{"arguments", tool_call.arguments},
|
||||
}},
|
||||
{"id", tc.id},
|
||||
// // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// // We only generate a random id for the ones that don't generate one by themselves
|
||||
// // (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
// {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
});
|
||||
};
|
||||
if (!tool_call.id.empty()) {
|
||||
tc["id"] = tool_call.id;
|
||||
}
|
||||
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
||||
// We only generate a random id for the ones that don't generate one by themselves
|
||||
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
||||
// {"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
||||
jtool_calls.push_back(tc);
|
||||
}
|
||||
message["tool_calls"] = arr;
|
||||
}
|
||||
return message;
|
||||
|
||||
return jmsg;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) {
|
||||
|
|
@ -256,7 +287,6 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
|
|||
return rendered_no_thinking.prompt != rendered_with_thinking.prompt;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
||||
std::vector<common_chat_msg> msgs;
|
||||
|
||||
|
|
@ -350,80 +380,15 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
|||
return msgs;
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
||||
json messages = json::array();
|
||||
for (const auto & msg : msgs) {
|
||||
if (!msg.content.empty() && !msg.content_parts.empty()) {
|
||||
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||
}
|
||||
json jmsg {
|
||||
{"role", msg.role},
|
||||
};
|
||||
if (!msg.content.empty()) {
|
||||
jmsg["content"] = msg.content;
|
||||
} else if (!msg.content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
for (const auto & part : msg.content_parts) {
|
||||
if (part.type != "text") {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
text += '\n';
|
||||
}
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
} else {
|
||||
auto & parts = jmsg["content"] = json::array();
|
||||
for (const auto & part : msg.content_parts) {
|
||||
parts.push_back({
|
||||
{"type", part.type},
|
||||
{"text", part.text},
|
||||
});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
jmsg["content"] = "";
|
||||
}
|
||||
if (!msg.reasoning_content.empty()) {
|
||||
jmsg["reasoning_content"] = msg.reasoning_content;
|
||||
}
|
||||
if (!msg.tool_name.empty()) {
|
||||
jmsg["name"] = msg.tool_name;
|
||||
}
|
||||
if (!msg.tool_call_id.empty()) {
|
||||
jmsg["tool_call_id"] = msg.tool_call_id;
|
||||
}
|
||||
if (!msg.tool_calls.empty()) {
|
||||
auto & tool_calls = jmsg["tool_calls"] = json::array();
|
||||
for (const auto & tool_call : msg.tool_calls) {
|
||||
json tc {
|
||||
{"type", "function"},
|
||||
{"function", {
|
||||
{"name", tool_call.name},
|
||||
{"arguments", tool_call.arguments},
|
||||
}},
|
||||
};
|
||||
if (!tool_call.id.empty()) {
|
||||
tc["id"] = tool_call.id;
|
||||
}
|
||||
tool_calls.push_back(tc);
|
||||
}
|
||||
}
|
||||
json jmsg = msg.to_json_oaicompat(concat_typed_text);
|
||||
messages.push_back(jmsg);
|
||||
}
|
||||
return messages;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
|
||||
return common_chat_msgs_parse_oaicompat(json::parse(messages));
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
||||
std::vector<common_chat_tool> result;
|
||||
|
||||
|
|
@ -459,12 +424,6 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
|
|||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
|
||||
return common_chat_tools_parse_oaicompat(json::parse(tools));
|
||||
}
|
||||
|
||||
template <>
|
||||
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
||||
if (tools.empty()) {
|
||||
return json();
|
||||
|
|
@ -484,7 +443,7 @@ json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & t
|
|||
return result;
|
||||
}
|
||||
|
||||
template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json delta = json::object();
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
delta["reasoning_content"] = diff.reasoning_content_delta;
|
||||
|
|
@ -601,18 +560,18 @@ bool common_chat_templates_was_explicit(const struct common_chat_templates * tmp
|
|||
return tmpls->has_explicit_template;
|
||||
}
|
||||
|
||||
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
|
||||
if (variant != nullptr) {
|
||||
if (strcmp(variant, "tool_use") == 0) {
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant) {
|
||||
if (!variant.empty()) {
|
||||
if (variant == "tool_use") {
|
||||
if (tmpls->template_tool_use) {
|
||||
return tmpls->template_tool_use->source().c_str();
|
||||
return tmpls->template_tool_use->source();
|
||||
}
|
||||
return nullptr;
|
||||
return "";
|
||||
} else {
|
||||
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
|
||||
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str());
|
||||
}
|
||||
}
|
||||
return tmpls->template_default->source().c_str();
|
||||
return tmpls->template_default->source();
|
||||
}
|
||||
|
||||
common_chat_templates_ptr common_chat_templates_init(
|
||||
|
|
@ -812,10 +771,12 @@ static std::string apply(
|
|||
|
||||
nlohmann::ordered_json inp = nlohmann::ordered_json{
|
||||
{"messages", messages_override.has_value() ? *messages_override : inputs.messages},
|
||||
{"tools", tools_override.has_value() ? *tools_override : inputs.tools},
|
||||
{"bos_token", tmpl.bos_token()},
|
||||
{"eos_token", tmpl.eos_token()},
|
||||
};
|
||||
if (tools_override.has_value() || !inputs.tools.empty()) {
|
||||
inp["tools"] = tools_override.has_value() ? *tools_override : inputs.tools;
|
||||
}
|
||||
if (inputs.extra_context.is_object()) {
|
||||
// TODO: do we need to merge, or replacing is fine?
|
||||
for (const auto & [k, v] : inputs.extra_context.items()) {
|
||||
|
|
@ -831,9 +792,6 @@ static std::string apply(
|
|||
if (inputs.add_generation_prompt) {
|
||||
inp["add_generation_prompt"] = true;
|
||||
}
|
||||
if (inp["tools"].is_null()) {
|
||||
inp["tools"] = json::array();
|
||||
}
|
||||
|
||||
jinja::global_from_json(ctx, inp, inputs.mark_input);
|
||||
|
||||
|
|
@ -2260,12 +2218,11 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp
|
|||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
const std::optional<json> tools_override = json();
|
||||
const std::optional<json> additional_context = json {
|
||||
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
};
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, tools_override, additional_context);
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override =*/ std::nullopt, additional_context);
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
|
|
@ -2614,20 +2571,165 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
|
|||
static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// TODO: Reasoning effort
|
||||
json additional_context = {};
|
||||
// Copy `reasoning_content` to `reasoning`
|
||||
auto adjusted_messages = json::array();
|
||||
for (const auto & msg : inputs.messages) {
|
||||
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
|
||||
auto adjusted_message = msg;
|
||||
adjusted_message["reasoning"] = msg.at("reasoning_content");
|
||||
adjusted_message.erase("reasoning_content");
|
||||
adjusted_messages.push_back(adjusted_message);
|
||||
} else {
|
||||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context);
|
||||
data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN;
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto include_grammar = true;
|
||||
|
||||
auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
|
||||
// Check if we need to replace the flush token with end token during inference and without generation prompt.
|
||||
if (inputs.is_inference && !inputs.add_generation_prompt) {
|
||||
static constexpr std::string_view return_token = "<|flush|>";
|
||||
static constexpr std::string_view end_token = "<|end|>";
|
||||
if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) {
|
||||
prompt.replace(pos, return_token.length(), end_token);
|
||||
}
|
||||
}
|
||||
|
||||
data.prompt = prompt;
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
"<|think|>",
|
||||
"<|content|>",
|
||||
"<|begin|>",
|
||||
"<|end|>",
|
||||
"<|tool_calls|>",
|
||||
"<|tool_call:begin|>",
|
||||
"<|tool_call:end|>",
|
||||
"<|tool_call:name|>",
|
||||
"<|tool_call:args|>",
|
||||
};
|
||||
|
||||
// TODO: Tool calling
|
||||
auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
|
||||
auto lit_think = p.atomic(p.literal("<|think|>"));
|
||||
auto lit_assistant_begin = p.atomic(p.literal("<|begin|>assistant"));
|
||||
auto lit_content = p.atomic(p.literal("<|content|>"));
|
||||
auto lit_end = p.atomic(p.literal("<|end|>"));
|
||||
auto parser_until_end = p.until("<|end|>");
|
||||
|
||||
// reasoning <- "<|think|>" (!"<|end|>" .)*
|
||||
auto parser_reasoning = p.rule("reasoning", lit_think + p.reasoning(parser_until_end));
|
||||
|
||||
// content <- "<|content|>" (!"<|end|>" .)*
|
||||
auto parser_content = p.rule("content", lit_content + p.content(parser_until_end));
|
||||
|
||||
// wrap_choice(items) <- item-choice wrapped*
|
||||
// item-choice <- items[0] / ... / items[n]
|
||||
// wrapped <- "<|end|><|begin|>assistant" item-choice
|
||||
auto wrap_choice = [&](const std::vector<common_peg_parser> & items) {
|
||||
auto choice = p.choice(items);
|
||||
return choice + p.zero_or_more(lit_end + lit_assistant_begin + choice);
|
||||
};
|
||||
|
||||
// wrap_seq(items) <- item[0] "<|end|><|begin|>assistant" item[1] ...
|
||||
auto wrap_seq = [&](const std::vector<common_peg_parser> & items) {
|
||||
auto seq = p.sequence();
|
||||
for (auto i = 0u; i < items.size(); i++) {
|
||||
if (i == 0) {
|
||||
seq += items[i];
|
||||
continue;
|
||||
}
|
||||
seq += lit_end + lit_assistant_begin + items[i];
|
||||
}
|
||||
return seq;
|
||||
};
|
||||
|
||||
// Response format parser
|
||||
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
|
||||
auto parser_response_format = lit_content + p.content(p.schema(p.json(), "response-format", inputs.json_schema));
|
||||
return p.choice({
|
||||
wrap_seq({parser_reasoning, parser_response_format}),
|
||||
wrap_seq({parser_response_format})
|
||||
});
|
||||
}
|
||||
|
||||
auto lit_tool_call_begin = p.literal("<|tool_call:begin|>");
|
||||
auto lit_tool_call_name = p.literal("<|tool_call:name|>");
|
||||
auto lit_tool_call_args = p.literal("<|tool_call:args|>");
|
||||
auto lit_tool_call_end = p.literal("<|tool_call:end|>");
|
||||
|
||||
// Tool call parser
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
auto parser_tool_call = p.choice();
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
const auto & schema = function.at("parameters");
|
||||
|
||||
// tool(name, schema) <- name "<|tool_call:args|>" schema
|
||||
parser_tool_call |= p.rule("tool-" + name,
|
||||
p.atomic(p.tool_name(p.literal(name)) + lit_tool_call_args)
|
||||
+ p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
|
||||
});
|
||||
|
||||
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
|
||||
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
||||
|
||||
// tool-calls <- "<|tool_calls|>" tool-call+
|
||||
// tool-call <- "<|tool_call:begin|> call-id "<|tool_call:name|>" &([^<]+ "<|tool_call:args|>") tool-choice "<|tool_call:end|>"
|
||||
// call-id <- [a-zA-Z0-9_-]+
|
||||
// tool-choice <- tool(t[0].name, t[0].schema) / ... / tool(t[n].name, t[n].schema)
|
||||
auto parser_tool_calls = p.trigger_rule("tool-calls",
|
||||
p.atomic(p.literal("<|tool_calls|>"))
|
||||
+ p.repeat(
|
||||
p.tool_open(
|
||||
lit_tool_call_begin
|
||||
+ p.tool_id(p.chars("[a-zA-Z0-9_-]", 1, -1))
|
||||
+ lit_tool_call_name
|
||||
+ p.peek(p.chars("[^<]", 1, -1) + lit_tool_call_args))
|
||||
+ parser_tool_call
|
||||
+ p.tool_close(lit_tool_call_end),
|
||||
/* min = */ 1,
|
||||
/* max = */ max_calls));
|
||||
|
||||
if (min_calls == 1) {
|
||||
// If required, then try any combination of the reasoning, content, and tool call
|
||||
return p.choice({
|
||||
wrap_seq({parser_reasoning, parser_content, parser_tool_calls}),
|
||||
wrap_seq({parser_reasoning, parser_tool_calls}),
|
||||
wrap_seq({parser_content, parser_tool_calls}),
|
||||
wrap_seq({parser_tool_calls})
|
||||
});
|
||||
}
|
||||
|
||||
return wrap_choice({parser_reasoning, parser_content, parser_tool_calls});
|
||||
}
|
||||
|
||||
// Content only parser
|
||||
include_grammar = false;
|
||||
return wrap_choice({parser_reasoning, parser_content});
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.at("parameters");
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls|>"}
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
|
@ -2691,6 +2793,51 @@ static common_chat_params common_chat_params_init_exaone_moe(const common_chat_t
|
|||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_translate_gemma(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// This template does not support tools or reasoning
|
||||
// we just need to transform the messages into the correct schema
|
||||
|
||||
templates_params inputs_new = inputs;
|
||||
json & messages = inputs_new.messages;
|
||||
|
||||
// default to chat_template_kwargs, or en-GB if not specified
|
||||
std::string default_src_lang = inputs.extra_context.value("source_lang_code", "en-GB");
|
||||
std::string default_tgt_lang = inputs.extra_context.value("target_lang_code", "en-GB");
|
||||
|
||||
GGML_ASSERT(messages.is_array());
|
||||
for (auto & message : messages) {
|
||||
if (message.contains("role") && message["role"].get<std::string>() != "user") {
|
||||
continue;
|
||||
}
|
||||
if (!message.contains("content")) {
|
||||
message["content"] = json::array();
|
||||
}
|
||||
if (message.contains("content") && !message["content"].is_array()) {
|
||||
auto content_str = message["content"].get<std::string>();
|
||||
// default to en-GB if not specified (to make common_chat_format_example works)
|
||||
auto src_lang = message.contains("source_lang_code")
|
||||
? message["source_lang_code"].get<std::string>() : default_src_lang;
|
||||
auto tgt_lang = message.contains("target_lang_code")
|
||||
? message["target_lang_code"].get<std::string>() : default_tgt_lang;
|
||||
message["content"] = json::array({
|
||||
json{
|
||||
{"type", "text"},
|
||||
{"text", content_str},
|
||||
{"source_lang_code", src_lang},
|
||||
{"target_lang_code", tgt_lang},
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs_new, std::nullopt, std::nullopt);
|
||||
data.format = COMMON_CHAT_FORMAT_GENERIC;
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
|
|
@ -2867,13 +3014,13 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
const struct common_chat_templates_inputs & inputs)
|
||||
{
|
||||
templates_params params;
|
||||
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
|
||||
params.tools = common_chat_tools_to_json_oaicompat(inputs.tools);
|
||||
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
||||
? *tmpls->template_tool_use
|
||||
: *tmpls->template_default;
|
||||
const auto & src = tmpl.source();
|
||||
const auto & caps = tmpl.original_caps();
|
||||
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
||||
params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
||||
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.reasoning_format = inputs.reasoning_format;
|
||||
|
|
@ -2943,6 +3090,10 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
src.find("<arg_value>") != std::string::npos &&
|
||||
params.json_schema.is_null()) {
|
||||
workaround::func_args_not_string(params.messages);
|
||||
if (!params.extra_context.contains("clear_thinking")) {
|
||||
// by default, do not clear reasoning_content (added since GLM-4.7)
|
||||
params.extra_context["clear_thinking"] = false;
|
||||
}
|
||||
return common_chat_params_init_glm_4_5(tmpl, params);
|
||||
}
|
||||
|
||||
|
|
@ -3035,6 +3186,13 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
return common_chat_params_init_apriel_1_5(tmpl, params);
|
||||
}
|
||||
|
||||
// Solar Open
|
||||
if (src.find("<|tool_response:begin|>") != std::string::npos &&
|
||||
src.find("<|tool_response:name|>") != std::string::npos &&
|
||||
src.find("<|tool_response:result|>") != std::string::npos) {
|
||||
return common_chat_params_init_solar_open(tmpl, params);
|
||||
}
|
||||
|
||||
// Use generic handler when mixing tools + JSON schema.
|
||||
// TODO: support that mix in handlers below.
|
||||
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||
|
|
@ -3082,6 +3240,12 @@ static common_chat_params common_chat_templates_apply_jinja(
|
|||
return common_chat_params_init_solar_open(tmpl, params);
|
||||
}
|
||||
|
||||
// TranslateGemma
|
||||
if (src.find("[source_lang_code]") != std::string::npos &&
|
||||
src.find("[target_lang_code]") != std::string::npos) {
|
||||
return common_chat_params_init_translate_gemma(tmpl, params);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
|
|
@ -3174,3 +3338,9 @@ common_chat_params common_chat_templates_apply(
|
|||
? common_chat_templates_apply_jinja(tmpls, inputs)
|
||||
: common_chat_templates_apply_legacy(tmpls, inputs);
|
||||
}
|
||||
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates) {
|
||||
GGML_ASSERT(chat_templates != nullptr);
|
||||
GGML_ASSERT(chat_templates->template_default != nullptr);
|
||||
return chat_templates->template_default->caps.to_map();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@
|
|||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
struct common_chat_templates;
|
||||
|
||||
struct common_chat_tool_call {
|
||||
|
|
@ -26,6 +28,11 @@ struct common_chat_msg_content_part {
|
|||
std::string type;
|
||||
std::string text;
|
||||
|
||||
// TODO @ngxson : no known chat templates support reasoning_content in content parts yet
|
||||
// this can be useful for models with interleaved thinking (like Kimi-K2)
|
||||
// if you see any templates explicitly support this, please ping me
|
||||
// std::string reasoning_content;
|
||||
|
||||
bool operator==(const common_chat_msg_content_part & other) const {
|
||||
return type == other.type && text == other.text;
|
||||
}
|
||||
|
|
@ -40,7 +47,7 @@ struct common_chat_msg {
|
|||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
|
||||
template <class T> T to_json_oaicompat() const;
|
||||
nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const;
|
||||
|
||||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||
|
|
@ -145,7 +152,7 @@ struct common_chat_templates_inputs {
|
|||
std::vector<common_chat_tool> tools;
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking"
|
||||
bool enable_thinking = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
std::map<std::string, std::string> chat_template_kwargs;
|
||||
|
|
@ -165,14 +172,21 @@ struct common_chat_params {
|
|||
std::string parser;
|
||||
};
|
||||
|
||||
struct common_chat_syntax {
|
||||
// per-message parsing syntax
|
||||
// should be derived from common_chat_params
|
||||
struct common_chat_parser_params {
|
||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning"
|
||||
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||
bool reasoning_in_content = false;
|
||||
bool thinking_forced_open = false;
|
||||
bool parse_tool_calls = true;
|
||||
common_peg_arena parser = {};
|
||||
common_chat_parser_params() = default;
|
||||
common_chat_parser_params(const common_chat_params & chat_params) {
|
||||
format = chat_params.format;
|
||||
thinking_forced_open = chat_params.thinking_forced_open;
|
||||
}
|
||||
};
|
||||
|
||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||
|
|
@ -191,7 +205,7 @@ common_chat_templates_ptr common_chat_templates_init(
|
|||
const std::string & eos_token_override = "");
|
||||
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
|
||||
|
||||
struct common_chat_params common_chat_templates_apply(
|
||||
|
|
@ -213,23 +227,25 @@ std::string common_chat_format_example(
|
|||
const std::map<std::string, std::string> & chat_template_kwargs);
|
||||
|
||||
const char* common_chat_format_name(common_chat_format format);
|
||||
const char* common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax);
|
||||
|
||||
// used by arg and server
|
||||
const char * common_reasoning_format_name(common_reasoning_format format);
|
||||
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||
|
||||
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
|
||||
|
||||
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
||||
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);
|
||||
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||
|
||||
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);
|
||||
nlohmann::ordered_json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||
|
||||
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
|
||||
// get template caps, useful for reporting to server /props endpoint
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates);
|
||||
|
|
|
|||
|
|
@ -1097,7 +1097,10 @@ common_init_result::common_init_result(common_params & params) :
|
|||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
|
||||
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx,
|
||||
params.tensor_split,
|
||||
params.tensor_buft_overrides.data(),
|
||||
params.fit_params_target.data(),
|
||||
params.fit_params_min_ctx,
|
||||
params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR);
|
||||
}
|
||||
|
||||
|
|
@ -1208,10 +1211,6 @@ std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
|||
return pimpl->lora;
|
||||
}
|
||||
|
||||
void common_init_result::free_context() {
|
||||
pimpl->context.reset();
|
||||
}
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
common_init_result_ptr res(new common_init_result(params));
|
||||
|
||||
|
|
|
|||
|
|
@ -57,6 +57,8 @@ extern const char * LLAMA_COMMIT;
|
|||
extern const char * LLAMA_COMPILER;
|
||||
extern const char * LLAMA_BUILD_TARGET;
|
||||
|
||||
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
|
||||
|
||||
struct common_control_vector_load_info;
|
||||
|
||||
//
|
||||
|
|
@ -162,6 +164,17 @@ enum common_params_sampling_config : uint64_t {
|
|||
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||
};
|
||||
|
||||
enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
|
||||
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
|
||||
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
|
||||
};
|
||||
|
||||
// sampling parameters
|
||||
struct common_params_sampling {
|
||||
|
|
@ -240,17 +253,40 @@ struct common_params_model {
|
|||
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
|
||||
};
|
||||
|
||||
struct common_params_speculative {
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
struct common_ngram_mod;
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
struct common_params_speculative {
|
||||
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
|
||||
|
||||
// general-purpose speculative decoding parameters
|
||||
|
||||
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
|
||||
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
|
||||
float p_split = 0.1f; // speculative decoding split probability
|
||||
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
|
||||
|
||||
// ngram-based speculative decoding
|
||||
|
||||
uint16_t ngram_size_n = 12; // ngram size for lookup
|
||||
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
|
||||
uint16_t ngram_check_rate = 1; // check rate for ngram lookup
|
||||
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
|
||||
|
||||
std::shared_ptr<common_ngram_mod> ngram_mod;
|
||||
|
||||
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
|
||||
// draft-model speculative decoding
|
||||
|
||||
struct common_params_model mparams_dft;
|
||||
|
||||
llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts
|
||||
|
||||
llama_context_params cparams_dft; // these are the parameters for the draft llama_context
|
||||
|
||||
int32_t n_ctx = 0; // draft context size
|
||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||
|
||||
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
|
||||
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
|
||||
|
|
@ -258,7 +294,14 @@ struct common_params_speculative {
|
|||
struct cpu_params cpuparams;
|
||||
struct cpu_params cpuparams_batch;
|
||||
|
||||
struct common_params_model model;
|
||||
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> replacements; // main to speculative model replacements
|
||||
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
|
||||
|
||||
bool has_dft() const {
|
||||
return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
struct common_params_vocoder {
|
||||
|
|
@ -284,6 +327,7 @@ struct common_params_diffusion {
|
|||
};
|
||||
|
||||
// reasoning API response format (not to be confused as chat template's reasoning format)
|
||||
// only used by server
|
||||
enum common_reasoning_format {
|
||||
COMMON_REASONING_FORMAT_NONE,
|
||||
COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content`
|
||||
|
|
@ -375,8 +419,6 @@ struct common_params {
|
|||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
|
||||
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
|
||||
std::string input_suffix = ""; // string to suffix user inputs with // NOLINT
|
||||
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT
|
||||
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
|
||||
std::string logits_file = ""; // file for saving *all* logits // NOLINT
|
||||
|
||||
// llama-debug specific options
|
||||
|
|
@ -435,7 +477,7 @@ struct common_params {
|
|||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool use_mmap = true; // enable mmap to use filesystem cache
|
||||
bool use_direct_io = true; // read from disk without buffering for faster model loading
|
||||
bool use_direct_io = false; // read from disk without buffering
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool verbose_prompt = false; // print prompt tokens before generation
|
||||
bool display_prompt = true; // print prompt before generation
|
||||
|
|
@ -572,10 +614,6 @@ struct common_params {
|
|||
// return false from callback to abort model loading or true to continue
|
||||
llama_progress_callback load_progress_callback = NULL;
|
||||
void * load_progress_callback_user_data = NULL;
|
||||
|
||||
bool has_speculative() const {
|
||||
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||
}
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
|
|
@ -711,8 +749,6 @@ struct common_init_result {
|
|||
|
||||
std::vector<llama_adapter_lora_ptr> & lora();
|
||||
|
||||
void free_context();
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
|
|
|
|||
|
|
@ -314,23 +314,26 @@ static bool common_pull_file(httplib::Client & cli,
|
|||
|
||||
// download one single file from remote URL to local path
|
||||
// returns status code or -1 on error
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers) {
|
||||
static int common_download_file_single_online(const std::string & url,
|
||||
const std::string & path,
|
||||
const std::string & bearer_token,
|
||||
const common_header_list & custom_headers) {
|
||||
static const int max_attempts = 3;
|
||||
static const int retry_delay_seconds = 2;
|
||||
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
|
||||
if (!bearer_token.empty()) {
|
||||
default_headers.insert({"Authorization", "Bearer " + bearer_token});
|
||||
}
|
||||
httplib::Headers headers;
|
||||
for (const auto & h : custom_headers) {
|
||||
default_headers.emplace(h.first, h.second);
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
cli.set_default_headers(default_headers);
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
}
|
||||
if (!bearer_token.empty()) {
|
||||
headers.emplace("Authorization", "Bearer " + bearer_token);
|
||||
}
|
||||
cli.set_default_headers(headers);
|
||||
|
||||
const bool file_exists = std::filesystem::exists(path);
|
||||
|
||||
|
|
@ -437,10 +440,12 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
|
|||
const common_remote_params & params) {
|
||||
auto [cli, parts] = common_http_client(url);
|
||||
|
||||
httplib::Headers headers = {{"User-Agent", "llama-cpp"}};
|
||||
|
||||
for (const auto & header : params.headers) {
|
||||
headers.emplace(header.first, header.second);
|
||||
httplib::Headers headers;
|
||||
for (const auto & h : params.headers) {
|
||||
headers.emplace(h.first, h.second);
|
||||
}
|
||||
if (headers.find("User-Agent") == headers.end()) {
|
||||
headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
||||
}
|
||||
|
||||
if (params.timeout > 0) {
|
||||
|
|
|
|||
|
|
@ -57,6 +57,17 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
|
|||
throw std::runtime_error("error: invalid URL format");
|
||||
}
|
||||
|
||||
#ifndef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||
if (parts.scheme == "https") {
|
||||
throw std::runtime_error(
|
||||
"HTTPS is not supported. Please rebuild with one of:\n"
|
||||
" -DLLAMA_BUILD_BORINGSSL=ON\n"
|
||||
" -DLLAMA_BUILD_LIBRESSL=ON\n"
|
||||
" -DLLAMA_OPENSSL=ON (default, requires OpenSSL dev files installed)"
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
httplib::Client cli(parts.scheme + "://" + parts.host);
|
||||
|
||||
if (!parts.user.empty()) {
|
||||
|
|
|
|||
|
|
@ -61,14 +61,23 @@ static void caps_print_stats(value & v, const std::string & path) {
|
|||
ops.c_str());
|
||||
}
|
||||
|
||||
std::map<std::string, bool> caps::to_map() const {
|
||||
return {
|
||||
{"requires_typed_content", requires_typed_content},
|
||||
{"supports_tools", supports_tools},
|
||||
{"supports_tool_calls", supports_tool_calls},
|
||||
{"supports_parallel_tool_calls", supports_parallel_tool_calls},
|
||||
{"supports_system_role", supports_system_role},
|
||||
{"supports_preserve_reasoning", supports_preserve_reasoning},
|
||||
};
|
||||
}
|
||||
|
||||
std::string caps::to_string() const {
|
||||
std::ostringstream ss;
|
||||
ss << "Caps(\n";
|
||||
ss << " requires_typed_content=" << requires_typed_content << "\n";
|
||||
ss << " supports_tools=" << supports_tools << "\n";
|
||||
ss << " supports_tool_calls=" << supports_tool_calls << "\n";
|
||||
ss << " supports_parallel_tool_calls=" << supports_parallel_tool_calls << "\n";
|
||||
ss << " supports_system_role=" << supports_system_role << "\n";
|
||||
for (const auto & [key, value] : to_map()) {
|
||||
ss << " " << key << "=" << (value ? "true" : "false") << "\n";
|
||||
}
|
||||
ss << ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
|
@ -229,6 +238,40 @@ caps caps_get(jinja::program & prog) {
|
|||
}
|
||||
);
|
||||
|
||||
// case: preserve reasoning content in chat history
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
// messages
|
||||
return json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", "Assistant message"},
|
||||
{"reasoning_content", "Reasoning content"}
|
||||
},
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array();
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
auto & content = messages->at(1)->at("reasoning_content");
|
||||
caps_print_stats(content, "messages[1].reasoning_content");
|
||||
if (content->stats.used) {
|
||||
result.supports_preserve_reasoning = true;
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
JJ_DEBUG("%s\n", result.to_string().c_str());
|
||||
|
||||
return result;
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include "runtime.h"
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
namespace jinja {
|
||||
|
||||
|
|
@ -11,14 +12,17 @@ struct caps {
|
|||
bool supports_tool_calls = true;
|
||||
bool supports_system_role = true;
|
||||
bool supports_parallel_tool_calls = true;
|
||||
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
|
||||
|
||||
bool requires_typed_content = false; // default: use string content
|
||||
|
||||
// for reporting on server
|
||||
std::map<std::string, bool> to_map() const;
|
||||
|
||||
// for debugging
|
||||
std::string to_string() const;
|
||||
};
|
||||
|
||||
caps caps_get(jinja::program & prog);
|
||||
void debug_print_caps(const caps & c);
|
||||
|
||||
} // namespace jinja
|
||||
|
|
|
|||
|
|
@ -91,6 +91,16 @@ lexer_result lexer::tokenize(const std::string & source) {
|
|||
return str;
|
||||
};
|
||||
|
||||
auto consume_numeric = [&]() -> std::string {
|
||||
std::string num = consume_while(is_integer);
|
||||
if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
|
||||
++pos; // Consume '.'
|
||||
std::string frac = consume_while(is_integer);
|
||||
num += "." + frac;
|
||||
}
|
||||
return num;
|
||||
};
|
||||
|
||||
auto next_pos_is = [&](std::initializer_list<char> chars, size_t n = 1) -> bool {
|
||||
if (pos + n >= src.size()) return false;
|
||||
for (char c : chars) {
|
||||
|
|
@ -258,7 +268,7 @@ lexer_result lexer::tokenize(const std::string & source) {
|
|||
++pos; // Consume the operator
|
||||
|
||||
// Check for numbers following the unary operator
|
||||
std::string num = consume_while(is_integer);
|
||||
std::string num = consume_numeric();
|
||||
std::string value = std::string(1, ch) + num;
|
||||
token::type t = num.empty() ? token::unary_operator : token::numeric_literal;
|
||||
// JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str());
|
||||
|
|
@ -307,12 +317,7 @@ lexer_result lexer::tokenize(const std::string & source) {
|
|||
// Numbers
|
||||
if (is_integer(ch)) {
|
||||
start_pos = pos;
|
||||
std::string num = consume_while(is_integer);
|
||||
if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
|
||||
++pos; // Consume '.'
|
||||
std::string frac = consume_while(is_integer);
|
||||
num += "." + frac;
|
||||
}
|
||||
std::string num = consume_numeric();
|
||||
// JJ_DEBUG("consumed numeric literal: '%s'", num.c_str());
|
||||
tokens.push_back({token::numeric_literal, num, start_pos});
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -44,6 +44,12 @@ static std::string get_line_col(const std::string & source, size_t pos) {
|
|||
return "line " + std::to_string(line) + ", column " + std::to_string(col);
|
||||
}
|
||||
|
||||
static void ensure_key_type_allowed(const value & val) {
|
||||
if (!val->is_hashable()) {
|
||||
throw std::runtime_error("Type: " + val->type() + " is not allowed as object key");
|
||||
}
|
||||
}
|
||||
|
||||
// execute with error handling
|
||||
value statement::execute(context & ctx) {
|
||||
try {
|
||||
|
|
@ -95,20 +101,10 @@ value identifier::execute_impl(context & ctx) {
|
|||
value object_literal::execute_impl(context & ctx) {
|
||||
auto obj = mk_val<value_object>();
|
||||
for (const auto & pair : val) {
|
||||
value key_val = pair.first->execute(ctx);
|
||||
if (!is_val<value_string>(key_val) && !is_val<value_int>(key_val)) {
|
||||
throw std::runtime_error("Object literal: keys must be string or int values, got " + key_val->type());
|
||||
}
|
||||
std::string key = key_val->as_string().str();
|
||||
value key = pair.first->execute(ctx);
|
||||
value val = pair.second->execute(ctx);
|
||||
JJ_DEBUG("Object literal: setting key '%s' with value type %s", key.c_str(), val->type().c_str());
|
||||
JJ_DEBUG("Object literal: setting key '%s' with value type %s", key->as_string().str().c_str(), val->type().c_str());
|
||||
obj->insert(key, val);
|
||||
|
||||
if (is_val<value_int>(key_val)) {
|
||||
obj->val_obj.is_key_numeric = true;
|
||||
} else if (obj->val_obj.is_key_numeric) {
|
||||
throw std::runtime_error("Object literal: cannot mix numeric and non-numeric keys");
|
||||
}
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
|
@ -127,9 +123,9 @@ value binary_expression::execute_impl(context & ctx) {
|
|||
value right_val = right->execute(ctx);
|
||||
JJ_DEBUG("Executing binary expression %s '%s' %s", left_val->type().c_str(), op.value.c_str(), right_val->type().c_str());
|
||||
if (op.value == "==") {
|
||||
return mk_val<value_bool>(value_compare(left_val, right_val, value_compare_op::eq));
|
||||
return mk_val<value_bool>(*left_val == *right_val);
|
||||
} else if (op.value == "!=") {
|
||||
return mk_val<value_bool>(!value_compare(left_val, right_val, value_compare_op::eq));
|
||||
return mk_val<value_bool>(!(*left_val == *right_val));
|
||||
}
|
||||
|
||||
auto workaround_concat_null_with_str = [&](value & res) -> bool {
|
||||
|
|
@ -230,7 +226,7 @@ value binary_expression::execute_impl(context & ctx) {
|
|||
auto & arr = right_val->as_array();
|
||||
bool member = false;
|
||||
for (const auto & item : arr) {
|
||||
if (value_compare(left_val, item, value_compare_op::eq)) {
|
||||
if (*left_val == *item) {
|
||||
member = true;
|
||||
break;
|
||||
}
|
||||
|
|
@ -265,11 +261,9 @@ value binary_expression::execute_impl(context & ctx) {
|
|||
}
|
||||
}
|
||||
|
||||
// String in object
|
||||
if (is_val<value_string>(left_val) && is_val<value_object>(right_val)) {
|
||||
auto key = left_val->as_string().str();
|
||||
auto & obj = right_val->as_object();
|
||||
bool has_key = obj.find(key) != obj.end();
|
||||
// Value key in object
|
||||
if (is_val<value_object>(right_val)) {
|
||||
bool has_key = right_val->has_key(left_val);
|
||||
if (op.value == "in") {
|
||||
return mk_val<value_bool>(has_key);
|
||||
} else if (op.value == "not in") {
|
||||
|
|
@ -464,16 +458,10 @@ value for_statement::execute_impl(context & ctx) {
|
|||
std::vector<value> items;
|
||||
if (is_val<value_object>(iterable_val)) {
|
||||
JJ_DEBUG("%s", "For loop over object keys");
|
||||
auto & obj = iterable_val->as_object();
|
||||
auto & obj = iterable_val->as_ordered_object();
|
||||
for (auto & p : obj) {
|
||||
auto tuple = mk_val<value_array>();
|
||||
if (iterable_val->val_obj.is_key_numeric) {
|
||||
tuple->push_back(mk_val<value_int>(std::stoll(p.first)));
|
||||
} else {
|
||||
tuple->push_back(mk_val<value_string>(p.first));
|
||||
}
|
||||
tuple->push_back(p.second);
|
||||
items.push_back(tuple);
|
||||
auto tuple = mk_val<value_tuple>(p);
|
||||
items.push_back(std::move(tuple));
|
||||
}
|
||||
if (ctx.is_get_stats) {
|
||||
iterable_val->stats.used = true;
|
||||
|
|
@ -560,6 +548,7 @@ value for_statement::execute_impl(context & ctx) {
|
|||
for (size_t i = 0; i < filtered_items.size(); i++) {
|
||||
JJ_DEBUG("For loop iteration %zu/%zu", i + 1, filtered_items.size());
|
||||
value_object loop_obj = mk_val<value_object>();
|
||||
loop_obj->has_builtins = false; // loop object has no builtins
|
||||
loop_obj->insert("index", mk_val<value_int>(i + 1));
|
||||
loop_obj->insert("index0", mk_val<value_int>(i));
|
||||
loop_obj->insert("revindex", mk_val<value_int>(filtered_items.size() - i));
|
||||
|
|
@ -602,11 +591,13 @@ value set_statement::execute_impl(context & ctx) {
|
|||
auto rhs = val ? val->execute(ctx) : exec_statements(body, ctx);
|
||||
|
||||
if (is_stmt<identifier>(assignee)) {
|
||||
// case: {% set my_var = value %}
|
||||
auto var_name = cast_stmt<identifier>(assignee)->val;
|
||||
JJ_DEBUG("Setting global variable '%s' with value type %s", var_name.c_str(), rhs->type().c_str());
|
||||
ctx.set_val(var_name, rhs);
|
||||
|
||||
} else if (is_stmt<tuple_literal>(assignee)) {
|
||||
// case: {% set a, b = value %}
|
||||
auto tuple = cast_stmt<tuple_literal>(assignee);
|
||||
if (!is_val<value_array>(rhs)) {
|
||||
throw std::runtime_error("Cannot unpack non-iterable type in set: " + rhs->type());
|
||||
|
|
@ -625,6 +616,7 @@ value set_statement::execute_impl(context & ctx) {
|
|||
}
|
||||
|
||||
} else if (is_stmt<member_expression>(assignee)) {
|
||||
// case: {% set ns.my_var = value %}
|
||||
auto member = cast_stmt<member_expression>(assignee);
|
||||
if (member->computed) {
|
||||
throw std::runtime_error("Cannot assign to computed member");
|
||||
|
|
@ -717,6 +709,7 @@ value member_expression::execute_impl(context & ctx) {
|
|||
|
||||
value property;
|
||||
if (this->computed) {
|
||||
// syntax: obj[expr]
|
||||
JJ_DEBUG("Member expression, computing property type %s", this->property->type().c_str());
|
||||
|
||||
int64_t arr_size = 0;
|
||||
|
|
@ -745,32 +738,43 @@ value member_expression::execute_impl(context & ctx) {
|
|||
property = this->property->execute(ctx);
|
||||
}
|
||||
} else {
|
||||
// syntax: obj.prop
|
||||
if (!is_stmt<identifier>(this->property)) {
|
||||
throw std::runtime_error("Non-computed member property must be an identifier");
|
||||
throw std::runtime_error("Static member property must be an identifier");
|
||||
}
|
||||
property = mk_val<value_string>(cast_stmt<identifier>(this->property)->val);
|
||||
std::string prop = property->as_string().str();
|
||||
JJ_DEBUG("Member expression, object type %s, static property '%s'", object->type().c_str(), prop.c_str());
|
||||
|
||||
// behavior of jinja2: obj having prop as a built-in function AND 'prop', as an object key,
|
||||
// then obj.prop returns the built-in function, not the property value.
|
||||
// while obj['prop'] returns the property value.
|
||||
// example: {"obj": {"items": 123}} -> obj.items is the built-in function, obj['items'] is 123
|
||||
|
||||
value val = try_builtin_func(ctx, prop, object, true);
|
||||
if (!is_val<value_undefined>(val)) {
|
||||
return val;
|
||||
}
|
||||
// else, fallthrough to normal property access below
|
||||
}
|
||||
|
||||
JJ_DEBUG("Member expression on object type %s, property type %s", object->type().c_str(), property->type().c_str());
|
||||
ensure_key_type_allowed(property);
|
||||
|
||||
value val = mk_val<value_undefined>("object_property");
|
||||
|
||||
if (is_val<value_undefined>(object)) {
|
||||
JJ_DEBUG("%s", "Accessing property on undefined object, returning undefined");
|
||||
return val;
|
||||
|
||||
} else if (is_val<value_object>(object)) {
|
||||
if (!is_val<value_string>(property)) {
|
||||
throw std::runtime_error("Cannot access object with non-string: got " + property->type());
|
||||
}
|
||||
auto key = property->as_string().str();
|
||||
auto & obj = object->as_object();
|
||||
auto it = obj.find(key);
|
||||
if (it != obj.end()) {
|
||||
val = it->second;
|
||||
} else {
|
||||
val = object->at(property, val);
|
||||
if (is_val<value_undefined>(val)) {
|
||||
val = try_builtin_func(ctx, key, object, true);
|
||||
}
|
||||
JJ_DEBUG("Accessed property '%s' value, got type: %s", key.c_str(), val->type().c_str());
|
||||
|
||||
} else if (is_val<value_array>(object) || is_val<value_string>(object)) {
|
||||
if (is_val<value_int>(property)) {
|
||||
int64_t index = property->as_int();
|
||||
|
|
@ -793,7 +797,8 @@ value member_expression::execute_impl(context & ctx) {
|
|||
} else if (is_val<value_string>(property)) {
|
||||
auto key = property->as_string().str();
|
||||
JJ_DEBUG("Accessing %s built-in '%s'", is_val<value_array>(object) ? "array" : "string", key.c_str());
|
||||
val = try_builtin_func(ctx, key, object);
|
||||
val = try_builtin_func(ctx, key, object, true);
|
||||
|
||||
} else {
|
||||
throw std::runtime_error("Cannot access property with non-string/non-number: got " + property->type());
|
||||
}
|
||||
|
|
@ -802,7 +807,7 @@ value member_expression::execute_impl(context & ctx) {
|
|||
throw std::runtime_error("Cannot access property with non-string: got " + property->type());
|
||||
}
|
||||
auto key = property->as_string().str();
|
||||
val = try_builtin_func(ctx, key, object);
|
||||
val = try_builtin_func(ctx, key, object, true);
|
||||
}
|
||||
|
||||
if (ctx.is_get_stats && val && object && property) {
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ struct context {
|
|||
// src is optional, used for error reporting
|
||||
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
|
||||
env = mk_val<value_object>();
|
||||
env->has_builtins = false; // context object has no builtins
|
||||
env->insert("true", mk_val<value_bool>(true));
|
||||
env->insert("True", mk_val<value_bool>(true));
|
||||
env->insert("false", mk_val<value_bool>(false));
|
||||
|
|
@ -68,7 +69,7 @@ struct context {
|
|||
|
||||
context(const context & parent) : context() {
|
||||
// inherit variables (for example, when entering a new scope)
|
||||
auto & pvar = parent.env->as_object();
|
||||
auto & pvar = parent.env->as_ordered_object();
|
||||
for (const auto & pair : pvar) {
|
||||
set_val(pair.first, pair.second);
|
||||
}
|
||||
|
|
@ -78,18 +79,18 @@ struct context {
|
|||
}
|
||||
|
||||
value get_val(const std::string & name) {
|
||||
auto it = env->val_obj.unordered.find(name);
|
||||
if (it != env->val_obj.unordered.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return mk_val<value_undefined>(name);
|
||||
}
|
||||
value default_val = mk_val<value_undefined>(name);
|
||||
return env->at(name, default_val);
|
||||
}
|
||||
|
||||
void set_val(const std::string & name, const value & val) {
|
||||
env->insert(name, val);
|
||||
}
|
||||
|
||||
void set_val(const value & name, const value & val) {
|
||||
env->insert(name, val);
|
||||
}
|
||||
|
||||
void print_vars() const {
|
||||
printf("Context Variables:\n%s\n", value_to_json(env, 2).c_str());
|
||||
}
|
||||
|
|
@ -265,7 +266,7 @@ struct comment_statement : public statement {
|
|||
struct member_expression : public expression {
|
||||
statement_ptr object;
|
||||
statement_ptr property;
|
||||
bool computed;
|
||||
bool computed; // true if obj[expr] and false if obj.prop
|
||||
|
||||
member_expression(statement_ptr && object, statement_ptr && property, bool computed)
|
||||
: object(std::move(object)), property(std::move(property)), computed(computed) {
|
||||
|
|
@ -343,9 +344,19 @@ struct array_literal : public expression {
|
|||
}
|
||||
};
|
||||
|
||||
struct tuple_literal : public array_literal {
|
||||
explicit tuple_literal(statements && val) : array_literal(std::move(val)) {}
|
||||
struct tuple_literal : public expression {
|
||||
statements val;
|
||||
explicit tuple_literal(statements && val) : val(std::move(val)) {
|
||||
for (const auto& item : this->val) chk_type<expression>(item);
|
||||
}
|
||||
std::string type() const override { return "TupleLiteral"; }
|
||||
value execute_impl(context & ctx) override {
|
||||
auto arr = mk_val<value_array>();
|
||||
for (const auto & item_stmt : val) {
|
||||
arr->push_back(item_stmt->execute(ctx));
|
||||
}
|
||||
return mk_val<value_tuple>(std::move(arr->as_array()));
|
||||
}
|
||||
};
|
||||
|
||||
struct object_literal : public expression {
|
||||
|
|
|
|||
|
|
@ -61,6 +61,12 @@ size_t string::length() const {
|
|||
return len;
|
||||
}
|
||||
|
||||
void string::hash_update(hasher & hash) const noexcept {
|
||||
for (const auto & part : parts) {
|
||||
hash.update(part.val.data(), part.val.length());
|
||||
}
|
||||
}
|
||||
|
||||
bool string::all_parts_are_input() const {
|
||||
for (const auto & part : parts) {
|
||||
if (!part.is_input) {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
namespace jinja {
|
||||
|
||||
// allow differentiate between user input strings and template strings
|
||||
|
|
@ -37,6 +39,7 @@ struct string {
|
|||
|
||||
std::string str() const;
|
||||
size_t length() const;
|
||||
void hash_update(hasher & hash) const noexcept;
|
||||
bool all_parts_are_input() const;
|
||||
bool is_uppercase() const;
|
||||
bool is_lowercase() const;
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
#include <string>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
namespace jinja {
|
||||
|
||||
|
|
@ -46,4 +48,102 @@ static std::string fmt_error_with_source(const std::string & tag, const std::str
|
|||
return oss.str();
|
||||
}
|
||||
|
||||
// Note: this is a simple hasher, not cryptographically secure, just for hash table usage
|
||||
struct hasher {
|
||||
static constexpr auto size_t_digits = sizeof(size_t) * 8;
|
||||
static constexpr size_t prime = size_t_digits == 64 ? 0x100000001b3 : 0x01000193;
|
||||
static constexpr size_t seed = size_t_digits == 64 ? 0xcbf29ce484222325 : 0x811c9dc5;
|
||||
static constexpr auto block_size = sizeof(size_t); // in bytes; allowing the compiler to vectorize the computation
|
||||
|
||||
static_assert(size_t_digits == 64 || size_t_digits == 32);
|
||||
static_assert(block_size == 8 || block_size == 4);
|
||||
|
||||
uint8_t buffer[block_size];
|
||||
size_t idx = 0; // current index in buffer
|
||||
size_t state = seed;
|
||||
|
||||
hasher() = default;
|
||||
hasher(const std::type_info & type_inf) noexcept {
|
||||
const auto type_hash = type_inf.hash_code();
|
||||
update(&type_hash, sizeof(type_hash));
|
||||
}
|
||||
|
||||
// Properties:
|
||||
// - update is not associative: update(a).update(b) != update(b).update(a)
|
||||
// - update(a ~ b) == update(a).update(b) with ~ as concatenation operator --> useful for streaming
|
||||
// - update("", 0) --> state unchanged with empty input
|
||||
hasher& update(void const * bytes, size_t len) noexcept {
|
||||
const uint8_t * c = static_cast<uint8_t const *>(bytes);
|
||||
if (len == 0) {
|
||||
return *this;
|
||||
}
|
||||
size_t processed = 0;
|
||||
|
||||
// first, fill the existing buffer if it's partial
|
||||
if (idx > 0) {
|
||||
size_t to_fill = block_size - idx;
|
||||
if (to_fill > len) {
|
||||
to_fill = len;
|
||||
}
|
||||
std::memcpy(buffer + idx, c, to_fill);
|
||||
idx += to_fill;
|
||||
processed += to_fill;
|
||||
if (idx == block_size) {
|
||||
update_block(buffer);
|
||||
idx = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// process full blocks from the remaining input
|
||||
for (; processed + block_size <= len; processed += block_size) {
|
||||
update_block(c + processed);
|
||||
}
|
||||
|
||||
// buffer any remaining bytes
|
||||
size_t remaining = len - processed;
|
||||
if (remaining > 0) {
|
||||
std::memcpy(buffer, c + processed, remaining);
|
||||
idx = remaining;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
// convenience function for testing only
|
||||
hasher& update(const std::string & s) noexcept {
|
||||
return update(s.data(), s.size());
|
||||
}
|
||||
|
||||
// finalize and get the hash value
|
||||
// note: after calling digest, the hasher state is modified, do not call update() again
|
||||
size_t digest() noexcept {
|
||||
// if there are remaining bytes in buffer, fill the rest with zeros and process
|
||||
if (idx > 0) {
|
||||
for (size_t i = idx; i < block_size; ++i) {
|
||||
buffer[i] = 0;
|
||||
}
|
||||
update_block(buffer);
|
||||
idx = 0;
|
||||
}
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
private:
|
||||
// IMPORTANT: block must have at least block_size bytes
|
||||
void update_block(const uint8_t * block) noexcept {
|
||||
size_t blk = static_cast<uint32_t>(block[0])
|
||||
| (static_cast<uint32_t>(block[1]) << 8)
|
||||
| (static_cast<uint32_t>(block[2]) << 16)
|
||||
| (static_cast<uint32_t>(block[3]) << 24);
|
||||
if constexpr (block_size == 8) {
|
||||
blk = blk | (static_cast<uint64_t>(block[4]) << 32)
|
||||
| (static_cast<uint64_t>(block[5]) << 40)
|
||||
| (static_cast<uint64_t>(block[6]) << 48)
|
||||
| (static_cast<uint64_t>(block[7]) << 56);
|
||||
}
|
||||
state ^= blk;
|
||||
state *= prime;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace jinja
|
||||
|
|
|
|||
|
|
@ -114,6 +114,18 @@ static T slice(const T & array, int64_t start, int64_t stop, int64_t step = 1) {
|
|||
return result;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static value empty_value_fn(const func_args &) {
|
||||
if constexpr (std::is_same_v<T, value_int>) {
|
||||
return mk_val<T>(0);
|
||||
} else if constexpr (std::is_same_v<T, value_float>) {
|
||||
return mk_val<T>(0.0);
|
||||
} else if constexpr (std::is_same_v<T, value_bool>) {
|
||||
return mk_val<T>(false);
|
||||
} else {
|
||||
return mk_val<T>();
|
||||
}
|
||||
}
|
||||
template<typename T>
|
||||
static value test_type_fn(const func_args & args) {
|
||||
args.ensure_count(1);
|
||||
|
|
@ -128,6 +140,13 @@ static value test_type_fn(const func_args & args) {
|
|||
JJ_DEBUG("test_type_fn: type=%s or %s result=%d", typeid(T).name(), typeid(U).name(), is_type ? 1 : 0);
|
||||
return mk_val<value_bool>(is_type);
|
||||
}
|
||||
template<typename T, typename U, typename V>
|
||||
static value test_type_fn(const func_args & args) {
|
||||
args.ensure_count(1);
|
||||
bool is_type = is_val<T>(args.get_pos(0)) || is_val<U>(args.get_pos(0)) || is_val<V>(args.get_pos(0));
|
||||
JJ_DEBUG("test_type_fn: type=%s, %s or %s result=%d", typeid(T).name(), typeid(U).name(), typeid(V).name(), is_type ? 1 : 0);
|
||||
return mk_val<value_bool>(is_type);
|
||||
}
|
||||
template<value_compare_op op>
|
||||
static value test_compare_fn(const func_args & args) {
|
||||
args.ensure_count(2, 2);
|
||||
|
|
@ -163,7 +182,7 @@ static value selectattr(const func_args & args) {
|
|||
args.ensure_vals<value_array, value_string, value_string, value_string>(true, true, false, false);
|
||||
|
||||
auto arr = args.get_pos(0)->as_array();
|
||||
auto attr_name = args.get_pos(1)->as_string().str();
|
||||
auto attribute = args.get_pos(1);
|
||||
auto out = mk_val<value_array>();
|
||||
value val_default = mk_val<value_undefined>();
|
||||
|
||||
|
|
@ -173,7 +192,7 @@ static value selectattr(const func_args & args) {
|
|||
if (!is_val<value_object>(item)) {
|
||||
throw raised_exception("selectattr: item is not an object");
|
||||
}
|
||||
value attr_val = item->at(attr_name, val_default);
|
||||
value attr_val = item->at(attribute, val_default);
|
||||
bool is_selected = attr_val->as_bool();
|
||||
if constexpr (is_reject) is_selected = !is_selected;
|
||||
if (is_selected) out->push_back(item);
|
||||
|
|
@ -217,7 +236,7 @@ static value selectattr(const func_args & args) {
|
|||
if (!is_val<value_object>(item)) {
|
||||
throw raised_exception("selectattr: item is not an object");
|
||||
}
|
||||
value attr_val = item->at(attr_name, val_default);
|
||||
value attr_val = item->at(attribute, val_default);
|
||||
func_args test_args(args.ctx);
|
||||
test_args.push_back(attr_val); // attribute value
|
||||
test_args.push_back(extra_arg); // extra argument
|
||||
|
|
@ -347,8 +366,8 @@ const func_builtins & global_builtins() {
|
|||
{"test_is_integer", test_type_fn<value_int>},
|
||||
{"test_is_float", test_type_fn<value_float>},
|
||||
{"test_is_number", test_type_fn<value_int, value_float>},
|
||||
{"test_is_iterable", test_type_fn<value_array, value_string>},
|
||||
{"test_is_sequence", test_type_fn<value_array, value_string>},
|
||||
{"test_is_iterable", test_type_fn<value_array, value_string, value_undefined>},
|
||||
{"test_is_sequence", test_type_fn<value_array, value_string, value_undefined>},
|
||||
{"test_is_mapping", test_type_fn<value_object>},
|
||||
{"test_is_lower", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_string>();
|
||||
|
|
@ -698,6 +717,7 @@ const func_builtins & value_bool_t::get_builtins() const {
|
|||
bool val = args.get_pos(0)->as_bool();
|
||||
return mk_val<value_string>(val ? "True" : "False");
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
|
@ -740,6 +760,7 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
args.ensure_count(1, 4);
|
||||
args.ensure_vals<value_array, value_int, value_int, value_int>(true, true, false, false);
|
||||
|
||||
auto val = args.get_pos(0);
|
||||
auto arg0 = args.get_pos(1);
|
||||
auto arg1 = args.get_pos(2, mk_val<value_undefined>());
|
||||
auto arg2 = args.get_pos(3, mk_val<value_undefined>());
|
||||
|
|
@ -761,10 +782,8 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
if (step == 0) {
|
||||
throw raised_exception("slice step cannot be zero");
|
||||
}
|
||||
auto arr = slice(args.get_pos(0)->as_array(), start, stop, step);
|
||||
auto res = mk_val<value_array>();
|
||||
res->val_arr = std::move(arr);
|
||||
return res;
|
||||
auto arr = slice(val->as_array(), start, stop, step);
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
|
||||
}},
|
||||
{"selectattr", selectattr<false>},
|
||||
{"select", selectattr<false>},
|
||||
|
|
@ -775,19 +794,29 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
if (!is_val<value_array>(args.get_pos(0))) {
|
||||
throw raised_exception("join() first argument must be an array");
|
||||
}
|
||||
value val_delim = args.get_kwarg_or_pos("d", 1);
|
||||
value val_attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
if (!val_attribute->is_undefined()) {
|
||||
throw not_implemented_exception("array attribute join not implemented");
|
||||
}
|
||||
value val_delim = args.get_kwarg_or_pos("d", 1);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
const auto & arr = args.get_pos(0)->as_array();
|
||||
std::string delim = is_val<value_string>(val_delim) ? val_delim->as_string().str() : "";
|
||||
const bool attr_is_int = is_val<value_int>(attribute);
|
||||
if (!attribute->is_undefined() && !is_val<value_string>(attribute) && !attr_is_int) {
|
||||
throw raised_exception("join() attribute must be string or integer");
|
||||
}
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str();
|
||||
std::string result;
|
||||
for (size_t i = 0; i < arr.size(); ++i) {
|
||||
if (!is_val<value_string>(arr[i]) && !is_val<value_int>(arr[i]) && !is_val<value_float>(arr[i])) {
|
||||
value val_arr = arr[i];
|
||||
if (!attribute->is_undefined()) {
|
||||
if (attr_is_int && is_val<value_array>(val_arr)) {
|
||||
val_arr = val_arr->at(attr_int);
|
||||
} else if (!attr_is_int && is_val<value_object>(val_arr)) {
|
||||
val_arr = val_arr->at(attribute);
|
||||
}
|
||||
}
|
||||
if (!is_val<value_string>(val_arr) && !is_val<value_int>(val_arr) && !is_val<value_float>(val_arr)) {
|
||||
throw raised_exception("join() can only join arrays of strings or numerics");
|
||||
}
|
||||
result += arr[i]->as_string().str();
|
||||
result += val_arr->as_string().str();
|
||||
if (i < arr.size() - 1) {
|
||||
result += delim;
|
||||
}
|
||||
|
|
@ -796,35 +825,37 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
}},
|
||||
{"string", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_array>();
|
||||
auto str = mk_val<value_string>();
|
||||
gather_string_parts_recursive(args.get_pos(0), str);
|
||||
return str;
|
||||
return mk_val<value_string>(args.get_pos(0)->as_string());
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"map", [](const func_args & args) -> value {
|
||||
args.ensure_count(2, 3);
|
||||
args.ensure_count(2);
|
||||
if (!is_val<value_array>(args.get_pos(0))) {
|
||||
throw raised_exception("map: first argument must be an array");
|
||||
}
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 1);
|
||||
if (is_val<value_int>(attribute)) {
|
||||
throw not_implemented_exception("map: integer attribute not implemented");
|
||||
if (!is_val<value_kwarg>(args.get_args().at(1))) {
|
||||
throw not_implemented_exception("map: filter-mapping not implemented");
|
||||
}
|
||||
if (!is_val<value_string>(attribute)) {
|
||||
value val = args.get_pos(0);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 1);
|
||||
const bool attr_is_int = is_val<value_int>(attribute);
|
||||
if (!is_val<value_string>(attribute) && !attr_is_int) {
|
||||
throw raised_exception("map: attribute must be string or integer");
|
||||
}
|
||||
std::string attr_name = attribute->as_string().str();
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
value default_val = args.get_kwarg("default", mk_val<value_undefined>());
|
||||
auto out = mk_val<value_array>();
|
||||
auto arr = args.get_pos(0)->as_array();
|
||||
auto arr = val->as_array();
|
||||
for (const auto & item : arr) {
|
||||
if (!is_val<value_object>(item)) {
|
||||
throw raised_exception("map: item is not an object");
|
||||
value attr_val;
|
||||
if (attr_is_int) {
|
||||
attr_val = is_val<value_array>(item) ? item->at(attr_int, default_val) : default_val;
|
||||
} else {
|
||||
attr_val = is_val<value_object>(item) ? item->at(attribute, default_val) : default_val;
|
||||
}
|
||||
value attr_val = item->at(attr_name, default_val);
|
||||
out->push_back(attr_val);
|
||||
}
|
||||
return out;
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(out->as_array())) : out;
|
||||
}},
|
||||
{"append", [](const func_args & args) -> value {
|
||||
args.ensure_count(2);
|
||||
|
|
@ -847,37 +878,44 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
return arr_editable->pop_at(index);
|
||||
}},
|
||||
{"sort", [](const func_args & args) -> value {
|
||||
args.ensure_count(1, 3);
|
||||
args.ensure_count(1, 4);
|
||||
if (!is_val<value_array>(args.get_pos(0))) {
|
||||
throw raised_exception("sort: first argument must be an array");
|
||||
}
|
||||
bool reverse = args.get_kwarg("reverse", mk_val<value_undefined>())->as_bool();
|
||||
value attribute = args.get_kwarg("attribute", mk_val<value_undefined>());
|
||||
std::string attr = attribute->is_undefined() ? "" : attribute->as_string().str();
|
||||
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
|
||||
value val = args.get_pos(0);
|
||||
value val_reverse = args.get_kwarg_or_pos("reverse", 1);
|
||||
value val_case = args.get_kwarg_or_pos("case_sensitive", 2);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 3);
|
||||
// FIXME: sorting is currently always case sensitive
|
||||
//const bool case_sensitive = val_case->as_bool(); // undefined == false
|
||||
const bool reverse = val_reverse->as_bool(); // undefined == false
|
||||
const bool attr_is_int = is_val<value_int>(attribute);
|
||||
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
|
||||
std::vector<value> arr = val->as_array(); // copy
|
||||
std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) {
|
||||
value val_a = a;
|
||||
value val_b = b;
|
||||
if (!attribute->is_undefined()) {
|
||||
if (!is_val<value_object>(a) || !is_val<value_object>(b)) {
|
||||
throw raised_exception("sort: items are not objects");
|
||||
if (attr_is_int && is_val<value_array>(a) && is_val<value_array>(b)) {
|
||||
val_a = a->at(attr_int);
|
||||
val_b = b->at(attr_int);
|
||||
} else if (!attr_is_int && is_val<value_object>(a) && is_val<value_object>(b)) {
|
||||
val_a = a->at(attribute);
|
||||
val_b = b->at(attribute);
|
||||
} else {
|
||||
throw raised_exception("sort: unsupported object attribute comparison between " + a->type() + " and " + b->type());
|
||||
}
|
||||
val_a = attr.empty() ? a : a->at(attr);
|
||||
val_b = attr.empty() ? b : b->at(attr);
|
||||
}
|
||||
if (reverse) {
|
||||
return value_compare(val_a, val_b, value_compare_op::gt);
|
||||
} else {
|
||||
return !value_compare(val_a, val_b, value_compare_op::gt);
|
||||
}
|
||||
return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt);
|
||||
});
|
||||
return mk_val<value_array>(arr);
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
|
||||
}},
|
||||
{"reverse", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_array>();
|
||||
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
|
||||
value val = args.get_pos(0);
|
||||
std::vector<value> arr = val->as_array(); // copy
|
||||
std::reverse(arr.begin(), arr.end());
|
||||
return mk_val<value_array>(arr);
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
|
||||
}},
|
||||
{"unique", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("Array unique builtin not implemented");
|
||||
|
|
@ -888,6 +926,11 @@ const func_builtins & value_array_t::get_builtins() const {
|
|||
|
||||
|
||||
const func_builtins & value_object_t::get_builtins() const {
|
||||
if (!has_builtins) {
|
||||
static const func_builtins no_builtins = {};
|
||||
return no_builtins;
|
||||
}
|
||||
|
||||
static const func_builtins builtins = {
|
||||
// {"default", default_value}, // cause issue with gpt-oss
|
||||
{"get", [](const func_args & args) -> value {
|
||||
|
|
@ -902,27 +945,22 @@ const func_builtins & value_object_t::get_builtins() const {
|
|||
if (args.count() == 3) {
|
||||
default_val = args.get_pos(2);
|
||||
}
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
std::string key = args.get_pos(1)->as_string().str();
|
||||
auto it = obj.find(key);
|
||||
if (it != obj.end()) {
|
||||
return it->second;
|
||||
} else {
|
||||
return default_val;
|
||||
}
|
||||
const value obj = args.get_pos(0);
|
||||
const value key = args.get_pos(1);
|
||||
return obj->at(key, default_val);
|
||||
}},
|
||||
{"keys", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
auto result = mk_val<value_array>();
|
||||
for (const auto & pair : obj) {
|
||||
result->push_back(mk_val<value_string>(pair.first));
|
||||
result->push_back(pair.first);
|
||||
}
|
||||
return result;
|
||||
}},
|
||||
{"values", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
auto result = mk_val<value_array>();
|
||||
for (const auto & pair : obj) {
|
||||
result->push_back(pair.second);
|
||||
|
|
@ -931,21 +969,22 @@ const func_builtins & value_object_t::get_builtins() const {
|
|||
}},
|
||||
{"items", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
auto result = mk_val<value_array>();
|
||||
for (const auto & pair : obj) {
|
||||
auto item = mk_val<value_array>();
|
||||
item->push_back(mk_val<value_string>(pair.first));
|
||||
item->push_back(pair.second);
|
||||
auto item = mk_val<value_tuple>(pair);
|
||||
result->push_back(std::move(item));
|
||||
}
|
||||
return result;
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"string", tojson},
|
||||
{"string", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
return mk_val<value_string>(args.get_pos(0)->as_string());
|
||||
}},
|
||||
{"length", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_object>();
|
||||
const auto & obj = args.get_pos(0)->as_object();
|
||||
const auto & obj = args.get_pos(0)->as_ordered_object();
|
||||
return mk_val<value_int>(static_cast<int64_t>(obj.size()));
|
||||
}},
|
||||
{"tojson", [](const func_args & args) -> value {
|
||||
|
|
@ -958,21 +997,18 @@ const func_builtins & value_object_t::get_builtins() const {
|
|||
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
|
||||
value val_by = args.get_kwarg_or_pos("by", 2);
|
||||
value val_reverse = args.get_kwarg_or_pos("reverse", 3);
|
||||
// FIXME: sorting is case sensitive
|
||||
// FIXME: sorting is currently always case sensitive
|
||||
//const bool case_sensitive = val_case->as_bool(); // undefined == false
|
||||
const bool reverse = val_reverse->as_bool(); // undefined == false
|
||||
if (!val_by->is_undefined()) {
|
||||
throw not_implemented_exception("dictsort by key not implemented");
|
||||
}
|
||||
if (reverse) {
|
||||
throw not_implemented_exception("dictsort reverse not implemented");
|
||||
}
|
||||
value_t::map obj = val_input->val_obj; // copy
|
||||
std::sort(obj.ordered.begin(), obj.ordered.end(), [&](const auto & a, const auto & b) {
|
||||
return a.first < b.first;
|
||||
const bool by_value = is_val<value_string>(val_by) && val_by->as_string().str() == "value" ? true : false;
|
||||
auto result = mk_val<value_object>(val_input); // copy
|
||||
std::sort(result->val_obj.begin(), result->val_obj.end(), [&](const auto & a, const auto & b) {
|
||||
if (by_value) {
|
||||
return value_compare(a.second, b.second, reverse ? value_compare_op::gt : value_compare_op::lt);
|
||||
} else {
|
||||
return value_compare(a.first, b.first, reverse ? value_compare_op::gt : value_compare_op::lt);
|
||||
}
|
||||
});
|
||||
auto result = mk_val<value_object>();
|
||||
result->val_obj = std::move(obj);
|
||||
return result;
|
||||
}},
|
||||
{"join", [](const func_args &) -> value {
|
||||
|
|
@ -986,6 +1022,22 @@ const func_builtins & value_none_t::get_builtins() const {
|
|||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"tojson", tojson},
|
||||
{"string", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"safe", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"strip", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"items", empty_value_fn<value_array>},
|
||||
{"map", empty_value_fn<value_array>},
|
||||
{"reject", empty_value_fn<value_array>},
|
||||
{"rejectattr", empty_value_fn<value_array>},
|
||||
{"select", empty_value_fn<value_array>},
|
||||
{"selectattr", empty_value_fn<value_array>},
|
||||
{"unique", empty_value_fn<value_array>},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
|
@ -994,10 +1046,33 @@ const func_builtins & value_none_t::get_builtins() const {
|
|||
const func_builtins & value_undefined_t::get_builtins() const {
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"tojson", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_undefined>();
|
||||
return mk_val<value_string>("null");
|
||||
}},
|
||||
{"capitalize", empty_value_fn<value_string>},
|
||||
{"first", empty_value_fn<value_undefined>},
|
||||
{"items", empty_value_fn<value_array>},
|
||||
{"join", empty_value_fn<value_string>},
|
||||
{"last", empty_value_fn<value_undefined>},
|
||||
{"length", empty_value_fn<value_int>},
|
||||
{"list", empty_value_fn<value_array>},
|
||||
{"lower", empty_value_fn<value_string>},
|
||||
{"map", empty_value_fn<value_array>},
|
||||
{"max", empty_value_fn<value_undefined>},
|
||||
{"min", empty_value_fn<value_undefined>},
|
||||
{"reject", empty_value_fn<value_array>},
|
||||
{"rejectattr", empty_value_fn<value_array>},
|
||||
{"replace", empty_value_fn<value_string>},
|
||||
{"reverse", empty_value_fn<value_array>},
|
||||
{"safe", empty_value_fn<value_string>},
|
||||
{"select", empty_value_fn<value_array>},
|
||||
{"selectattr", empty_value_fn<value_array>},
|
||||
{"sort", empty_value_fn<value_array>},
|
||||
{"string", empty_value_fn<value_string>},
|
||||
{"strip", empty_value_fn<value_string>},
|
||||
{"sum", empty_value_fn<value_int>},
|
||||
{"title", empty_value_fn<value_string>},
|
||||
{"truncate", empty_value_fn<value_string>},
|
||||
{"unique", empty_value_fn<value_array>},
|
||||
{"upper", empty_value_fn<value_string>},
|
||||
{"wordcount", empty_value_fn<value_int>},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
|
@ -1114,6 +1189,8 @@ void global_from_json(context & ctx, const nlohmann::ordered_json & json_obj, bo
|
|||
}
|
||||
}
|
||||
|
||||
// recursively convert value to JSON string
|
||||
// TODO: avoid circular references
|
||||
static void value_to_json_internal(std::ostringstream & oss, const value & val, int curr_lvl, int indent, const std::string_view item_sep, const std::string_view key_sep) {
|
||||
auto indent_str = [indent, curr_lvl]() -> std::string {
|
||||
return (indent > 0) ? std::string(curr_lvl * indent, ' ') : "";
|
||||
|
|
@ -1169,14 +1246,15 @@ static void value_to_json_internal(std::ostringstream & oss, const value & val,
|
|||
}
|
||||
oss << "]";
|
||||
} else if (is_val<value_object>(val)) {
|
||||
const auto & obj = val->val_obj.ordered; // IMPORTANT: need to keep exact order
|
||||
const auto & obj = val->as_ordered_object(); // IMPORTANT: need to keep exact order
|
||||
oss << "{";
|
||||
if (!obj.empty()) {
|
||||
oss << newline();
|
||||
size_t i = 0;
|
||||
for (const auto & pair : obj) {
|
||||
oss << indent_str() << (indent > 0 ? std::string(indent, ' ') : "");
|
||||
oss << "\"" << pair.first << "\"" << key_sep;
|
||||
value_to_json_internal(oss, mk_val<value_string>(pair.first->as_string().str()), curr_lvl + 1, indent, item_sep, key_sep);
|
||||
oss << key_sep;
|
||||
value_to_json_internal(oss, pair.second, curr_lvl + 1, indent, item_sep, key_sep);
|
||||
if (i < obj.size() - 1) {
|
||||
oss << item_sep;
|
||||
|
|
@ -1199,4 +1277,19 @@ std::string value_to_json(const value & val, int indent, const std::string_view
|
|||
return oss.str();
|
||||
}
|
||||
|
||||
// TODO: avoid circular references
|
||||
std::string value_to_string_repr(const value & val) {
|
||||
if (is_val<value_string>(val)) {
|
||||
const std::string val_str = val->as_string().str();
|
||||
|
||||
if (val_str.find('\'') != std::string::npos) {
|
||||
return value_to_json(val);
|
||||
} else {
|
||||
return "'" + val_str + "'";
|
||||
}
|
||||
} else {
|
||||
return val->as_repr();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jinja
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
#include "string.h"
|
||||
#include "utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
|
@ -10,6 +12,7 @@
|
|||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace jinja {
|
||||
|
|
@ -93,7 +96,8 @@ void global_from_json(context & ctx, const T_JSON & json_obj, bool mark_input);
|
|||
|
||||
struct func_args; // function argument values
|
||||
|
||||
using func_handler = std::function<value(const func_args &)>;
|
||||
using func_hptr = value(const func_args &);
|
||||
using func_handler = std::function<func_hptr>;
|
||||
using func_builtins = std::map<std::string, func_handler>;
|
||||
|
||||
enum value_compare_op { eq, ge, gt, lt, ne };
|
||||
|
|
@ -103,28 +107,9 @@ struct value_t {
|
|||
int64_t val_int;
|
||||
double val_flt;
|
||||
string val_str;
|
||||
bool val_bool;
|
||||
|
||||
std::vector<value> val_arr;
|
||||
|
||||
struct map {
|
||||
// once set to true, all keys must be numeric
|
||||
// caveat: we only allow either all numeric keys or all non-numeric keys
|
||||
// for now, this only applied to for_statement in case of iterating over object keys/items
|
||||
bool is_key_numeric = false;
|
||||
std::map<std::string, value> unordered;
|
||||
std::vector<std::pair<std::string, value>> ordered;
|
||||
void insert(const std::string & key, const value & val) {
|
||||
if (unordered.find(key) != unordered.end()) {
|
||||
// if key exists, remove from ordered list
|
||||
ordered.erase(std::remove_if(ordered.begin(), ordered.end(),
|
||||
[&](const std::pair<std::string, value> & p) { return p.first == key; }),
|
||||
ordered.end());
|
||||
}
|
||||
unordered[key] = val;
|
||||
ordered.push_back({key, val});
|
||||
}
|
||||
} val_obj;
|
||||
std::vector<std::pair<value, value>> val_obj;
|
||||
|
||||
func_handler val_func;
|
||||
|
||||
|
|
@ -139,6 +124,7 @@ struct value_t {
|
|||
value_t(const value_t &) = default;
|
||||
virtual ~value_t() = default;
|
||||
|
||||
// Note: only for debugging and error reporting purposes
|
||||
virtual std::string type() const { return ""; }
|
||||
|
||||
virtual int64_t as_int() const { throw std::runtime_error(type() + " is not an int value"); }
|
||||
|
|
@ -146,7 +132,7 @@ struct value_t {
|
|||
virtual string as_string() const { throw std::runtime_error(type() + " is not a string value"); }
|
||||
virtual bool as_bool() const { throw std::runtime_error(type() + " is not a bool value"); }
|
||||
virtual const std::vector<value> & as_array() const { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual const std::map<std::string, value> & as_object() const { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value invoke(const func_args &) const { throw std::runtime_error(type() + " is not a function value"); }
|
||||
virtual bool is_none() const { return false; }
|
||||
virtual bool is_undefined() const { return false; }
|
||||
|
|
@ -154,28 +140,66 @@ struct value_t {
|
|||
throw std::runtime_error("No builtins available for type " + type());
|
||||
}
|
||||
|
||||
virtual value & at(const std::string & key, value & default_val) {
|
||||
auto it = val_obj.unordered.find(key);
|
||||
if (it == val_obj.unordered.end()) {
|
||||
return default_val;
|
||||
}
|
||||
return val_obj.unordered.at(key);
|
||||
}
|
||||
virtual value & at(const std::string & key) {
|
||||
auto it = val_obj.unordered.find(key);
|
||||
if (it == val_obj.unordered.end()) {
|
||||
throw std::runtime_error("Key '" + key + "' not found in value of type " + type());
|
||||
}
|
||||
return val_obj.unordered.at(key);
|
||||
}
|
||||
virtual value & at(size_t index) {
|
||||
if (index >= val_arr.size()) {
|
||||
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
|
||||
}
|
||||
return val_arr[index];
|
||||
}
|
||||
virtual bool has_key(const value &) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual void insert(const value & /* key */, const value & /* val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const value & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const value & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */, value & /* default_val */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(const std::string & /* key */) { throw std::runtime_error(type() + " is not an object value"); }
|
||||
virtual value & at(int64_t /* idx */, value & /* default_val */) { throw std::runtime_error(type() + " is not an array value"); }
|
||||
virtual value & at(int64_t /* idx */) { throw std::runtime_error(type() + " is not an array value"); }
|
||||
|
||||
virtual bool is_numeric() const { return false; }
|
||||
virtual bool is_hashable() const { return false; }
|
||||
virtual bool is_immutable() const { return true; }
|
||||
virtual hasher unique_hash() const noexcept = 0;
|
||||
// TODO: C++20 <=> operator
|
||||
// NOTE: We are treating == as equivalent (for normal comparisons) and != as strict nonequal (for strict (is) comparisons)
|
||||
virtual bool operator==(const value_t & other) const { return equivalent(other); }
|
||||
virtual bool operator!=(const value_t & other) const { return nonequal(other); }
|
||||
|
||||
// Note: only for debugging purposes
|
||||
virtual std::string as_repr() const { return as_string().str(); }
|
||||
|
||||
protected:
|
||||
virtual bool equivalent(const value_t &) const = 0;
|
||||
virtual bool nonequal(const value_t & other) const { return !equivalent(other); }
|
||||
};
|
||||
|
||||
//
|
||||
// utils
|
||||
//
|
||||
|
||||
const func_builtins & global_builtins();
|
||||
|
||||
std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
|
||||
|
||||
// Note: only used for debugging purposes
|
||||
std::string value_to_string_repr(const value & val);
|
||||
|
||||
struct not_implemented_exception : public std::runtime_error {
|
||||
not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
|
||||
};
|
||||
|
||||
struct value_hasher {
|
||||
size_t operator()(const value & val) const noexcept {
|
||||
return val->unique_hash().digest();
|
||||
}
|
||||
};
|
||||
|
||||
struct value_equivalence {
|
||||
bool operator()(const value & lhs, const value & rhs) const {
|
||||
return *lhs == *rhs;
|
||||
}
|
||||
bool operator()(const std::pair<value, value> & lhs, const std::pair<value, value> & rhs) const {
|
||||
return *(lhs.first) == *(rhs.first) && *(lhs.second) == *(rhs.second);
|
||||
}
|
||||
};
|
||||
|
||||
struct value_equality {
|
||||
bool operator()(const value & lhs, const value & rhs) const {
|
||||
return !(*lhs != *rhs);
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
|
|
@ -183,28 +207,77 @@ struct value_t {
|
|||
//
|
||||
|
||||
struct value_int_t : public value_t {
|
||||
value_int_t(int64_t v) { val_int = v; }
|
||||
value_int_t(int64_t v) {
|
||||
val_int = v;
|
||||
val_flt = static_cast<double>(v);
|
||||
if (static_cast<int64_t>(val_flt) != v) {
|
||||
val_flt = v < 0 ? -INFINITY : INFINITY;
|
||||
}
|
||||
}
|
||||
virtual std::string type() const override { return "Integer"; }
|
||||
virtual int64_t as_int() const override { return val_int; }
|
||||
virtual double as_float() const override { return static_cast<double>(val_int); }
|
||||
virtual double as_float() const override { return val_flt; }
|
||||
virtual string as_string() const override { return std::to_string(val_int); }
|
||||
virtual bool as_bool() const override {
|
||||
return val_int != 0;
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_numeric() const override { return true; }
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
return hasher(typeid(*this))
|
||||
.update(&val_int, sizeof(val_int))
|
||||
.update(&val_flt, sizeof(val_flt));
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
|
||||
}
|
||||
virtual bool nonequal(const value_t & other) const override {
|
||||
return !(typeid(*this) == typeid(other) && val_int == other.val_int);
|
||||
}
|
||||
};
|
||||
using value_int = std::shared_ptr<value_int_t>;
|
||||
|
||||
|
||||
struct value_float_t : public value_t {
|
||||
value_float_t(double v) { val_flt = v; }
|
||||
value val;
|
||||
value_float_t(double v) {
|
||||
val_flt = v;
|
||||
val_int = std::isfinite(v) ? static_cast<int64_t>(v) : 0;
|
||||
val = mk_val<value_int>(val_int);
|
||||
}
|
||||
virtual std::string type() const override { return "Float"; }
|
||||
virtual double as_float() const override { return val_flt; }
|
||||
virtual int64_t as_int() const override { return static_cast<int64_t>(val_flt); }
|
||||
virtual int64_t as_int() const override { return val_int; }
|
||||
virtual string as_string() const override {
|
||||
std::string out = std::to_string(val_flt);
|
||||
out.erase(out.find_last_not_of('0') + 1, std::string::npos); // remove trailing zeros
|
||||
if (out.back() == '.') out.push_back('0'); // leave one zero if no decimals
|
||||
return out;
|
||||
}
|
||||
virtual bool as_bool() const override {
|
||||
return val_flt != 0.0;
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_numeric() const override { return true; }
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
if (static_cast<double>(val_int) == val_flt) {
|
||||
return val->unique_hash();
|
||||
} else {
|
||||
return hasher(typeid(*this))
|
||||
.update(&val_int, sizeof(val_int))
|
||||
.update(&val_flt, sizeof(val_flt));
|
||||
}
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
|
||||
}
|
||||
virtual bool nonequal(const value_t & other) const override {
|
||||
return !(typeid(*this) == typeid(other) && val_flt == other.val_flt);
|
||||
}
|
||||
};
|
||||
using value_float = std::shared_ptr<value_float_t>;
|
||||
|
||||
|
|
@ -226,19 +299,49 @@ struct value_string_t : public value_t {
|
|||
return val_str.length() > 0;
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
const auto type_hash = typeid(*this).hash_code();
|
||||
auto hash = hasher();
|
||||
hash.update(&type_hash, sizeof(type_hash));
|
||||
val_str.hash_update(hash);
|
||||
return hash;
|
||||
}
|
||||
void mark_input() {
|
||||
val_str.mark_input();
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && val_str.str() == other.val_str.str();
|
||||
}
|
||||
};
|
||||
using value_string = std::shared_ptr<value_string_t>;
|
||||
|
||||
|
||||
struct value_bool_t : public value_t {
|
||||
value_bool_t(bool v) { val_bool = v; }
|
||||
value val;
|
||||
value_bool_t(bool v) {
|
||||
val_int = static_cast<int64_t>(v);
|
||||
val_flt = static_cast<double>(v);
|
||||
val = mk_val<value_int>(val_int);
|
||||
}
|
||||
virtual std::string type() const override { return "Boolean"; }
|
||||
virtual bool as_bool() const override { return val_bool; }
|
||||
virtual string as_string() const override { return std::string(val_bool ? "True" : "False"); }
|
||||
virtual int64_t as_int() const override { return val_int; }
|
||||
virtual bool as_bool() const override { return val_int; }
|
||||
virtual string as_string() const override { return std::string(val_int ? "True" : "False"); }
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_numeric() const override { return true; }
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
return val->unique_hash();
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return other.is_numeric() && val_int == other.val_int && val_flt == other.val_flt;
|
||||
}
|
||||
virtual bool nonequal(const value_t & other) const override {
|
||||
return !(typeid(*this) == typeid(other) && val_int == other.val_int);
|
||||
}
|
||||
};
|
||||
using value_bool = std::shared_ptr<value_bool_t>;
|
||||
|
||||
|
|
@ -248,13 +351,34 @@ struct value_array_t : public value_t {
|
|||
value_array_t(value & v) {
|
||||
val_arr = v->val_arr;
|
||||
}
|
||||
value_array_t(std::vector<value> && arr) {
|
||||
val_arr = arr;
|
||||
}
|
||||
value_array_t(const std::vector<value> & arr) {
|
||||
val_arr = arr;
|
||||
}
|
||||
void reverse() { std::reverse(val_arr.begin(), val_arr.end()); }
|
||||
void push_back(const value & val) { val_arr.push_back(val); }
|
||||
void push_back(value && val) { val_arr.push_back(std::move(val)); }
|
||||
void reverse() {
|
||||
if (is_immutable()) {
|
||||
throw std::runtime_error("Attempting to modify immutable type");
|
||||
}
|
||||
std::reverse(val_arr.begin(), val_arr.end());
|
||||
}
|
||||
void push_back(const value & val) {
|
||||
if (is_immutable()) {
|
||||
throw std::runtime_error("Attempting to modify immutable type");
|
||||
}
|
||||
val_arr.push_back(val);
|
||||
}
|
||||
void push_back(value && val) {
|
||||
if (is_immutable()) {
|
||||
throw std::runtime_error("Attempting to modify immutable type");
|
||||
}
|
||||
val_arr.push_back(std::move(val));
|
||||
}
|
||||
value pop_at(int64_t index) {
|
||||
if (is_immutable()) {
|
||||
throw std::runtime_error("Attempting to modify immutable type");
|
||||
}
|
||||
if (index < 0) {
|
||||
index = static_cast<int64_t>(val_arr.size()) + index;
|
||||
}
|
||||
|
|
@ -266,61 +390,228 @@ struct value_array_t : public value_t {
|
|||
return val;
|
||||
}
|
||||
virtual std::string type() const override { return "Array"; }
|
||||
virtual bool is_immutable() const override { return false; }
|
||||
virtual const std::vector<value> & as_array() const override { return val_arr; }
|
||||
virtual string as_string() const override {
|
||||
const bool immutable = is_immutable();
|
||||
std::ostringstream ss;
|
||||
ss << "[";
|
||||
ss << (immutable ? "(" : "[");
|
||||
for (size_t i = 0; i < val_arr.size(); i++) {
|
||||
if (i > 0) ss << ", ";
|
||||
ss << val_arr.at(i)->as_repr();
|
||||
value val = val_arr.at(i);
|
||||
ss << value_to_string_repr(val);
|
||||
}
|
||||
ss << "]";
|
||||
if (immutable && val_arr.size() == 1) {
|
||||
ss << ",";
|
||||
}
|
||||
ss << (immutable ? ")" : "]");
|
||||
return ss.str();
|
||||
}
|
||||
virtual bool as_bool() const override {
|
||||
return !val_arr.empty();
|
||||
}
|
||||
virtual value & at(int64_t index, value & default_val) override {
|
||||
if (index < 0) {
|
||||
index += val_arr.size();
|
||||
}
|
||||
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
|
||||
return default_val;
|
||||
}
|
||||
return val_arr[index];
|
||||
}
|
||||
virtual value & at(int64_t index) override {
|
||||
if (index < 0) {
|
||||
index += val_arr.size();
|
||||
}
|
||||
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
|
||||
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
|
||||
}
|
||||
return val_arr[index];
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_hashable() const override {
|
||||
if (std::all_of(val_arr.begin(), val_arr.end(), [&](auto & val) -> bool {
|
||||
return val->is_immutable() && val->is_hashable();
|
||||
})) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
auto hash = hasher(typeid(*this));
|
||||
for (const auto & val : val_arr) {
|
||||
// must use digest to prevent problems from "concatenation" property of hasher
|
||||
// for ex. hash of [ "ab", "c" ] should be different from [ "a", "bc" ]
|
||||
const size_t val_hash = val->unique_hash().digest();
|
||||
hash.update(&val_hash, sizeof(size_t));
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_arr.begin(), val_arr.end(), other.val_arr.begin(), value_equivalence());
|
||||
}
|
||||
};
|
||||
using value_array = std::shared_ptr<value_array_t>;
|
||||
|
||||
|
||||
struct value_tuple_t : public value_array_t {
|
||||
value_tuple_t(value & v) {
|
||||
val_arr = v->val_arr;
|
||||
}
|
||||
value_tuple_t(std::vector<value> && arr) {
|
||||
val_arr = arr;
|
||||
}
|
||||
value_tuple_t(const std::vector<value> & arr) {
|
||||
val_arr = arr;
|
||||
}
|
||||
value_tuple_t(const std::pair<value, value> & pair) {
|
||||
val_arr.push_back(pair.first);
|
||||
val_arr.push_back(pair.second);
|
||||
}
|
||||
virtual std::string type() const override { return "Tuple"; }
|
||||
virtual bool is_immutable() const override { return true; }
|
||||
};
|
||||
using value_tuple = std::shared_ptr<value_tuple_t>;
|
||||
|
||||
|
||||
struct value_object_t : public value_t {
|
||||
std::unordered_map<value, value, value_hasher, value_equivalence> unordered;
|
||||
bool has_builtins = true; // context and loop objects do not have builtins
|
||||
value_object_t() = default;
|
||||
value_object_t(value & v) {
|
||||
val_obj = v->val_obj;
|
||||
for (const auto & pair : val_obj) {
|
||||
unordered[pair.first] = pair.second;
|
||||
}
|
||||
}
|
||||
value_object_t(const std::map<std::string, value> & obj) {
|
||||
value_object_t(const std::map<value, value> & obj) {
|
||||
for (const auto & pair : obj) {
|
||||
val_obj.insert(pair.first, pair.second);
|
||||
insert(pair.first, pair.second);
|
||||
}
|
||||
}
|
||||
value_object_t(const std::vector<std::pair<value, value>> & obj) {
|
||||
for (const auto & pair : obj) {
|
||||
insert(pair.first, pair.second);
|
||||
}
|
||||
}
|
||||
void insert(const std::string & key, const value & val) {
|
||||
val_obj.insert(key, val);
|
||||
insert(mk_val<value_string>(key), val);
|
||||
}
|
||||
virtual std::string type() const override { return "Object"; }
|
||||
virtual const std::map<std::string, value> & as_object() const override { return val_obj.unordered; }
|
||||
virtual bool is_immutable() const override { return false; }
|
||||
virtual const std::vector<std::pair<value, value>> & as_ordered_object() const override { return val_obj; }
|
||||
virtual string as_string() const override {
|
||||
std::ostringstream ss;
|
||||
ss << "{";
|
||||
for (size_t i = 0; i < val_obj.size(); i++) {
|
||||
if (i > 0) ss << ", ";
|
||||
auto & [key, val] = val_obj.at(i);
|
||||
ss << value_to_string_repr(key) << ": " << value_to_string_repr(val);
|
||||
}
|
||||
ss << "}";
|
||||
return ss.str();
|
||||
}
|
||||
virtual bool as_bool() const override {
|
||||
return !val_obj.unordered.empty();
|
||||
return !unordered.empty();
|
||||
}
|
||||
virtual bool has_key(const value & key) override {
|
||||
if (!key->is_immutable() || !key->is_hashable()) {
|
||||
throw std::runtime_error("Object key of unhashable type: " + key->type());
|
||||
}
|
||||
return unordered.find(key) != unordered.end();
|
||||
}
|
||||
virtual void insert(const value & key, const value & val) override {
|
||||
bool replaced = false;
|
||||
if (is_immutable()) {
|
||||
throw std::runtime_error("Attempting to modify immutable type");
|
||||
}
|
||||
if (has_key(key)) {
|
||||
// if key exists, replace value in ordered list instead of appending
|
||||
for (auto & pair : val_obj) {
|
||||
if (*(pair.first) == *key) {
|
||||
pair.second = val;
|
||||
replaced = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
unordered[key] = val;
|
||||
if (!replaced) {
|
||||
val_obj.push_back({key, val});
|
||||
}
|
||||
}
|
||||
virtual value & at(const value & key, value & default_val) override {
|
||||
if (!has_key(key)) {
|
||||
return default_val;
|
||||
}
|
||||
return unordered.at(key);
|
||||
}
|
||||
virtual value & at(const value & key) override {
|
||||
if (!has_key(key)) {
|
||||
throw std::runtime_error("Key '" + key->as_string().str() + "' not found in value of type " + type());
|
||||
}
|
||||
return unordered.at(key);
|
||||
}
|
||||
virtual value & at(const std::string & key, value & default_val) override {
|
||||
value key_val = mk_val<value_string>(key);
|
||||
return at(key_val, default_val);
|
||||
}
|
||||
virtual value & at(const std::string & key) override {
|
||||
value key_val = mk_val<value_string>(key);
|
||||
return at(key_val);
|
||||
}
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_hashable() const override {
|
||||
if (std::all_of(val_obj.begin(), val_obj.end(), [&](auto & pair) -> bool {
|
||||
const auto & val = pair.second;
|
||||
return val->is_immutable() && val->is_hashable();
|
||||
})) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
auto hash = hasher(typeid(*this));
|
||||
for (const auto & [key, val] : val_obj) {
|
||||
// must use digest to prevent problems from "concatenation" property of hasher
|
||||
// for ex. hash of key="ab", value="c" should be different from key="a", value="bc"
|
||||
const size_t key_hash = key->unique_hash().digest();
|
||||
const size_t val_hash = val->unique_hash().digest();
|
||||
hash.update(&key_hash, sizeof(key_hash));
|
||||
hash.update(&val_hash, sizeof(val_hash));
|
||||
}
|
||||
return hash;
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other) && is_hashable() && other.is_hashable() && std::equal(val_obj.begin(), val_obj.end(), other.val_obj.begin(), value_equivalence());
|
||||
}
|
||||
};
|
||||
using value_object = std::shared_ptr<value_object_t>;
|
||||
|
||||
//
|
||||
// null and undefined types
|
||||
// none and undefined types
|
||||
//
|
||||
|
||||
struct value_none_t : public value_t {
|
||||
virtual std::string type() const override { return "None"; }
|
||||
virtual bool is_none() const override { return true; }
|
||||
virtual bool as_bool() const override { return false; }
|
||||
virtual string as_string() const override { return string(type()); }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
return hasher(typeid(*this));
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return typeid(*this) == typeid(other);
|
||||
}
|
||||
};
|
||||
using value_none = std::shared_ptr<value_none_t>;
|
||||
|
||||
|
||||
struct value_undefined_t : public value_t {
|
||||
std::string hint; // for debugging, to indicate where undefined came from
|
||||
value_undefined_t(const std::string & h = "") : hint(h) {}
|
||||
|
|
@ -329,6 +620,13 @@ struct value_undefined_t : public value_t {
|
|||
virtual bool as_bool() const override { return false; }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
virtual const func_builtins & get_builtins() const override;
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
return hasher(typeid(*this));
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
return is_undefined() == other.is_undefined();
|
||||
}
|
||||
};
|
||||
using value_undefined = std::shared_ptr<value_undefined_t>;
|
||||
|
||||
|
|
@ -409,7 +707,23 @@ struct value_func_t : public value_t {
|
|||
return val_func(new_args);
|
||||
}
|
||||
virtual std::string type() const override { return "Function"; }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
virtual std::string as_repr() const override { return type() + "<" + name + ">(" + (arg0 ? arg0->as_repr() : "") + ")"; }
|
||||
virtual bool is_hashable() const override { return false; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
// Note: this is unused for now, we don't support function as object keys
|
||||
// use function pointer as unique identifier
|
||||
const auto target = val_func.target<func_hptr>();
|
||||
return hasher(typeid(*this)).update(&target, sizeof(target));
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
// Note: this is unused for now, we don't support function as object keys
|
||||
// compare function pointers
|
||||
// (val_func == other.val_func does not work as std::function::operator== is only used for nullptr check)
|
||||
const auto target_this = this->val_func.target<func_hptr>();
|
||||
const auto target_other = other.val_func.target<func_hptr>();
|
||||
return typeid(*this) == typeid(other) && target_this == target_other;
|
||||
}
|
||||
};
|
||||
using value_func = std::shared_ptr<value_func_t>;
|
||||
|
||||
|
|
@ -420,18 +734,21 @@ struct value_kwarg_t : public value_t {
|
|||
value_kwarg_t(const std::string & k, const value & v) : key(k), val(v) {}
|
||||
virtual std::string type() const override { return "KwArg"; }
|
||||
virtual std::string as_repr() const override { return type(); }
|
||||
virtual bool is_hashable() const override { return true; }
|
||||
virtual hasher unique_hash() const noexcept override {
|
||||
const auto type_hash = typeid(*this).hash_code();
|
||||
auto hash = val->unique_hash();
|
||||
hash.update(&type_hash, sizeof(type_hash))
|
||||
.update(key.data(), key.size());
|
||||
return hash;
|
||||
}
|
||||
protected:
|
||||
virtual bool equivalent(const value_t & other) const override {
|
||||
const value_kwarg_t & other_val = static_cast<const value_kwarg_t &>(other);
|
||||
return typeid(*this) == typeid(other) && key == other_val.key && val == other_val.val;
|
||||
}
|
||||
};
|
||||
using value_kwarg = std::shared_ptr<value_kwarg_t>;
|
||||
|
||||
|
||||
// utils
|
||||
|
||||
const func_builtins & global_builtins();
|
||||
std::string value_to_json(const value & val, int indent = -1, const std::string_view item_sep = ", ", const std::string_view key_sep = ": ");
|
||||
|
||||
struct not_implemented_exception : public std::runtime_error {
|
||||
not_implemented_exception(const std::string & msg) : std::runtime_error("NotImplemented: " + msg) {}
|
||||
};
|
||||
|
||||
|
||||
} // namespace jinja
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
// TODO: use json_fwd.hpp when possible
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||
|
|
|
|||
|
|
@ -192,12 +192,12 @@ void common_ngram_cache_draft(
|
|||
break;
|
||||
}
|
||||
|
||||
LOG(" - draft candidate: token=%d\n", drafted_token);
|
||||
LOG_DBG(" - draft candidate: token=%d\n", drafted_token);
|
||||
draft.push_back(drafted_token);
|
||||
}
|
||||
}
|
||||
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) {
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) {
|
||||
std::ofstream file_out(filename, std::ios::binary);
|
||||
for (std::pair<common_ngram, common_ngram_cache_part> item : ngram_cache) {
|
||||
const common_ngram ngram = item.first;
|
||||
|
|
@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil
|
|||
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
common_ngram_cache common_ngram_cache_load(std::string & filename) {
|
||||
common_ngram_cache common_ngram_cache_load(const std::string & filename) {
|
||||
std::ifstream hashmap_file(filename, std::ios::binary);
|
||||
if (!hashmap_file) {
|
||||
throw std::ifstream::failure("Unable to open file " + filename);
|
||||
|
|
|
|||
|
|
@ -88,12 +88,12 @@ void common_ngram_cache_draft(
|
|||
// Save an ngram cache to a file.
|
||||
// ngram_cache: the ngram cache to save.
|
||||
// filename: the path under which to save the ngram cache.
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename);
|
||||
void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename);
|
||||
|
||||
// Load an ngram cache saved with common_ngram_cache_save.
|
||||
// filename: the path from which to load the ngram cache.
|
||||
// returns: an ngram cache containing the information saved to filename.
|
||||
common_ngram_cache common_ngram_cache_load(std::string & filename);
|
||||
common_ngram_cache common_ngram_cache_load(const std::string & filename);
|
||||
|
||||
// Merge two ngram caches.
|
||||
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,362 @@
|
|||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "ngram-map.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <sstream>
|
||||
|
||||
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
|
||||
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
|
||||
std::ostringstream oss;
|
||||
oss << '[';
|
||||
for (size_t i = 0; i < length; ++i) {
|
||||
if (i > 0) {
|
||||
oss << ", ";
|
||||
}
|
||||
oss << inp[start + i];
|
||||
}
|
||||
oss << ']';
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
// n-gram simple
|
||||
//
|
||||
|
||||
/**
|
||||
* Perform speculative generation using the model's own token history.
|
||||
* Searches for a matching pattern in the token history and returns draft tokens.
|
||||
*
|
||||
* @param state Current state of this implementation
|
||||
* @param tokens Token history to search in
|
||||
* @param sampled Last sampled token
|
||||
* @return Vector of draft tokens, empty if no matching pattern is found
|
||||
*/
|
||||
llama_tokens common_ngram_simple_draft(
|
||||
common_ngram_simple_state & state,
|
||||
const llama_tokens & tokens, llama_token sampled) {
|
||||
|
||||
// Simple implementation of self-speculative decoding without a draft model.
|
||||
//
|
||||
const size_t cur_len = tokens.size();
|
||||
// Only check every check_rate tokens to save compute
|
||||
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
|
||||
if (state.idx_last_check + state.config.check_rate > cur_len) {
|
||||
llama_tokens draft_tokens;
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
|
||||
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
|
||||
|
||||
// vector for tokens we want to verify.
|
||||
// return empty vector if there is no match.
|
||||
llama_tokens draft_tokens;
|
||||
|
||||
// We need at least n_draft_min + n_draft_max + 1 tokens.
|
||||
if (cur_len <= static_cast<size_t>(n_draft_min + n_draft_max + 1)) {
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
// pattern search
|
||||
llama_tokens pattern;
|
||||
pattern.reserve(n_draft_min);
|
||||
for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) {
|
||||
pattern.push_back(tokens[j]);
|
||||
}
|
||||
pattern.push_back(sampled); // add the last token to the pattern
|
||||
|
||||
// We do a search in the token history.
|
||||
state.idx_last_check = cur_len;
|
||||
|
||||
size_t match_pos = 0; // we ignore position 0, position 0 == no match
|
||||
// search backwards, but skip the current match (we are currently there)
|
||||
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
|
||||
bool match = true;
|
||||
for (size_t k = 0; k < pattern.size(); ++k) {
|
||||
if (tokens[j + k] != pattern[k]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
match_pos = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_pos == 0) {
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
const size_t copy_max = std::min(
|
||||
n_draft_max,
|
||||
cur_len - (match_pos + n_draft_min)
|
||||
);
|
||||
if (copy_max < n_draft_min) {
|
||||
return draft_tokens;
|
||||
}
|
||||
LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n",
|
||||
__func__, cur_len,
|
||||
match_pos, pattern.size(), copy_max);
|
||||
|
||||
draft_tokens.reserve(copy_max);
|
||||
for (size_t j = 0; j < copy_max; ++j) {
|
||||
draft_tokens.push_back(tokens[match_pos + n_draft_min + j]);
|
||||
}
|
||||
return draft_tokens;
|
||||
}
|
||||
|
||||
|
||||
// n-gram map
|
||||
//
|
||||
|
||||
// maximum number of counted values of a ngram map value.
|
||||
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380
|
||||
|
||||
void common_ngram_map_draft(common_ngram_map & map,
|
||||
const llama_tokens & inp, llama_token sampled,
|
||||
llama_tokens & draft) {
|
||||
// reset last key and value.
|
||||
map.last_draft_created = false;
|
||||
map.last_draft_key_idx = 0;
|
||||
map.last_draft_value_idx = 0;
|
||||
|
||||
const size_t cur_len = inp.size();
|
||||
const uint16_t n = map.size_key;
|
||||
const uint16_t m = map.size_value;
|
||||
if (cur_len < static_cast<size_t>(2 * n + m)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only check every check_rate tokens to save compute
|
||||
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
|
||||
if (map.idx_last_check + map.check_rate > cur_len) {
|
||||
return;
|
||||
}
|
||||
map.idx_last_check = cur_len;
|
||||
|
||||
// search pattern, the key n-gram
|
||||
std::vector<llama_token> key_tokens;
|
||||
key_tokens.reserve(n);
|
||||
for (size_t j = cur_len - n + 1; j < cur_len; ++j) {
|
||||
key_tokens.push_back(inp[j]);
|
||||
}
|
||||
key_tokens.push_back(sampled);
|
||||
|
||||
// search for the key in the map
|
||||
size_t match_pos = 0;
|
||||
for (size_t j = cur_len - n - m - 1; j > 0; --j) {
|
||||
bool match = true;
|
||||
for (size_t k = 0; k < n; ++k) {
|
||||
if (inp[j + k] != key_tokens[k]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
match_pos = j;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match_pos > 0) {
|
||||
LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
|
||||
cur_len, n, m, key_tokens.size(), sampled, match_pos);
|
||||
}
|
||||
|
||||
if (match_pos == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We have a match, now we look for the statistics of the key.
|
||||
size_t key_offset = map.keys.size(); // offset in the map
|
||||
// We iterate through the std::vector<common_ngram_map_key> map->keys.
|
||||
for (size_t i = 0; i < map.keys.size(); ++i) {
|
||||
bool match = true;
|
||||
for (size_t j = 0; j < n; ++j) {
|
||||
if (inp[map.keys[i].key_idx + j] != key_tokens[j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
key_offset = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (key_offset == map.keys.size()) {
|
||||
// We create a new key-entry, it will get offset key_offset.
|
||||
common_ngram_map_key new_key;
|
||||
new_key.key_idx = match_pos;
|
||||
new_key.stat_idx = 0;
|
||||
new_key.key_num = 0;
|
||||
for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) {
|
||||
new_key.values[i].value_num = 0;
|
||||
new_key.values[i].n_accepted = m;
|
||||
}
|
||||
map.keys.push_back(new_key);
|
||||
}
|
||||
|
||||
// our key n-gram:
|
||||
common_ngram_map_key & curr_key = map.keys[key_offset];
|
||||
|
||||
// update number of key hits
|
||||
curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1,
|
||||
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
|
||||
|
||||
if (map.key_only) {
|
||||
// simple mode:
|
||||
// Fill in the draft with the m tokens following the key.
|
||||
// We work with value values[0] only.
|
||||
int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted);
|
||||
|
||||
for (int i = 0; i < n_draft_tokens; ++i) {
|
||||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, curr_key.key_num, draft.size());
|
||||
|
||||
map.last_draft_created = false;
|
||||
map.last_draft_key_idx = key_offset;
|
||||
map.last_draft_value_idx = 0; // value 0 is used for simple mode
|
||||
return;
|
||||
}
|
||||
|
||||
if (curr_key.key_num < map.min_hits) {
|
||||
// not enough hits to consider this a good draft
|
||||
LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__,
|
||||
key_offset, curr_key.key_num, map.min_hits);
|
||||
return;
|
||||
}
|
||||
|
||||
// complex mode: examine the different m-grams after this key n-gram.
|
||||
//
|
||||
|
||||
// determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram.
|
||||
for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) {
|
||||
// begins the key n-gram at index i?
|
||||
bool match_key = true;
|
||||
for (size_t k = 0; k < n; ++k) {
|
||||
if (inp[i + k] != key_tokens[k]) {
|
||||
match_key = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match_key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do we haven a existing value m-gram or a new one after the key at index i?
|
||||
size_t idx_begin_value_key = i + n;
|
||||
int idx_value = -1;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
size_t idx_begin_value_v = curr_key.values[v].value_idx;
|
||||
if (idx_begin_value_v == 0) {
|
||||
// We found an empty value slot => we found a new value m-gram after the key n-gram.
|
||||
curr_key.values[v].value_idx = idx_begin_value_key;
|
||||
curr_key.values[v].value_num = 0;
|
||||
curr_key.values[v].n_accepted = m;
|
||||
idx_value = v;
|
||||
break;
|
||||
}
|
||||
bool match = true;
|
||||
for (size_t j = 0; j < m; ++j) {
|
||||
if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
// We found an existing value m-gram after the key n-gram.
|
||||
idx_value = v;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (idx_value >= 0) {
|
||||
// We found a value m-gram of the key n-gram.
|
||||
curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1,
|
||||
(int) COMMON_NGRAM_MAX_VALUE_COUNT);
|
||||
}
|
||||
}
|
||||
// the statistics are updated up to match_pos.
|
||||
curr_key.stat_idx = match_pos;
|
||||
|
||||
// Do we have a value we could use for the draft?
|
||||
uint16_t max_occur = 0;
|
||||
int slot_max = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
uint16_t curr_occur = curr_key.values[v].value_num;
|
||||
if (curr_occur > max_occur) {
|
||||
max_occur = curr_occur;
|
||||
slot_max = v;
|
||||
}
|
||||
}
|
||||
// What is sum of the other occurences?
|
||||
uint32_t sum_occur = 0;
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (v == slot_max) {
|
||||
continue;
|
||||
}
|
||||
uint16_t curr_occur = curr_key.values[v].value_num;
|
||||
sum_occur += curr_occur;
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__,
|
||||
key_offset,
|
||||
max_occur, sum_occur, slot_max,
|
||||
curr_key.values[0].value_idx, curr_key.values[0].value_num,
|
||||
curr_key.values[1].value_idx, curr_key.values[1].value_num,
|
||||
curr_key.values[2].value_idx, curr_key.values[2].value_num,
|
||||
curr_key.values[3].value_idx, curr_key.values[3].value_num
|
||||
);
|
||||
// Print the tokens of the four values (if idx != 0), use LOG_INF
|
||||
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
|
||||
if (curr_key.values[v].value_idx != 0) {
|
||||
LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (sum_occur > 0 && max_occur < 3 * sum_occur) {
|
||||
// The most frequent value is not much more frequent than the other values.
|
||||
// We do not use the draft.
|
||||
return;
|
||||
}
|
||||
|
||||
// We use the most frequent value values[slot_max] for the draft.
|
||||
// Fill in the draft with the m tokens following the key.
|
||||
int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted);
|
||||
|
||||
for (int i = 0; i < n_draft_tokens; ++i) {
|
||||
draft.push_back(inp[match_pos + n + i]);
|
||||
}
|
||||
|
||||
LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__,
|
||||
key_offset, slot_max,
|
||||
curr_key.key_num, draft.size());
|
||||
|
||||
map.last_draft_created = true;
|
||||
map.last_draft_key_idx = key_offset;
|
||||
map.last_draft_value_idx = slot_max; // value used for draft generation.
|
||||
}
|
||||
|
||||
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
|
||||
if (!map.last_draft_created) {
|
||||
return;
|
||||
}
|
||||
|
||||
// find the key and its chosen value.
|
||||
const size_t key_idx = map.last_draft_key_idx;
|
||||
const size_t val_idx = map.last_draft_value_idx;
|
||||
|
||||
// find key corresponding to key_idx.
|
||||
common_ngram_map_key & curr_key = map.keys[key_idx];
|
||||
// find value corresponding to val_idx.
|
||||
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
|
||||
|
||||
// update the value statistics
|
||||
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
|
||||
n_accepted, curr_value.n_accepted);
|
||||
curr_value.n_accepted = n_accepted;
|
||||
}
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
#pragma once
|
||||
//
|
||||
// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams
|
||||
//
|
||||
// These structures are used to do a lookup of n-grams followed by m-grams in token history.
|
||||
//
|
||||
// There are two algorithms implemented:
|
||||
// 1. ngram_simple: lookup of n-grams followed by m-grams in token history.
|
||||
// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
|
||||
// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
|
||||
//
|
||||
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// n-gram simple
|
||||
//
|
||||
|
||||
// config of n-gram simple.
|
||||
struct common_ngram_simple_config {
|
||||
uint16_t size_ngram; // size of n-grams to lookup in self-mode
|
||||
uint16_t size_mgram; // size of m-grams to draft in self-mode
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
};
|
||||
|
||||
// current state (and config) of n-gram simple.
|
||||
struct common_ngram_simple_state {
|
||||
common_ngram_simple_config config;
|
||||
|
||||
size_t idx_last_check = 0; // index of last check in context history (mutable)
|
||||
|
||||
common_ngram_simple_state(const common_ngram_simple_config & config)
|
||||
: config(config) {}
|
||||
};
|
||||
|
||||
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
|
||||
// state: the ngram simple state to search in.
|
||||
// inp: the tokens generated so far.
|
||||
// sampled: the token that was just sampled.
|
||||
// draft: vector to store the draft tokens, initially empty.
|
||||
llama_tokens common_ngram_simple_draft(
|
||||
common_ngram_simple_state & state,
|
||||
const llama_tokens & tokens, llama_token sampled);
|
||||
|
||||
|
||||
// n-gram map
|
||||
//
|
||||
|
||||
// maximum number of m-gram values stored for each key n-gram.
|
||||
#define COMMON_NGRAM_MAX_VALUES 4
|
||||
|
||||
// statistics of a m-gram after a known n-gram
|
||||
struct common_ngram_map_value {
|
||||
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
|
||||
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
|
||||
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
|
||||
};
|
||||
|
||||
// statistics of a n-gram
|
||||
struct common_ngram_map_key {
|
||||
size_t key_idx; // index of key n-gram in token-history
|
||||
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
|
||||
|
||||
uint16_t key_num; // number of occurences of this key n-gram in token-history
|
||||
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
|
||||
};
|
||||
|
||||
// map from n-grams to following m-grams in token-history
|
||||
struct common_ngram_map {
|
||||
uint16_t size_key; // size of key n-grams
|
||||
uint16_t size_value; // size of value m-grams
|
||||
|
||||
bool key_only; // true if only key n-grams are used, no values.
|
||||
|
||||
// first draft: vector only, no map.
|
||||
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
|
||||
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
|
||||
uint16_t min_hits; // minimum number of key hits to consider a draft
|
||||
|
||||
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
|
||||
uint16_t check_rate, uint16_t min_hits)
|
||||
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
|
||||
check_rate(check_rate), min_hits(min_hits) {}
|
||||
|
||||
bool last_draft_created = false; // true if a draft was created at last call.
|
||||
size_t last_draft_key_idx = 0; // index of last key used for draft generation.
|
||||
uint16_t last_draft_value_idx = 0; // index of last value used for draft generation.
|
||||
|
||||
size_t idx_last_check = 0; // index of last check in context history
|
||||
};
|
||||
|
||||
|
||||
// Searches for the n-gram in the history and checks whether a draft sequence should be generated.
|
||||
// map: the ngram map to search in.
|
||||
// inp: the tokens generated so far.
|
||||
// sampled: the token that was just sampled.
|
||||
// draft: vector to store the draft tokens, initially empty.
|
||||
void common_ngram_map_draft(
|
||||
common_ngram_map & map,
|
||||
const llama_tokens & inp, llama_token sampled,
|
||||
llama_tokens & draft);
|
||||
|
||||
// Update the statistics of a value after a draft was processed.
|
||||
void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted);
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
#include "ngram-mod.h"
|
||||
|
||||
//
|
||||
// common_ngram_mod
|
||||
//
|
||||
|
||||
common_ngram_mod::common_ngram_mod(uint16_t n, size_t size) : n(n), used(0) {
|
||||
entries.resize(size);
|
||||
|
||||
reset();
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::idx(const entry_t * tokens) const {
|
||||
size_t res = 0;
|
||||
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
res = res*6364136223846793005ULL + tokens[i];
|
||||
}
|
||||
|
||||
res = res % entries.size();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void common_ngram_mod::add(const entry_t * tokens) {
|
||||
const size_t i = idx(tokens);
|
||||
|
||||
if (entries[i] == EMPTY) {
|
||||
used++;
|
||||
}
|
||||
|
||||
entries[i] = tokens[n];
|
||||
}
|
||||
|
||||
common_ngram_mod::entry_t common_ngram_mod::get(const entry_t * tokens) const {
|
||||
const size_t i = idx(tokens);
|
||||
|
||||
return entries[i];
|
||||
}
|
||||
|
||||
void common_ngram_mod::reset() {
|
||||
std::fill(entries.begin(), entries.end(), EMPTY);
|
||||
used = 0;
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::get_n() const {
|
||||
return n;
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::get_used() const {
|
||||
return used;
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::size() const {
|
||||
return entries.size();
|
||||
}
|
||||
|
||||
size_t common_ngram_mod::size_bytes() const {
|
||||
return entries.size() * sizeof(entries[0]);
|
||||
}
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <cstddef>
|
||||
|
||||
//
|
||||
// common_ngram_mod
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19164
|
||||
//
|
||||
|
||||
// basic n-gram hasher
|
||||
struct common_ngram_mod {
|
||||
using entry_t = int32_t;
|
||||
|
||||
static constexpr entry_t EMPTY = -1;
|
||||
|
||||
common_ngram_mod(uint16_t n, size_t size);
|
||||
|
||||
size_t idx(const entry_t * tokens) const;
|
||||
void add(const entry_t * tokens);
|
||||
entry_t get(const entry_t * tokens) const; // return -1 if not found
|
||||
|
||||
void reset();
|
||||
|
||||
size_t get_n() const;
|
||||
size_t get_used() const;
|
||||
|
||||
size_t size() const;
|
||||
size_t size_bytes() const;
|
||||
|
||||
private:
|
||||
size_t n; // ngram size to hash
|
||||
|
||||
size_t used;
|
||||
|
||||
std::vector<entry_t> entries;
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -5,31 +5,33 @@
|
|||
|
||||
struct common_speculative;
|
||||
|
||||
struct common_speculative_params {
|
||||
int n_draft = 16; // max drafted tokens
|
||||
int n_reuse = 256;
|
||||
// comma separated list of all types
|
||||
std::string common_speculative_type_name_str();
|
||||
|
||||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
};
|
||||
// convert string to type
|
||||
enum common_speculative_type common_speculative_type_from_name(const std::string & name);
|
||||
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct llama_context * ctx_tgt,
|
||||
struct llama_context * ctx_dft
|
||||
);
|
||||
// convert type to string
|
||||
std::string common_speculative_type_to_str(enum common_speculative_type type);
|
||||
|
||||
void common_speculative_free(struct common_speculative * spec);
|
||||
common_speculative * common_speculative_init(
|
||||
common_params_speculative & params,
|
||||
llama_context * ctx_tgt);
|
||||
|
||||
bool common_speculative_are_compatible(
|
||||
const struct llama_context * ctx_tgt,
|
||||
const struct llama_context * ctx_dft);
|
||||
void common_speculative_free(common_speculative * spec);
|
||||
|
||||
void common_speculative_add_replacement_tgt_dft(
|
||||
struct common_speculative * spec,
|
||||
const char *source, const char *dest);
|
||||
// optionally call once at the beginning of a new generation
|
||||
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
llama_tokens common_speculative_draft(
|
||||
common_speculative * spec,
|
||||
const common_params_speculative & params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
|
||||
// informs the speculative decoder that n_accepted tokens were accepted by the target model
|
||||
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
|
||||
|
||||
// print statistics about the speculative decoding
|
||||
void common_speculative_print_stats(const common_speculative * spec);
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -170,6 +170,7 @@ pre_computed_hashes = [
|
|||
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||
# jina-v2-de variants
|
||||
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
|
||||
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"},
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
- [CMake Options](#cmake-options)
|
||||
- [Android](#android)
|
||||
- [Windows 11 Arm64](#windows-11-arm64)
|
||||
- [Linux](#Linux)
|
||||
- [Known Issue](#known-issues)
|
||||
- [TODO](#todo)
|
||||
|
||||
|
|
|
|||
|
|
@ -35,9 +35,9 @@ The following releases are verified and recommended:
|
|||
|
||||
|Commit ID|Tag|Release|Verified Platform| Update date|
|
||||
|-|-|-|-|-|
|
||||
|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |ArcB580/Linux/oneAPI 2025.1<br>LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15|
|
||||
|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19|
|
||||
|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1||
|
||||
|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |Arc B580/Linux/oneAPI 2025.1<br>LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15|
|
||||
|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19|
|
||||
|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1||
|
||||
|
||||
|
||||
## News
|
||||
|
|
@ -51,7 +51,7 @@ The following releases are verified and recommended:
|
|||
|-|-|-|-|
|
||||
|PVC 1550|39|73|+87%|
|
||||
|Flex 170|39|50|+28%|
|
||||
|Arc770|42|55|+30%|
|
||||
|Arc A770|42|55|+30%|
|
||||
|MTL|13|16|+23%|
|
||||
|ARL-H|14|17|+21%|
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ The following releases are verified and recommended:
|
|||
- Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs.
|
||||
|
||||
- 2024.5
|
||||
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770.
|
||||
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc A770.
|
||||
- Arch Linux is verified successfully.
|
||||
|
||||
- 2024.4
|
||||
|
|
@ -111,7 +111,8 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the
|
|||
|-------------------------------|---------|---------------------------------------|
|
||||
| Intel Data Center Max Series | Support | Max 1550, 1100 |
|
||||
| Intel Data Center Flex Series | Support | Flex 170 |
|
||||
| Intel Arc Series | Support | Arc 770, 730M, Arc A750, B580 |
|
||||
| Intel Arc A-Series | Support | Arc A770, Arc A730M, Arc A750 |
|
||||
| Intel Arc B-Series | Support | Arc B580 |
|
||||
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake, Lunar Lake |
|
||||
| Intel iGPU | Support | iGPU in 13700k, 13400, i5-1250P, i7-1260P, i7-1165G7 |
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
{
|
||||
"version": 4,
|
||||
"version": 5,
|
||||
"cmakeMinimumRequired": {
|
||||
"major": 3,
|
||||
"minor": 28,
|
||||
"patch": 0
|
||||
},
|
||||
"configurePresets": [
|
||||
{
|
||||
"name": "arm64-android-snapdragon",
|
||||
|
|
@ -16,7 +21,9 @@
|
|||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "android_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
|
|
@ -31,7 +38,15 @@
|
|||
"name": "arm64-windows-snapdragon",
|
||||
"inherits": [ "base", "arm64-windows-llvm" ],
|
||||
"cacheVariables": {
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"CMAKE_C_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE",
|
||||
"CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE",
|
||||
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "windows_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
# Snapdragon-based Android devices
|
||||
# Snapdragon-based devices
|
||||
|
||||
## How to Build
|
||||
## Setup
|
||||
|
||||
### Android
|
||||
|
||||
The easiest way to build llama.cpp for a Snapdragon-based Android device is using the toolchain Docker image (see github.com/snapdragon-toolchain).
|
||||
This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc.
|
||||
|
|
@ -12,7 +14,24 @@ This method works on Linux, macOS, and Windows. macOS and Windows users should i
|
|||
[d]/> cd /workspace
|
||||
```
|
||||
|
||||
The rest of the Android build process assumes that you're running inside the toolchain container.
|
||||
Note: The rest of the **Android** build process assumes that you're running inside the toolchain container.
|
||||
|
||||
### Windows On Snapdragon
|
||||
|
||||
Native Windows 11 arm64 builds has the following tools dependencies:
|
||||
- MS Visual Studio 2026 (Community Edition or Pro)
|
||||
- MSVC arm64 standard and runtime libraries
|
||||
- UCRT and Driver Kit
|
||||
- LLVM core libraries and Clang compiler (winget)
|
||||
- CMake, Git, Python (winget)
|
||||
- Hexagon SDK Community Edition 6.4 or later (see windows.md)
|
||||
- OpenCL SDK 2.3 or later (see windows.md)
|
||||
|
||||
Note: The rest of the **Windows** build process assumes that you're running natively in Powershell.
|
||||
Adapt below build commands accordingly.
|
||||
|
||||
## How to Build
|
||||
|
||||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
|
|
@ -49,24 +68,26 @@ Preset CMake variables:
|
|||
To generate an installable "package" simply use cmake --install:
|
||||
|
||||
```
|
||||
[d]/workspace> cmake --install build-snapdragon --prefix pkg-adb/llama.cpp
|
||||
[d]/workspace> cmake --install build-snapdragon --prefix pkg-snapdragon/llama.cpp
|
||||
-- Install configuration: "Release"
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-cpu.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-opencl.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-hexagon.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v73.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v75.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v79.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v81.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-cpu.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-opencl.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-hexagon.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v73.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v75.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v79.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v81.so
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml.so
|
||||
...
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-bench
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-cli
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-bench
|
||||
-- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-cli
|
||||
...
|
||||
```
|
||||
|
||||
## How to Install
|
||||
|
||||
### Android
|
||||
|
||||
For this step, your device needs to be configured for on-device development.
|
||||
Please see https://developer.android.com/studio/debug/dev-options for details.
|
||||
|
||||
|
|
@ -74,10 +95,10 @@ Once ADB is enabled, use `adb push` to install `pkg-snapdragon` on the device.
|
|||
**Note that the toolchain Docker image doesn't have ADB and doesn't set up the ADB bridge. Please use native ADB on the host.**
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ adb push pkg-adb/llama.cpp /data/local/tmp/
|
||||
pkg-adb/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
|
||||
pkg-adb/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
|
||||
pkg-adb/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
|
||||
~/src/llama.cpp$ adb push pkg-snapdragon/llama.cpp /data/local/tmp/
|
||||
pkg-snapdragon/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
|
||||
pkg-snapdragon/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
|
||||
pkg-snapdragon/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
|
||||
102 files pushed, 0 skipped. 186.9 MB/s (963151597 bytes in 4.914s)
|
||||
```
|
||||
|
||||
|
|
@ -92,6 +113,11 @@ At this point, you should also install some models:
|
|||
Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 bytes in 19.250s)
|
||||
```
|
||||
|
||||
### Windows
|
||||
|
||||
All artifacts are already installed in the `pkg-snapdragon` folder.
|
||||
To run, adapt below instructions to use Powershell scrits in `scripts/snapdragon/windows`.
|
||||
|
||||
## How to Run
|
||||
|
||||
The easiest way to run llama.cpp cli tools is using provided wrapper scripts that properly set up all required environment variables.
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
## Overview
|
||||
|
||||
The document covers procedures for installing the latest GPU and NPU drivers, and OpenCL and Hexagon SDKs.
|
||||
|
||||
|
||||
In order to use Hexagon NPU on Snapdragon Windows devices the underlying HTP Ops libraries (e.g libggml-htp-v73.so)
|
||||
must be included in the .cat file digitally signed with a trusted certificate.
|
||||
|
||||
This document covers details on how to generate personal certificate files (.pfx) and how to configure the system
|
||||
to allow for test signatures (aka test-signing).
|
||||
|
||||
## Install the latest Adreno OpenCL SDK
|
||||
|
||||
Either use the trimmed down version (optimized for CI) from
|
||||
|
||||
https://github.com/snapdragon-toolchain/opencl-sdk/releases/download/v2.3.2/adreno-opencl-sdk-v2.3.2-arm64-wos.tar.xz
|
||||
|
||||
Or download the complete official version from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Adreno_OpenCL_SDK?version=2.3.2
|
||||
|
||||
Unzip/untar the archive into
|
||||
```
|
||||
c:\Qualcomm\OpenCL_SDK\2.3.2
|
||||
```
|
||||
|
||||
## Install the latest Hexagon SDK Community Edition
|
||||
|
||||
Either use the trimmed down version (optimized for CI) from
|
||||
|
||||
https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v6.4.0.2/hexagon-sdk-v6.4.0.2-arm64-wos.tar.xz
|
||||
|
||||
Or download the complete official version from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Hexagon_SDK?version=6.4.0.2
|
||||
|
||||
Unzip/untar the archive into
|
||||
```
|
||||
c:\Qualcomm\Hexagon_SDK\6.4.0.2
|
||||
```
|
||||
|
||||
## Install the latest Adreno GPU driver
|
||||
|
||||
Download the driver from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Windows_Graphics_Driver
|
||||
|
||||
After the automated installation and reboot please make sure that the GPU device shows up in the `Device Manager` (under 'Display Adapters`)
|
||||
|
||||
## Install the latest Qualcomm NPU driver
|
||||
|
||||
Download the driver from
|
||||
|
||||
https://softwarecenter.qualcomm.com/catalog/item/Qualcomm_HND
|
||||
|
||||
After the automated installation and reboot please make sure that the Hexagon NPU device shows up in the `Device Manager` (under `Neural Processors`).
|
||||
|
||||
If the device is not available you can try installing all components (`qcnspmcdm8380`, `qcnspmcdm8380_ext`) manually.
|
||||
The components are extracted into
|
||||
```
|
||||
c:\QCDrivers\qcnspmcdm...
|
||||
```
|
||||
|
||||
## Enable NPU driver test signatures
|
||||
|
||||
Please note that the following steps are required only for the Hexagon NPU.
|
||||
Adreno GPU backend does not require test signatures.
|
||||
|
||||
### Enable testsigning
|
||||
|
||||
Use `bcdedit` to enable test-signing
|
||||
```
|
||||
> bcdedit /set TESTSIGNING ON
|
||||
```
|
||||
(Secure Boot may need to be disabled for this to work)
|
||||
|
||||
Make sure test-signing is enabled after reboot
|
||||
```
|
||||
> bcdedit /enum
|
||||
...
|
||||
testsigning Yes
|
||||
...
|
||||
```
|
||||
For additional details see Microsoft guide at
|
||||
|
||||
https://learn.microsoft.com/en-us/windows-hardware/drivers/install/the-testsigning-boot-configuration-option
|
||||
|
||||
### Create personal certificate
|
||||
|
||||
The tools required for this procedure are available as part of Windows SDK and Windows Driver Kit which should be
|
||||
installed as part of the MS Visual Studio.
|
||||
They are typically located at
|
||||
```
|
||||
c:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0
|
||||
```
|
||||
(replace 10.0.26100.0 with correct version).
|
||||
|
||||
To create personal self-signed certificate run the following commands (either from cmd or power-shell):
|
||||
```
|
||||
> cd c:\Users\MyUser
|
||||
> mkdir Certs
|
||||
> cd Certs
|
||||
> makecert -r -pe -ss PrivateCertStore -n CN=GGML.HTP.v1 -eku 1.3.6.1.5.5.7.3.3 -sv ggml-htp-v1.pvk ggml-htp-v1.cer
|
||||
> pvk2pfx.exe -pvk ggml-htp-v1.pvk -spc ggml-htp-v1.cer -pfx ggml-htp-v1.pfx
|
||||
```
|
||||
(replace `MyUser` with your username).
|
||||
|
||||
Add this certificate to `Trusted Root Certification Authorities` and `Trusted Publishers` stores.
|
||||
This can be done using `certlm` Certificate Manager tool.
|
||||
Right click on the certificate store, select `All Tasks -> Import` and follow the prompts to import the certificate from the
|
||||
PFX file you created above.
|
||||
|
||||
For additional details see Microsoft guide at
|
||||
|
||||
https://learn.microsoft.com/en-us/windows-hardware/drivers/install/introduction-to-test-signing
|
||||
|
||||
Make sure to save the PFX file, you will need it for the build procedures.
|
||||
Please note that the same certificate can be used for signing any number of builds.
|
||||
|
||||
## Build Hexagon backend with signed HTP ops libraries
|
||||
|
||||
The overall Hexagon backend build procedure for Windows on Snapdragon is the same as for other platforms.
|
||||
However, additional settings are required for generating and signing HTP Ops libraries.
|
||||
```
|
||||
> $env:OPENCL_SDK_ROOT="C:\Qualcomm\OpenCL_SDK\2.3.2"
|
||||
> $env:HEXAGON_SDK_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2"
|
||||
> $env:HEXAGON_TOOLS_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2\tools\HEXAGON_Tools\19.0.04"
|
||||
> $env:HEXAGON_HTP_CERT="c:\Users\MyUsers\Certs\ggml-htp-v1.pfx"
|
||||
> $env:WINDOWS_SDK_BIN="C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\arm64"
|
||||
|
||||
> cmake --preset arm64-windows-snapdragon -B build-wos
|
||||
...
|
||||
> cmake --install build-wos --prefix pkg-snapdragon
|
||||
```
|
||||
|
||||
Once the build is complete HTP ops libraries will be installed like this
|
||||
```
|
||||
> dir pkg-snapdragon/lib
|
||||
...
|
||||
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v73.so
|
||||
-a---- 1/22/2026 6:01 PM 191752 libggml-htp-v75.so
|
||||
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v79.so
|
||||
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v81.so
|
||||
-a---- 1/22/2026 6:01 PM 4139 libggml-htp.cat
|
||||
```
|
||||
|
||||
The .cat file, the signature and proper certicate installation can be verified with
|
||||
|
||||
```
|
||||
> signtool.exe verify /v /pa .\pkg-snapdragon\lib\libggml-htp.cat
|
||||
Verifying: .\pkg-snapdragon\lib\libggml-htp.cat
|
||||
|
||||
Signature Index: 0 (Primary Signature)
|
||||
Hash of file (sha256): 9820C664DA59D5EAE31DBB664127FCDAEF59CDC31502496BC567544EC2F401CF
|
||||
|
||||
Signing Certificate Chain:
|
||||
Issued to: GGML.HTP.v1
|
||||
...
|
||||
Successfully verified: .\pkg-snapdragon\lib\libggml-htp.cat
|
||||
...
|
||||
```
|
||||
|
|
@ -144,7 +144,7 @@ We also have a [guide](./backend/CUDA-FEDORA.md) for setting up CUDA toolkit in
|
|||
- ***Necessary*** for users of [Atomic Desktops for Fedora](https://fedoraproject.org/atomic-desktops/); such as: [Silverblue](https://fedoraproject.org/atomic-desktops/silverblue/) and [Kinoite](https://fedoraproject.org/atomic-desktops/kinoite/).
|
||||
- (there are no supported CUDA packages for these systems)
|
||||
- ***Necessary*** for users that have a host that is not a: [Supported Nvidia CUDA Release Platform](https://developer.nvidia.com/cuda-downloads).
|
||||
- (for example, you may have [Fedora 42 Beta](https://fedoramagazine.org/announcing-fedora-linux-42-beta/) as your your host operating system)
|
||||
- (for example, you may have [Fedora 42 Beta](https://fedoramagazine.org/announcing-fedora-linux-42-beta/) as your host operating system)
|
||||
- ***Convenient*** For those running [Fedora Workstation](https://fedoraproject.org/workstation/) or [Fedora KDE Plasma Desktop](https://fedoraproject.org/spins/kde), and want to keep their host system clean.
|
||||
- *Optionally* toolbox packages are available: [Arch Linux](https://archlinux.org/), [Red Hat Enterprise Linux >= 8.5](https://www.redhat.com/en/technologies/linux-platforms/enterprise-linux), or [Ubuntu](https://ubuntu.com/download)
|
||||
|
||||
|
|
@ -248,6 +248,14 @@ You may set the [cuda environmental variables](https://docs.nvidia.com/cuda/cuda
|
|||
CUDA_VISIBLE_DEVICES="-0" ./build/bin/llama-server --model /srv/models/llama.gguf
|
||||
```
|
||||
|
||||
#### CUDA_SCALE_LAUNCH_QUEUES
|
||||
|
||||
The environment variable [`CUDA_SCALE_LAUNCH_QUEUES`](https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/environment-variables.html#cuda-scale-launch-queues) controls the size of CUDA's command buffer, which determines how many GPU operations can be queued before the CPU must wait for the GPU to catch up. A larger buffer reduces CPU-side stalls and allows more work to be queued on a GPU.
|
||||
|
||||
**Default behavior:** llama.cpp automatically sets `CUDA_SCALE_LAUNCH_QUEUES=4x`, which increases the CUDA command buffer to 4 times its default size. This optimization is particularly beneficial for **Multi-GPU setups with pipeline parallelism**, where it significantly improves prompt processing throughput by allowing more operations to be enqueued across GPUs.
|
||||
|
||||
See PR [#19042](https://github.com/ggml-org/llama.cpp/pull/19042) for performance benchmarks and technical details.
|
||||
|
||||
### Unified Memory
|
||||
|
||||
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted. In Windows this setting is available in the NVIDIA control panel as `System Memory Fallback`.
|
||||
|
|
@ -487,6 +495,37 @@ Finally, after finishing your build, you should be able to do something like thi
|
|||
# ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32
|
||||
```
|
||||
|
||||
### For Mac users:
|
||||
|
||||
Generally, follow LunarG's [Getting Started with the MacOS Vulkan SDK](https://vulkan.lunarg.com/doc/sdk/latest/mac/getting_started.html) guide for installation and setup of the Vulkan SDK. There are two options of Vulkan drivers on macOS, both of which implement translation layers to map Vulkan to Metal. They can be hot-swapped by setting the `VK_ICD_FILENAMES` environment variable to point to the respective ICD JSON file.
|
||||
|
||||
Check the box for "KosmicKrisp" during the LunarG Vulkan SDK installation.
|
||||
|
||||
Set environment variable for the LunarG Vulkan SDK after installation (and optionally add to your shell profile for persistence):
|
||||
```bash
|
||||
source /path/to/vulkan-sdk/setup-env.sh
|
||||
```
|
||||
|
||||
#### Using MoltenVK
|
||||
|
||||
MoltenVK is the default Vulkan driver installed with the LunarG Vulkan SDK on macOS, so you can use the above environment variable settings as is.
|
||||
|
||||
#### Using KosmicKrisp
|
||||
|
||||
Override the environment variable for KosmicKrisp:
|
||||
```bash
|
||||
export VK_ICD_FILENAMES=$VULKAN_SDK/share/vulkan/icd.d/libkosmickrisp_icd.json
|
||||
export VK_DRIVER_FILES=$VULKAN_SDK/share/vulkan/icd.d/libkosmickrisp_icd.json
|
||||
```
|
||||
|
||||
#### Build
|
||||
|
||||
This is the only step different from [above](#common-steps) instructions.
|
||||
```bash
|
||||
cmake -B build -DGGML_VULKAN=1 -DGGML_METAL=OFF
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
## CANN
|
||||
This provides NPU acceleration using the AI cores of your Ascend NPU. And [CANN](https://www.hiascend.com/en/software/cann) is a hierarchical APIs to help you to quickly build AI applications and service based on Ascend NPU.
|
||||
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ Legend:
|
|||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
|
|
@ -114,7 +114,7 @@ Legend:
|
|||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
|
|
|||
1368
docs/ops/SYCL.csv
1368
docs/ops/SYCL.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,120 @@
|
|||
# Speculative Decoding
|
||||
|
||||
llama.cpp supports speculative decoding, a technique that can significantly accelerate token generation by predicting multiple tokens ahead of the main model.
|
||||
|
||||
[Speculative decoding](https://en.wikipedia.org/wiki/Transformer_(deep_learning)#Speculative_decoding) leverages the fact that computing n tokens in a batch (as in prompt processing) is more efficient than computing n sequentially (as in response generation). By generating draft tokens quickly and then verifying them with the target model in a single batch, this approach can achieve substantial speedups when the draft predictions are frequently correct.
|
||||
|
||||
## Implementations
|
||||
|
||||
The `llama-server` application supports several implementations of speculative decoding:
|
||||
|
||||
### Draft Model (`draft`)
|
||||
|
||||
A much smaller model (called the _draft model_) generates drafts.
|
||||
A draft model is the most used approach in speculative decoding.
|
||||
|
||||
### n-gram Cache (`ngram-cache`)
|
||||
|
||||
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
|
||||
A draft is computed using probabilities derived from these statistics. External statistics can also be loaded from files for improved accuracy.
|
||||
|
||||
See:
|
||||
|
||||
- #5479, #6828, #6848
|
||||
|
||||
### n-gram Map (`ngram-simple`, `ngram-map-*`)
|
||||
|
||||
These implementations search the token history for patterns and use matching sequences as draft candidates.
|
||||
They require no additional model but rely on patterns that have already appeared in the generated text.
|
||||
An example to use this approach can be the rewriting of source code by a LLM.
|
||||
|
||||
#### n-gram Map (`ngram-simple`)
|
||||
|
||||
This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
|
||||
|
||||
#### n-gram Map Key (`ngram-map-k`)
|
||||
|
||||
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`) before generating drafts.
|
||||
|
||||
The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
#### n-gram Map Key-4-Values (`ngram-map-k4v`)
|
||||
|
||||
This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft.
|
||||
|
||||
The number of accepted tokens is stored for each used n-gram.
|
||||
|
||||
**Example:** Server options to be used if there are a lot of longer repetitions.
|
||||
```bash
|
||||
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2
|
||||
```
|
||||
|
||||
|
||||
## Command-Line Options
|
||||
|
||||
If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
|
||||
|
||||
```
|
||||
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]
|
||||
type of speculative decoding to use when no draft model is provided
|
||||
(default: none)
|
||||
--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
|
||||
of lookup n-gram (default: 12)
|
||||
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
|
||||
of draft m-gram (default: 48)
|
||||
--spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding
|
||||
(default: 1)
|
||||
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
|
||||
```
|
||||
|
||||
### `--spec-type TYPE`
|
||||
|
||||
Specifies a type of speculative decoding without draft model.
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `none` | No speculative decoding (default) |
|
||||
| `ngram-cache` | Use n-gram cache lookup |
|
||||
| `ngram-simple` | Use simple n-gram pattern matching |
|
||||
| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys |
|
||||
| `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) |
|
||||
|
||||
**Example:** Server-instance used to refactor source code.
|
||||
```bash
|
||||
./llama-server [...] --spec-type ngram-simple
|
||||
```
|
||||
|
||||
### `--spec-ngram-size-n N`
|
||||
|
||||
Sets the size N of the lookup n-gram for n-gram map based speculative decoding.
|
||||
The n-gram size N determines how many tokens in a row to look back when searching for matching patterns.
|
||||
|
||||
### `--spec-ngram-size-m M`
|
||||
|
||||
Sets the size M of the draft m-gram for n-gram map based speculative decoding.
|
||||
The m-gram size determines how many tokens to draft when a match is found.
|
||||
Larger values can provide more speedup but may reduce acceptance rate.
|
||||
|
||||
### `--spec-ngram-check-rate R`
|
||||
|
||||
This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token).
|
||||
|
||||
### `--spec-ngram-min-hits H`
|
||||
|
||||
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
|
||||
|
||||
## Statistics
|
||||
Each speculative decoding implementation prints statistics.
|
||||
|
||||
```
|
||||
draft acceptance rate = 0.57576 ( 171 accepted / 297 generated)
|
||||
statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tokens = 187, #acc tokens = 73
|
||||
statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98
|
||||
```
|
||||
|
||||
- `#calls`: number of calls of this implementations
|
||||
- `#gen drafts`: number of drafts generated by this implementation
|
||||
- `#acc drafts`: number of drafts accepted (partially) by the main model
|
||||
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
|
||||
- `#acc tokens`: number of tokens accepted by the main model
|
||||
|
||||
|
|
@ -50,6 +50,12 @@ int main(int argc, char ** argv) {
|
|||
const int N = 5; // n-gram size
|
||||
const int G = 15; // max verification n-grams
|
||||
|
||||
// lookahead requires W + G + 1 sequences for parallel Jacobi decoding
|
||||
params.n_parallel = W + G + 1;
|
||||
|
||||
// unified KV cache is required for coupled sequences in batch splitting
|
||||
params.kv_unified = true;
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
|
@ -115,7 +121,7 @@ int main(int argc, char ** argv) {
|
|||
// seq_id == 0 : the current input token
|
||||
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
|
||||
// seq_id [W + 1, W + G] : verification n-grams
|
||||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
||||
llama_batch batch = llama_batch_init(llama_n_ctx(ctx), 0, W + G + 1);
|
||||
|
||||
// target model sampling context
|
||||
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
|
||||
|
|
|
|||
|
|
@ -32,9 +32,9 @@ int main(int argc, char ** argv){
|
|||
|
||||
common_ngram_cache ngram_cache;
|
||||
common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
|
||||
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
|
||||
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str());
|
||||
|
||||
common_ngram_cache_save(ngram_cache, params.lookup_cache_static);
|
||||
common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,18 +46,18 @@ int main(int argc, char ** argv){
|
|||
{
|
||||
const int64_t t_start_draft_us = ggml_time_us();
|
||||
|
||||
if (!params.lookup_cache_static.empty()) {
|
||||
if (!params.speculative.lookup_cache_static.empty()) {
|
||||
try {
|
||||
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
|
||||
} catch (std::ifstream::failure const &) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.lookup_cache_dynamic.empty()) {
|
||||
if (!params.speculative.lookup_cache_dynamic.empty()) {
|
||||
try {
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
|
||||
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -51,18 +51,18 @@ int main(int argc, char ** argv){
|
|||
const int64_t t_start_draft_us = ggml_time_us();
|
||||
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
|
||||
|
||||
if (!params.lookup_cache_static.empty()) {
|
||||
if (!params.speculative.lookup_cache_static.empty()) {
|
||||
try {
|
||||
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
|
||||
ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static);
|
||||
} catch (std::ifstream::failure const &) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
|
||||
LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str());
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.lookup_cache_dynamic.empty()) {
|
||||
if (!params.speculative.lookup_cache_dynamic.empty()) {
|
||||
try {
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
|
||||
ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic);
|
||||
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
||||
}
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ int main(int argc, char ** argv){
|
|||
|
||||
std::vector<llama_token> draft;
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_ctx(ctx), 0, 1);
|
||||
|
||||
const auto t_dec_start = ggml_time_us();
|
||||
|
||||
|
|
@ -210,7 +210,7 @@ int main(int argc, char ** argv){
|
|||
|
||||
// Update dynamic ngram cache with context ngram cache and save it to disk:
|
||||
common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
|
||||
common_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic);
|
||||
common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic);
|
||||
|
||||
LOG("\n\n");
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ set -e
|
|||
|
||||
# First try command line argument, then environment variable, then file
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
|
|
@ -13,6 +14,10 @@ if [ -z "$CONVERTED_MODEL" ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
cmake --build ../../build --target llama-debug -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits
|
||||
cmake --build ${BUILD_DIR} --target llama-debug -j8
|
||||
|
||||
${BUILD_DIR}/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits
|
||||
|
|
|
|||
|
|
@ -5,11 +5,16 @@ set -e
|
|||
# First try command line argument, then environment variable, then file
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
|
||||
BUILD_DIR="${3:-"$BUILD_DIR"}"
|
||||
|
||||
if [ -z "$MODEL_TESTING_PROMPT"]; then
|
||||
if [ -z "$MODEL_TESTING_PROMPT" ]; then
|
||||
MODEL_TESTING_PROMPT="Hello, my name is"
|
||||
fi
|
||||
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
echo "Error: Model path must be provided either as:" >&2
|
||||
|
|
@ -21,6 +26,6 @@ fi
|
|||
echo $CONVERTED_MODEL
|
||||
echo $MODEL_TESTING_PROMPT
|
||||
|
||||
cmake --build ../../build --target llama-debug -j8
|
||||
cmake --build ${BUILD_DIR} --target llama-debug -j8
|
||||
|
||||
../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits
|
||||
${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ done
|
|||
|
||||
# First try command line argument, then environment variable
|
||||
CONVERTED_MODEL="${CONVERTED_MODEL:-"$CONVERTED_EMBEDDING_MODEL"}"
|
||||
BUILD_DIR="${BUILD_DIR:-"../../build"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
|
|
@ -50,5 +51,5 @@ fi
|
|||
|
||||
echo $CONVERTED_MODEL
|
||||
|
||||
cmake --build ../../build --target llama-debug -j8
|
||||
../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE
|
||||
cmake --build ${BUILD_DIR} --target llama-debug -j8
|
||||
${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
set -e
|
||||
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
|
|
@ -25,9 +26,13 @@ mkdir -p ppl
|
|||
OUTPUTFILE="ppl/$(basename $CONVERTED_MODEL).kld"
|
||||
echo "Model: $CONVERTED_MODEL"
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
../.././build/bin/llama-perplexity -m $CONVERTED_MODEL \
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $CONVERTED_MODEL \
|
||||
-f ppl/wikitext-2-raw/wiki.test.raw \
|
||||
--kl-divergence-base $OUTPUTFILE
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
set -e
|
||||
|
||||
QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
if [ -z "$QUANTIZED_MODEL" ]; then
|
||||
echo "Error: Model path must be provided either as:" >&2
|
||||
|
|
@ -20,8 +21,12 @@ if [ ! -d "ppl/wikitext-2-raw" ]; then
|
|||
popd
|
||||
fi
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@
|
|||
set -e
|
||||
|
||||
QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}"
|
||||
LOGITS_FILE="${1:-"$LOGITS_FILE"}"
|
||||
LOGITS_FILE="${2:-"$LOGITS_FILE"}"
|
||||
BUILD_DIR="${3:-"$BUILD_DIR"}"
|
||||
|
||||
if [ -z "$QUANTIZED_MODEL" ]; then
|
||||
echo "Error: Model path must be provided either as:" >&2
|
||||
|
|
@ -18,11 +19,15 @@ if [ ! -f ${LOGITS_FILE} ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
echo "Model: $QUANTIZED_MODEL"
|
||||
echo "Data file: $LOGITS_FILE"
|
||||
|
||||
cmake --build ../../build --target llama-perplexity -j8
|
||||
cmake --build $BUILD_DIR --target llama-perplexity -j8
|
||||
|
||||
../.././build/bin/llama-perplexity -m $QUANTIZED_MODEL \
|
||||
${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL \
|
||||
--kl-divergence-base $LOGITS_FILE \
|
||||
--kl-divergence
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
|||
QUANTIZED_TYPE="${2:-"$QUANTIZED_TYPE"}"
|
||||
TOKEN_EMBD_TYPE="${3:-"${TOKEN_EMBD_TYPE}"}"
|
||||
OUTPUT_TYPE="${4:-"${OUTPUT_TYPE}"}"
|
||||
BUILD_DIR="${5:-"$BUILD_DIR"}"
|
||||
QUANTIZED_MODEL=$CONVERTED_MODEL
|
||||
|
||||
# Final check if we have a model path
|
||||
|
|
@ -33,12 +34,16 @@ else
|
|||
exit 1
|
||||
fi
|
||||
|
||||
cmake --build ../../build --target llama-quantize -j8
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
cmake --build $BUILD_DIR --target llama-quantize -j8
|
||||
|
||||
echo $TOKEN_EMBD_TYPE
|
||||
echo $OUTPUT_TYPE
|
||||
|
||||
CMD_ARGS=("../../build/bin/llama-quantize")
|
||||
CMD_ARGS=("${BUILD_DIR}/bin/llama-quantize")
|
||||
[[ -n "$TOKEN_EMBD_TYPE" ]] && CMD_ARGS+=("--token-embedding-type" "$TOKEN_EMBD_TYPE")
|
||||
[[ -n "$OUTPUT_TYPE" ]] && CMD_ARGS+=("--output-tensor-type" "$OUTPUT_TYPE")
|
||||
CMD_ARGS+=("$CONVERTED_MODEL" "$QUANTIZED_MODEL" "$QUANTIZED_TYPE")
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ set -e
|
|||
#
|
||||
# First try command line argument, then environment variable, then file
|
||||
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
|
||||
BUILD_DIR="${2:-"$BUILD_DIR"}"
|
||||
|
||||
# Final check if we have a model path
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
|
|
@ -13,10 +14,14 @@ if [ -z "$CONVERTED_MODEL" ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
BUILD_DIR="../../build"
|
||||
fi
|
||||
|
||||
echo $CONVERTED_MODEL
|
||||
|
||||
cmake --build ../../build --target llama-server
|
||||
cmake --build $BUILD_DIR --target llama-server
|
||||
|
||||
../../build/bin/llama-server -m $CONVERTED_MODEL \
|
||||
${BUILD_DIR}/bin/llama-server -m $CONVERTED_MODEL \
|
||||
--embedding \
|
||||
--pooling none
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
common_init();
|
||||
|
||||
if (params.speculative.model.path.empty()) {
|
||||
if (params.speculative.mparams_dft.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -34,10 +34,8 @@ int main(int argc, char ** argv) {
|
|||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model_tgt = NULL;
|
||||
//llama_model * model_dft = NULL;
|
||||
|
||||
llama_context * ctx_tgt = NULL;
|
||||
llama_context * ctx_dft = NULL;
|
||||
|
||||
// load the target model
|
||||
auto llama_init_tgt = common_init_from_params(params);
|
||||
|
|
@ -48,26 +46,38 @@ int main(int argc, char ** argv) {
|
|||
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
|
||||
|
||||
// load the draft model
|
||||
params.devices = params.speculative.devices;
|
||||
params.model = params.speculative.model;
|
||||
params.n_ctx = params.speculative.n_ctx;
|
||||
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
|
||||
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
||||
llama_model_ptr model_dft;
|
||||
|
||||
if (params.speculative.cpuparams.n_threads > 0) {
|
||||
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
}
|
||||
// TODO: simplify this logic
|
||||
{
|
||||
const auto & params_spec = params.speculative;
|
||||
|
||||
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
||||
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
|
||||
auto params_dft = params;
|
||||
|
||||
auto llama_init_dft = common_init_from_params(params);
|
||||
params_dft.n_parallel = 1;
|
||||
params_dft.n_ctx = params_spec.n_ctx;
|
||||
params_dft.n_batch = llama_n_ctx_seq(ctx_tgt);
|
||||
params_dft.devices = params_spec.devices;
|
||||
params_dft.model = params_spec.mparams_dft;
|
||||
params_dft.n_gpu_layers = params_spec.n_gpu_layers;
|
||||
|
||||
//model_dft = llama_init_dft->model();
|
||||
ctx_dft = llama_init_dft->context();
|
||||
if (params_spec.cpuparams.n_threads > 0) {
|
||||
params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
||||
}
|
||||
|
||||
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
|
||||
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
|
||||
params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
|
||||
|
||||
auto mparams_dft = common_model_params_to_llama(params_dft);
|
||||
|
||||
model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft));
|
||||
if (model_dft == nullptr) {
|
||||
LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.speculative.model_dft = model_dft.get();
|
||||
params.speculative.cparams_dft = common_context_params_to_llama(params_dft);
|
||||
}
|
||||
|
||||
// Tokenize the prompt
|
||||
|
|
@ -92,12 +102,6 @@ int main(int argc, char ** argv) {
|
|||
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
|
||||
}
|
||||
|
||||
// how many tokens to draft each time
|
||||
int n_draft = params.speculative.n_max;
|
||||
int n_draft_min = params.speculative.n_min;
|
||||
|
||||
float p_min = params.speculative.p_min;
|
||||
|
||||
int n_predict = 0;
|
||||
int n_drafted = 0;
|
||||
int n_accept = 0;
|
||||
|
|
@ -127,15 +131,11 @@ int main(int argc, char ** argv) {
|
|||
int n_past = inp.size() - 1;
|
||||
|
||||
// init the speculator
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft;
|
||||
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
|
||||
params_spec.p_min = p_min;
|
||||
const auto & params_spec = params.speculative;
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft);
|
||||
for (auto &pair : params.speculative.replacements) {
|
||||
common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str());
|
||||
}
|
||||
struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt);
|
||||
|
||||
common_speculative_begin(spec, prompt_tgt);
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
||||
|
||||
|
|
@ -151,7 +151,7 @@ int main(int argc, char ** argv) {
|
|||
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
|
||||
// from a cache or lookup tables.
|
||||
//
|
||||
llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
|
||||
llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last);
|
||||
|
||||
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
|
||||
|
||||
|
|
@ -162,7 +162,7 @@ int main(int argc, char ** argv) {
|
|||
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
||||
{
|
||||
// do not waste time on small drafts
|
||||
if (draft.size() < (size_t) n_draft_min) {
|
||||
if (draft.size() < (size_t) params_spec.n_min) {
|
||||
draft.clear();
|
||||
}
|
||||
|
||||
|
|
@ -240,7 +240,7 @@ int main(int argc, char ** argv) {
|
|||
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
||||
|
||||
LOG_INF("\n");
|
||||
LOG_INF("n_draft = %d\n", n_draft);
|
||||
LOG_INF("n_draft = %d\n", params_spec.n_max);
|
||||
LOG_INF("n_predict = %d\n", n_predict);
|
||||
LOG_INF("n_drafted = %d\n", n_drafted);
|
||||
LOG_INF("n_accept = %d\n", n_accept);
|
||||
|
|
@ -249,8 +249,6 @@ int main(int argc, char ** argv) {
|
|||
LOG_INF("\n");
|
||||
LOG_INF("draft:\n\n");
|
||||
|
||||
llama_perf_context_print(ctx_dft);
|
||||
|
||||
LOG_INF("\n");
|
||||
LOG_INF("target:\n\n");
|
||||
common_perf_print(ctx_tgt, smpl);
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
common_init();
|
||||
|
||||
if (params.speculative.model.path.empty()) {
|
||||
if (params.speculative.mparams_dft.path.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -78,7 +78,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// load the draft model
|
||||
params.devices = params.speculative.devices;
|
||||
params.model = params.speculative.model;
|
||||
params.model = params.speculative.mparams_dft;
|
||||
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
||||
if (params.speculative.cpuparams.n_threads > 0) {
|
||||
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
||||
|
|
|
|||
|
|
@ -228,6 +228,8 @@ option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU)
|
|||
option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF)
|
||||
option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON)
|
||||
option(GGML_ZDNN "ggml: use zDNN" OFF)
|
||||
option(GGML_VIRTGPU "ggml: use the VirtGPU/Virglrenderer API Remoting frontend" OFF)
|
||||
option(GGML_VIRTGPU_BACKEND "ggml: build the VirtGPU/Virglrenderer API Remoting backend" OFF)
|
||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
|
||||
|
|
@ -320,6 +322,7 @@ set(GGML_PUBLIC_HEADERS
|
|||
include/ggml-opt.h
|
||||
include/ggml-metal.h
|
||||
include/ggml-rpc.h
|
||||
include/ggml-virtgpu.h
|
||||
include/ggml-sycl.h
|
||||
include/ggml-vulkan.h
|
||||
include/ggml-webgpu.h
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define GGML_REMOTING_FRONTEND_NAME "RemotingFrontend"
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg();
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -630,10 +630,11 @@ extern "C" {
|
|||
|
||||
// this tensor...
|
||||
enum ggml_tensor_flag {
|
||||
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
|
||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
|
||||
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
|
||||
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
|
||||
GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed
|
||||
};
|
||||
|
||||
enum ggml_tri_type {
|
||||
|
|
@ -2577,11 +2578,42 @@ extern "C" {
|
|||
struct ggml_tensor * grad,
|
||||
struct ggml_tensor * sgd_params); // alpha, weight decay
|
||||
|
||||
// build forward mutiple tensors and select one of them for computing
|
||||
// this is useful for creating graphs that have constant topology but compute different things based on the input
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18550
|
||||
//
|
||||
// automatic differentiation
|
||||
// nodes:
|
||||
// | - build forward into the graph but do not compute
|
||||
// c - build forward into the graph and compute
|
||||
//
|
||||
// | | ... c ... |
|
||||
// | | ... c ... |
|
||||
// | | ... c ... |
|
||||
// [0 1 ... idx ... n-1] <-- ggml_build_forward_select(..., n, idx)
|
||||
// c
|
||||
// c
|
||||
//
|
||||
// example:
|
||||
// struct ggml_tensor * curs[3];
|
||||
//
|
||||
// curs[0] = compute0(...);
|
||||
// curs[1] = compute1(...);
|
||||
// curs[2] = compute2(...);
|
||||
//
|
||||
// int idx = select_branch(some_input);
|
||||
//
|
||||
// struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx);
|
||||
//
|
||||
GGML_API struct ggml_tensor * ggml_build_forward_select(
|
||||
struct ggml_cgraph * cgraph,
|
||||
struct ggml_tensor ** tensors,
|
||||
int n_tensors,
|
||||
int idx);
|
||||
|
||||
GGML_API void ggml_build_forward_expand(
|
||||
struct ggml_cgraph * cgraph,
|
||||
struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_build_backward_expand(
|
||||
struct ggml_context * ctx, // context for gradient computation
|
||||
struct ggml_cgraph * cgraph,
|
||||
|
|
@ -2613,7 +2645,7 @@ extern "C" {
|
|||
GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
|
||||
|
||||
// dump the graph into a file using the dot format
|
||||
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
|
||||
GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename);
|
||||
|
||||
// TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
|
||||
typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
|
||||
|
|
|
|||
|
|
@ -222,6 +222,7 @@ if (GGML_SCHED_NO_REALLOC)
|
|||
endif()
|
||||
|
||||
add_library(ggml
|
||||
ggml-backend-dl.cpp
|
||||
ggml-backend-reg.cpp)
|
||||
add_library(ggml::ggml ALIAS ggml)
|
||||
|
||||
|
|
@ -451,6 +452,7 @@ ggml_add_backend(HIP)
|
|||
ggml_add_backend(METAL)
|
||||
ggml_add_backend(MUSA)
|
||||
ggml_add_backend(RPC)
|
||||
ggml_add_backend(VirtGPU)
|
||||
ggml_add_backend(SYCL)
|
||||
ggml_add_backend(Vulkan)
|
||||
ggml_add_backend(WebGPU)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,48 @@
|
|||
#include "ggml-backend-dl.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
dl_handle * dl_load_library(const fs::path & path) {
|
||||
// suppress error dialogs for missing DLLs
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
void * p = (void *) GetProcAddress(handle, name);
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
const char * dl_error() {
|
||||
return "";
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
dl_handle * dl_load_library(const fs::path & path) {
|
||||
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
return handle;
|
||||
}
|
||||
|
||||
void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
return dlsym(handle, name);
|
||||
}
|
||||
|
||||
const char * dl_error() {
|
||||
const char *rslt = dlerror();
|
||||
return rslt != nullptr ? rslt : "";
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef _WIN32
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
# ifndef NOMINMAX
|
||||
# define NOMINMAX
|
||||
# endif
|
||||
# include <windows.h>
|
||||
# include <winevt.h>
|
||||
#else
|
||||
# include <dlfcn.h>
|
||||
# include <unistd.h>
|
||||
#endif
|
||||
#include <filesystem>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(HMODULE handle) {
|
||||
FreeLibrary(handle);
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
using dl_handle = void;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(void * handle) {
|
||||
dlclose(handle);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||
|
||||
dl_handle * dl_load_library(const fs::path & path);
|
||||
void * dl_get_sym(dl_handle * handle, const char * name);
|
||||
const char * dl_error();
|
||||
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-backend-dl.h"
|
||||
#include "ggml-impl.h"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
|
@ -69,6 +70,10 @@
|
|||
#include "ggml-rpc.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_VIRTGPU_FRONTEND
|
||||
#include "ggml-virtgpu.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_CANN
|
||||
#include "ggml-cann.h"
|
||||
#endif
|
||||
|
|
@ -77,105 +82,23 @@
|
|||
#include "ggml-zendnn.h"
|
||||
#endif
|
||||
|
||||
// disable C++17 deprecation warning for std::codecvt_utf8
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic push
|
||||
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic push
|
||||
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
||||
#endif
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
static std::string path_str(const fs::path & path) {
|
||||
std::string u8path;
|
||||
try {
|
||||
#if defined(__cpp_lib_char8_t)
|
||||
// C++20 and later: u8string() returns std::u8string
|
||||
std::u8string u8str = path.u8string();
|
||||
u8path = std::string(reinterpret_cast<const char*>(u8str.c_str()));
|
||||
const std::u8string u8str = path.u8string();
|
||||
return std::string(reinterpret_cast<const char *>(u8str.data()), u8str.size());
|
||||
#else
|
||||
// C++17: u8string() returns std::string
|
||||
u8path = path.u8string();
|
||||
return path.u8string();
|
||||
#endif
|
||||
} catch (...) {
|
||||
return std::string();
|
||||
}
|
||||
return u8path;
|
||||
}
|
||||
|
||||
#if defined(__clang__)
|
||||
# pragma clang diagnostic pop
|
||||
#elif defined(__GNUC__)
|
||||
# pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#ifdef _WIN32
|
||||
|
||||
using dl_handle = std::remove_pointer_t<HMODULE>;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(HMODULE handle) {
|
||||
FreeLibrary(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static dl_handle * dl_load_library(const fs::path & path) {
|
||||
// suppress error dialogs for missing DLLs
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
static void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
void * p = (void *) GetProcAddress(handle, name);
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
static const char * dl_error() {
|
||||
return "";
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
using dl_handle = void;
|
||||
|
||||
struct dl_handle_deleter {
|
||||
void operator()(void * handle) {
|
||||
dlclose(handle);
|
||||
}
|
||||
};
|
||||
|
||||
static void * dl_load_library(const fs::path & path) {
|
||||
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
|
||||
|
||||
return handle;
|
||||
}
|
||||
|
||||
static void * dl_get_sym(dl_handle * handle, const char * name) {
|
||||
return dlsym(handle, name);
|
||||
}
|
||||
|
||||
static const char * dl_error() {
|
||||
const char *rslt = dlerror();
|
||||
return rslt != nullptr ? rslt : "";
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
|
||||
|
||||
struct ggml_backend_reg_entry {
|
||||
ggml_backend_reg_t reg;
|
||||
dl_handle_ptr handle;
|
||||
|
|
@ -196,7 +119,12 @@ struct ggml_backend_registry {
|
|||
register_backend(ggml_backend_sycl_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_VULKAN
|
||||
// Add runtime disable check
|
||||
if (getenv("GGML_DISABLE_VULKAN") == nullptr) {
|
||||
register_backend(ggml_backend_vk_reg());
|
||||
} else {
|
||||
GGML_LOG_DEBUG("Vulkan backend disabled by GGML_DISABLE_VULKAN environment variable\n");
|
||||
}
|
||||
#endif
|
||||
#ifdef GGML_USE_WEBGPU
|
||||
register_backend(ggml_backend_webgpu_reg());
|
||||
|
|
@ -204,6 +132,10 @@ struct ggml_backend_registry {
|
|||
#ifdef GGML_USE_ZDNN
|
||||
register_backend(ggml_backend_zdnn_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_VIRTGPU_FRONTEND
|
||||
register_backend(ggml_backend_virtgpu_reg());
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_OPENCL
|
||||
register_backend(ggml_backend_opencl_reg());
|
||||
#endif
|
||||
|
|
@ -620,6 +552,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
|||
ggml_backend_load_best("rpc", silent, dir_path);
|
||||
ggml_backend_load_best("sycl", silent, dir_path);
|
||||
ggml_backend_load_best("vulkan", silent, dir_path);
|
||||
ggml_backend_load_best("virtgpu", silent, dir_path);
|
||||
ggml_backend_load_best("opencl", silent, dir_path);
|
||||
ggml_backend_load_best("hexagon", silent, dir_path);
|
||||
ggml_backend_load_best("musa", silent, dir_path);
|
||||
|
|
|
|||
|
|
@ -874,9 +874,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
|
|||
}
|
||||
if (sched->debug > 1) {
|
||||
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
|
||||
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
|
||||
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name,
|
||||
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
|
||||
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
|
||||
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0);
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
struct ggml_tensor * src = node->src[j];
|
||||
if (src == NULL) {
|
||||
|
|
@ -1922,6 +1922,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set,
|
|||
dst->view_offs = src->view_offs;
|
||||
}
|
||||
dst->op = src->op;
|
||||
dst->flags = src->flags;
|
||||
memcpy(dst->op_params, src->op_params, sizeof(dst->op_params));
|
||||
ggml_set_name(dst, src->name);
|
||||
|
||||
|
|
|
|||
|
|
@ -226,6 +226,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
|
|||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
struct ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
ggml_backend_blas_mul_mat(ctx, node);
|
||||
|
|
|
|||
|
|
@ -2146,6 +2146,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
|
|||
continue;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
|
||||
if (!ok) {
|
||||
GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
// Rename `_generic` functions if no native implementation is available.
|
||||
|
|
@ -38,9 +39,11 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
|
|
@ -48,9 +51,11 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -70,12 +75,16 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
|
||||
|
|
@ -94,9 +103,11 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
|
|
@ -104,9 +115,11 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -126,9 +139,11 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
|
|
@ -136,9 +151,11 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -165,18 +182,22 @@
|
|||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -202,9 +223,11 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
|
|
@ -212,9 +235,11 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -242,9 +267,11 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
|
|
@ -252,9 +279,11 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -6,6 +6,9 @@
|
|||
#include "ggml-impl.h"
|
||||
#include "simd-mappings.h"
|
||||
|
||||
#define GGML_FA_TILE_Q 32
|
||||
#define GGML_FA_TILE_KV 16
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
#include <utility>
|
||||
|
|
@ -84,4 +87,9 @@ static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_pa
|
|||
return {ir0, ir1};
|
||||
}
|
||||
|
||||
struct ggml_fa_tile_config {
|
||||
static constexpr size_t Q = GGML_FA_TILE_Q;
|
||||
static constexpr size_t KV = GGML_FA_TILE_KV;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#include "vec.h"
|
||||
#include "ops.h"
|
||||
#include "ggml.h"
|
||||
#include "common.h"
|
||||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||
|
|
@ -2866,10 +2867,12 @@ struct ggml_cplan ggml_graph_plan(
|
|||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const int64_t ne10 = node->src[1]->ne[0]; // DK
|
||||
const int64_t ne20 = node->src[2]->ne[0]; // DV
|
||||
const int64_t DK = node->src[1]->ne[0];
|
||||
const int64_t DV = node->src[2]->ne[0];
|
||||
|
||||
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
|
||||
// Tiled flash attention scratch (tile sizes defined in common.h)
|
||||
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
|
||||
cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
{
|
||||
|
|
@ -2943,6 +2946,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
|||
continue;
|
||||
}
|
||||
|
||||
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_compute_forward(¶ms, node);
|
||||
|
||||
if (state->ith == 0 && cplan->abort_callback &&
|
||||
|
|
|
|||
|
|
@ -1797,10 +1797,27 @@ class tinyBLAS_Q0_AVX {
|
|||
} \
|
||||
} \
|
||||
|
||||
template<typename T>
|
||||
struct mma_instr;
|
||||
|
||||
template<>
|
||||
struct mma_instr<ggml_bf16_t> {
|
||||
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
|
||||
__builtin_mma_xvbf16ger2pp(acc, a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct mma_instr<ggml_fp16_t> {
|
||||
static inline void outer_product(acc_t *acc, vec_t a, vec_t b) {
|
||||
__builtin_mma_xvf16ger2pp(acc, a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TA, typename TB, typename TC>
|
||||
class tinyBLAS_BF16_PPC {
|
||||
class tinyBLAS_HP16_PPC {
|
||||
public:
|
||||
tinyBLAS_BF16_PPC(int64_t k,
|
||||
tinyBLAS_HP16_PPC(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const TB *B, int64_t ldb,
|
||||
TC *C, int64_t ldc,
|
||||
|
|
@ -2118,8 +2135,8 @@ class tinyBLAS_BF16_PPC {
|
|||
packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A);
|
||||
packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B);
|
||||
for (int x = 0; x < 4; x++) {
|
||||
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
}
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
|
|
@ -2135,8 +2152,8 @@ class tinyBLAS_BF16_PPC {
|
|||
packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A);
|
||||
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
||||
for (int x = 0; x < 4; x++) {
|
||||
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
}
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
|
|
@ -2155,10 +2172,10 @@ class tinyBLAS_BF16_PPC {
|
|||
packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A);
|
||||
packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B);
|
||||
for (int x = 0; x < 4; x++) {
|
||||
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2189,7 +2206,7 @@ class tinyBLAS_BF16_PPC {
|
|||
packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A);
|
||||
packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B);
|
||||
for (int x = 0; x<2; x++) {
|
||||
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
}
|
||||
}
|
||||
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||
|
|
@ -2224,8 +2241,8 @@ class tinyBLAS_BF16_PPC {
|
|||
packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A);
|
||||
packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B);
|
||||
for (int x = 0; x<4; x++) {
|
||||
__builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
}
|
||||
}
|
||||
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||
|
|
@ -3418,16 +3435,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|||
return tb.matmul(m, n);
|
||||
}
|
||||
#elif defined(__MMA__)
|
||||
if ((k % 8))
|
||||
return false;
|
||||
if(Btype == GGML_TYPE_BF16) {
|
||||
tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
|
||||
(const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
params->ith, params->nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
if (k % 8) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Btype == GGML_TYPE_BF16) {
|
||||
tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k,
|
||||
(const ggml_bf16_t *)A, lda,
|
||||
(const ggml_bf16_t *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
params->ith, params->nth };
|
||||
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
}
|
||||
#elif defined(__riscv_zvfbfwma)
|
||||
#if LMUL == 1
|
||||
|
|
@ -3516,6 +3536,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|||
#endif
|
||||
return tb.matmul(m, n);
|
||||
}
|
||||
#elif defined(__MMA__)
|
||||
if (k % 8) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Btype == GGML_TYPE_F16) {
|
||||
tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k,
|
||||
(const ggml_fp16_t *)A, lda,
|
||||
(const ggml_fp16_t *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
params->ith, params->nth };
|
||||
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8164,6 +8164,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
// online softmax / attention
|
||||
// loop over n_kv and n_head_kv
|
||||
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
|
||||
if (mv == -INFINITY) {
|
||||
|
|
@ -8271,6 +8272,280 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_tiled(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst,
|
||||
int ir0, int ir1) {
|
||||
const ggml_tensor * q = dst->src[0];
|
||||
const ggml_tensor * k = dst->src[1];
|
||||
const ggml_tensor * v = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||
|
||||
const int64_t DK = nek0;
|
||||
const int64_t DV = nev0;
|
||||
const int64_t N = neq1;
|
||||
|
||||
GGML_ASSERT(ne0 == DV);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
|
||||
// input tensor rows must be contiguous
|
||||
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
||||
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
||||
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
||||
|
||||
GGML_ASSERT(neq0 == DK);
|
||||
GGML_ASSERT(nek0 == DK);
|
||||
GGML_ASSERT(nev0 == DV);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
GGML_ASSERT(k->type == v->type);
|
||||
const ggml_type kv_type = k->type;
|
||||
|
||||
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
|
||||
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
|
||||
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
|
||||
const size_t kv_type_size = ggml_type_size(kv_type);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t rk2 = neq2/nek2;
|
||||
const int64_t rk3 = neq3/nek3;
|
||||
|
||||
const int64_t rv2 = neq2/nev2;
|
||||
const int64_t rv3 = neq3/nev3;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
const uint32_t n_head = neq2;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
int ith = params->ith;
|
||||
|
||||
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
|
||||
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
|
||||
|
||||
int ir = ir0;
|
||||
while (ir < ir1) {
|
||||
// q indices for the start of this tile
|
||||
const int iq3 = ir/(neq2*neq1);
|
||||
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||
|
||||
// Number of valid rows in this tile:
|
||||
// - limited by tile size (Q_TILE_SZ)
|
||||
// - limited by chunk boundary (ir1 - ir)
|
||||
// - limited by head boundary (neq1 - iq1) to avoid crossing into next head
|
||||
const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
|
||||
GGML_ASSERT(tile_rows > 0);
|
||||
|
||||
const uint32_t h = iq2; // head index
|
||||
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
|
||||
|
||||
float S[Q_TILE_SZ];
|
||||
float M[Q_TILE_SZ];
|
||||
|
||||
for (int i = 0 ; i < Q_TILE_SZ; ++i) {
|
||||
S[i] = 0.;
|
||||
M[i] = -INFINITY;
|
||||
}
|
||||
|
||||
// Per-thread scratch layout:
|
||||
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
|
||||
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
|
||||
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
|
||||
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
|
||||
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
|
||||
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
|
||||
|
||||
void * Q_q = base;
|
||||
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
|
||||
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
|
||||
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
|
||||
|
||||
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
|
||||
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
|
||||
|
||||
// k indices
|
||||
const int ik3 = iq3 / rk3;
|
||||
const int ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const int iv3 = iq3 / rv3;
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
|
||||
}
|
||||
// Zero-pad remaining rows
|
||||
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
|
||||
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
|
||||
}
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
|
||||
|
||||
// skip the tile entirely if all the masks are -inf
|
||||
if (mask) {
|
||||
bool can_skip = true;
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
|
||||
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
|
||||
can_skip = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (can_skip) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
float s;
|
||||
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
|
||||
KQ[tq * KV_TILE_SZ + tk] = s * scale;
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
|
||||
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
|
||||
}
|
||||
|
||||
if (mask) {
|
||||
ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
|
||||
}
|
||||
|
||||
bool skip[Q_TILE_SZ] = {};
|
||||
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
float * kq_row = KQ + tq * KV_TILE_SZ;
|
||||
|
||||
float tile_max;
|
||||
ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
|
||||
|
||||
if (tile_max == -INFINITY) {
|
||||
skip[tq] = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
const float Mold = M[tq];
|
||||
const float Mnew = fmaxf(Mold, tile_max);
|
||||
|
||||
if (Mnew > Mold) {
|
||||
const float ms = expf(Mold - Mnew);
|
||||
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
||||
S[tq] *= ms;
|
||||
}
|
||||
M[tq] = Mnew;
|
||||
|
||||
|
||||
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
|
||||
}
|
||||
|
||||
// Convert V tile to F32 first (if F16), then do MAD
|
||||
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
|
||||
// TODO: on ARM, native f16 should be faster
|
||||
if (kv_type == GGML_TYPE_F16) {
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
|
||||
}
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
|
||||
if (skip[tq]) continue;
|
||||
float * vkq_row = VKQ32 + tq * DV;
|
||||
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
|
||||
const float p = KQ[tq * KV_TILE_SZ + tk];
|
||||
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sinks (apply only to valid rows in the tile)
|
||||
if (sinks) {
|
||||
const float s = ((float *)((char *) sinks->data))[h];
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
|
||||
if (s > M[tq]) {
|
||||
ms = expf(M[tq] - s);
|
||||
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
|
||||
} else {
|
||||
vs = expf(s - M[tq]);
|
||||
}
|
||||
|
||||
S[tq] = S[tq] * ms + vs;
|
||||
}
|
||||
}
|
||||
|
||||
for (int tq = 0; tq < tile_rows; tq++) {
|
||||
// V /= S
|
||||
const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
|
||||
ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
|
||||
|
||||
// dst indices
|
||||
const int i1 = iq1 + tq;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
// permute(0, 2, 1, 3)
|
||||
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
|
||||
}
|
||||
|
||||
ir += tile_rows;
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
|
|
@ -8343,6 +8618,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
// The number of elements in each chunk
|
||||
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||
|
||||
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
|
||||
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
|
||||
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
|
||||
const bool use_tiled = (q->type == GGML_TYPE_F32 &&
|
||||
kv_is_f32_or_f16 &&
|
||||
k->type == v->type &&
|
||||
nek1 % KV_TILE_SZ == 0 &&
|
||||
neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
|
||||
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
|
|
@ -8350,7 +8634,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||
const int64_t ir0 = dr * current_chunk;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
||||
if (use_tiled) {
|
||||
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
|
||||
} else {
|
||||
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
||||
}
|
||||
|
||||
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -474,15 +474,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
|
|
@ -616,6 +609,191 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0f;
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
|
||||
|
||||
for (int k = 0; k < 16; k++) {
|
||||
// k = 0.. 7 weights 0-63 low, 64-127 high
|
||||
// k = 8..15 weights 128-191 low, 192-255 high
|
||||
const int base_l = (k / 8) * 128 + (k % 8) * 8;
|
||||
const int base_h = base_l + 64;
|
||||
|
||||
const int scale_idx_l = base_l / 16;
|
||||
const int scale_idx_h = base_h / 16;
|
||||
|
||||
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
|
||||
const int qh_shift_l = ((base_l % 128) / 32) * 2;
|
||||
const int qh_shift_h = ((base_h % 128) / 32) * 2;
|
||||
|
||||
// qh_half: offset to the correct 32-byte half (0 or 32)
|
||||
const int qh_half_l = (base_l / 128) * 32;
|
||||
const int qh_half_h = (base_h / 128) * 32;
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
// Interleaved scales
|
||||
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
|
||||
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
|
||||
|
||||
int sumi_l = 0;
|
||||
int sumi_h = 0;
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int ql_pos = k * 64 + j * 8 + i;
|
||||
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
|
||||
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
|
||||
|
||||
// qh indexing with 8-byte interleaving (like q5_K)
|
||||
const int qh_byte_l = qh_half_l + ((base_l + i) % 32);
|
||||
const int qh_chunk_l = qh_byte_l / 8;
|
||||
const int qh_pos_l = qh_byte_l % 8;
|
||||
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
|
||||
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
|
||||
|
||||
const int qh_byte_h = qh_half_h + ((base_h + i) % 32);
|
||||
const int qh_chunk_h = qh_byte_h / 8;
|
||||
const int qh_pos_h = qh_byte_h % 8;
|
||||
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
|
||||
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
|
||||
|
||||
const int q_l = ((hi_2_l << 4) | l_4) - 32;
|
||||
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
|
||||
|
||||
const int8_t a_l = a_ptr[l].qs[base_l + i];
|
||||
const int8_t a_h = a_ptr[l].qs[base_h + i];
|
||||
|
||||
sumi_l += q_l * a_l;
|
||||
sumi_h += q_h * a_h;
|
||||
}
|
||||
|
||||
sumf[j] +=
|
||||
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -1046,15 +1224,7 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
assert (nr % 4 == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
float sumf[4][8];
|
||||
float sum_minf[4][8];
|
||||
|
|
@ -1212,6 +1382,213 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][8];
|
||||
float sum_minf[4][8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
|
||||
float sumf[4][8];
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int k = 0; k < 16; k++) {
|
||||
// k = 0.. 7 weights 0-63 low, 64-127 high
|
||||
// k = 8..15 weights 128-191 low, 192-255 high
|
||||
const int base_l = (k / 8) * 128 + (k % 8) * 8;
|
||||
const int base_h = base_l + 64;
|
||||
|
||||
const int scale_idx_l = base_l / 16;
|
||||
const int scale_idx_h = base_h / 16;
|
||||
|
||||
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
|
||||
const int qh_shift_l = ((base_l % 128) / 32) * 2;
|
||||
const int qh_shift_h = ((base_h % 128) / 32) * 2;
|
||||
|
||||
// qh_half: offset to the correct 32-byte half (0 or 32)
|
||||
const int qh_half_l = (base_l / 128) * 32;
|
||||
const int qh_half_h = (base_h / 128) * 32;
|
||||
|
||||
// Activation base indices for q8_Kx4 interleaved format
|
||||
// Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32
|
||||
const int q8_base = (k / 8) * 512 + (k % 8) * 32;
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
// Interleaved scales
|
||||
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
|
||||
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
|
||||
|
||||
int sumi_l = 0;
|
||||
int sumi_h = 0;
|
||||
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int ql_pos = k * 64 + j * 8 + i;
|
||||
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
|
||||
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
|
||||
|
||||
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
|
||||
const int qh_chunk_l = qh_idx_l / 8;
|
||||
const int qh_pos_l = qh_idx_l % 8;
|
||||
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
|
||||
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
|
||||
|
||||
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
|
||||
const int qh_chunk_h = qh_idx_h / 8;
|
||||
const int qh_pos_h = qh_idx_h % 8;
|
||||
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
|
||||
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
|
||||
|
||||
const int q_l = ((hi_2_l << 4) | l_4) - 32;
|
||||
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
|
||||
|
||||
const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i];
|
||||
const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256];
|
||||
|
||||
sumi_l += q_l * q8_l;
|
||||
sumi_h += q_h * q8_h;
|
||||
}
|
||||
|
||||
sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
|
||||
a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
|
|
@ -1612,8 +1989,7 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in
|
|||
// Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure
|
||||
// For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures
|
||||
|
||||
for(int i = 0; i < 128; i++){
|
||||
|
||||
for (int i = 0; i < 128; i++) {
|
||||
// Index for selecting which q2k super block
|
||||
int src1 = (i % 16) / 2;
|
||||
// Index for selecting scale
|
||||
|
|
@ -1622,7 +1998,141 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in
|
|||
out.scales[i] = in[src1].scales[src2];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) {
|
||||
block_q5_Kx8 out;
|
||||
//Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
|
||||
}
|
||||
|
||||
const int end = QK_K * 4 / blck_size_interleave;
|
||||
|
||||
// Interleave Q5_K quants by taking 8 bytes at a time
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
||||
}
|
||||
|
||||
// Repeat for low bits 8 bytes at a time as well, since
|
||||
// the high bits are interleaved in Q5_K and the index is
|
||||
// qh_idx = (qs_idx % 32);
|
||||
// qh_val = qh[qh_idx] >> (qs_idx / 32);
|
||||
for (int i = 0; i < end / 4; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
|
||||
}
|
||||
|
||||
// The below logic is copied over from Q4_K
|
||||
// The point is to unpack all the scales and mins for each sub block every time we load 12 bytes.
|
||||
// Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
|
||||
// The output Q5_Kx8 structure has 96 bytes
|
||||
// Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure
|
||||
// For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures
|
||||
uint8_t s[8], m[8];
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 8; j++) {
|
||||
s[j] = in[j].scales[i] & 63;
|
||||
m[j] = in[j].scales[i + 4] & 63;
|
||||
}
|
||||
|
||||
out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
|
||||
out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
|
||||
out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
|
||||
out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
|
||||
out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
|
||||
out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
|
||||
out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
|
||||
out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
|
||||
out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
|
||||
out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
|
||||
out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
|
||||
out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
|
||||
}
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
for (int j = 0; j < 8; j++) {
|
||||
s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15);
|
||||
m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4);
|
||||
}
|
||||
|
||||
out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
|
||||
out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
|
||||
out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
|
||||
out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
|
||||
out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
|
||||
out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
|
||||
out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
|
||||
out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
|
||||
out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
|
||||
out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
|
||||
out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
|
||||
out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_interleave) {
|
||||
block_q6_Kx8 out;
|
||||
constexpr int n_blocks = 8; // Kx8
|
||||
for (int i = 0; i < n_blocks; i++) {
|
||||
out.d[i] = in[i].d;
|
||||
}
|
||||
|
||||
const int end_ls = QK_K * 4 / blck_size_interleave;
|
||||
// Interleave Q6_K quants by taking 8 bytes at a time
|
||||
for (int i = 0; i < end_ls; ++i) {
|
||||
int src_id = i % n_blocks;
|
||||
int src_offset = (i / n_blocks) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elem_ls;
|
||||
memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t));
|
||||
}
|
||||
|
||||
// Interleave high bits using same 8-byte pattern as low bits
|
||||
const int end_hs = end_ls / 2;
|
||||
for (int i = 0; i < end_hs; ++i) {
|
||||
int src_id = i % n_blocks;
|
||||
int src_offset = (i / n_blocks) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elem_hs;
|
||||
memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t));
|
||||
}
|
||||
|
||||
// The below logic is designed so as to unpack and rearrange scales in Q6_K
|
||||
// The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants
|
||||
// Q6_K structure has an 8-bit scale per 16 elements -> 16 scales
|
||||
// scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7] (bl = block)
|
||||
constexpr int n_scales = QK_K / 16;
|
||||
|
||||
for (int i = 0; i < n_blocks; i++) {
|
||||
for (int j = 0; j < n_scales; j++) {
|
||||
out.scales[j * n_blocks + i] = in[i].scales[j];
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
|
|
@ -1706,7 +2216,7 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
|
|||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++ ) {
|
||||
for (int i = 0; i < nrows_interleaved; i++) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
*dst++ = make_block_q2_Kx8(dst_tmp, interleave_block);
|
||||
|
|
@ -1718,6 +2228,67 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
|
|||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
|
||||
int interleave_block,
|
||||
const void * GGML_RESTRICT data,
|
||||
size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 8;
|
||||
|
||||
block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data;
|
||||
const block_q5_K * src = (const block_q5_K *) data;
|
||||
block_q5_K dst_tmp[8];
|
||||
int nrow = ggml_nrows(t);
|
||||
int nblocks = t->ne[0] / QK_K;
|
||||
|
||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K));
|
||||
|
||||
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
*dst++ = make_block_q5_Kx8(dst_tmp, interleave_block);
|
||||
}
|
||||
src += nrows_interleaved * nblocks;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q6_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 8;
|
||||
|
||||
block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data;
|
||||
const block_q6_K * src = (const block_q6_K *) data;
|
||||
block_q6_K dst_tmp[8];
|
||||
int nrow = ggml_nrows(t);
|
||||
int nblocks = t->ne[0] / QK_K;
|
||||
|
||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K));
|
||||
|
||||
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
*dst++ = make_block_q6_Kx8(dst_tmp, interleave_block);
|
||||
}
|
||||
src += nrows_interleaved * nblocks;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
|
|
@ -1936,6 +2507,14 @@ template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * da
|
|||
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q6_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
|
||||
}
|
||||
|
|
@ -1973,6 +2552,17 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
|||
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <>
|
||||
void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n,
|
||||
float * s,
|
||||
size_t bs,
|
||||
const void * vx,
|
||||
const void * vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
|
@ -1981,8 +2571,12 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
|||
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
|
|
@ -2013,20 +2607,35 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
|||
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
template <>
|
||||
void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n,
|
||||
float * s,
|
||||
size_t bs,
|
||||
const void * vx,
|
||||
const void * vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
|
|
@ -2393,20 +3002,19 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
||||
|
||||
const int id = row_mapping.i1; // selected expert index
|
||||
const int id = row_mapping.i1; // selected expert index
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||
|
||||
const int64_t i1 = id; // selected expert index
|
||||
const int64_t i2 = i12; // row
|
||||
const int64_t i1 = id; // selected expert index
|
||||
const int64_t i2 = i12; // row
|
||||
|
||||
const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
|
||||
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
|
||||
(float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
|
||||
src0_cur + src0_cur_start * nb01,
|
||||
src1_col, 1, src0_cur_end - src0_cur_start);
|
||||
gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(
|
||||
ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
|
||||
src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
|
||||
}
|
||||
}
|
||||
#undef MMID_MATRIX_ROW
|
||||
|
|
@ -2422,7 +3030,6 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||
} // namespace ggml::cpu::repack
|
||||
|
||||
static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) {
|
||||
|
||||
// instance for Q4
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
|
||||
|
|
@ -2432,6 +3039,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
// instance for Q5_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
|
||||
|
||||
// instance for Q6_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q6_K, 8, 8, GGML_TYPE_Q8_K> q6_K_8x8_q8_K;
|
||||
|
||||
// instance for Q2
|
||||
static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
|
||||
|
||||
|
|
@ -2482,6 +3095,18 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
return &q2_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q5_K) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q5_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q6_K) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q6_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_IQ4_NL) {
|
||||
if (ggml_cpu_has_avx2()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ struct block_q4_Kx8 {
|
|||
};
|
||||
|
||||
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
|
||||
|
||||
struct block_q2_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
|
|
@ -52,6 +53,28 @@ struct block_q2_Kx8 {
|
|||
};
|
||||
|
||||
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
|
||||
|
||||
struct block_q5_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
uint8_t scales[96]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants
|
||||
uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4)
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5,
|
||||
"wrong q5_K block size/padding");
|
||||
|
||||
struct block_q6_Kx8 {
|
||||
ggml_half d[8];
|
||||
int8_t scales[QK_K / 16 * 8];
|
||||
uint8_t ql[QK_K / 2 * 8]; // low bits of 6-bit quants (groups of 2)
|
||||
uint8_t qh[QK_K / 4 * 8]; // high bits of 6-bit quants (groups of 4)
|
||||
};
|
||||
|
||||
static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8,
|
||||
"wrong q6_K block size/padding");
|
||||
|
||||
struct block_q8_Kx4 {
|
||||
float d[4]; // delta
|
||||
int8_t qs[QK_K * 4]; // quants
|
||||
|
|
@ -85,17 +108,21 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
@ -111,17 +138,21 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
|||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/cub.cuh>
|
||||
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1)
|
||||
# define STRIDED_ITERATOR_AVAILABLE
|
||||
# endif
|
||||
using namespace cub;
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
|
|
@ -14,12 +17,14 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr
|
|||
}
|
||||
}
|
||||
|
||||
#ifndef STRIDED_ITERATOR_AVAILABLE
|
||||
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx <= nrows) {
|
||||
offsets[idx] = idx * ncols;
|
||||
}
|
||||
}
|
||||
#endif // STRIDED_ITERATOR_AVAILABLE
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
|
|
@ -31,19 +36,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|||
cudaStream_t stream) {
|
||||
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
|
||||
int * temp_indices = temp_indices_alloc.get();
|
||||
float * temp_keys = temp_keys_alloc.get();
|
||||
int * d_offsets = offsets_alloc.get();
|
||||
|
||||
static const int block_size = 256;
|
||||
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
|
||||
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
|
||||
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||
|
||||
#ifdef STRIDED_ITERATOR_AVAILABLE
|
||||
auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
|
||||
#else
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
int * offset_iterator = offsets_alloc.get();
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);
|
||||
#endif
|
||||
CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
|
@ -57,7 +65,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|||
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
d_offsets, d_offsets + 1, stream);
|
||||
offset_iterator, offset_iterator + 1, stream);
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
|
|
@ -66,7 +74,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
||||
dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -80,7 +89,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
|
||||
ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
|
||||
}
|
||||
} else {
|
||||
if (nrows == 1) {
|
||||
|
|
@ -89,8 +98,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
|||
ncols, 0, sizeof(float) * 8, stream);
|
||||
} else {
|
||||
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||
stream);
|
||||
temp_indices, dst, ncols * nrows, nrows, offset_iterator,
|
||||
offset_iterator + 1, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@
|
|||
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
|
||||
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
|
||||
#define GGML_CUDA_CC_BLACKWELL 1200
|
||||
#define GGML_CUDA_CC_DGX_SPARK 1210
|
||||
#define GGML_CUDA_CC_RUBIN 1300
|
||||
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
|
||||
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
|
||||
|
|
@ -1121,14 +1122,18 @@ struct ggml_tensor_extra_gpu {
|
|||
#endif
|
||||
|
||||
struct ggml_cuda_graph_node_properties {
|
||||
void * node_address;
|
||||
void * node_data;
|
||||
ggml_op node_op;
|
||||
enum ggml_type node_type;
|
||||
int32_t flags;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
void * src_address[GGML_MAX_SRC];
|
||||
void * src_data[GGML_MAX_SRC];
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
};
|
||||
|
||||
static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
|
||||
|
||||
struct ggml_cuda_graph {
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
~ggml_cuda_graph() {
|
||||
|
|
@ -1148,6 +1153,12 @@ struct ggml_cuda_graph {
|
|||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_cuda_graph_node_properties> props;
|
||||
|
||||
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
|
||||
// they properties also have to match in order to be able to safely reuse a CUDA graph
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18583
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
|
||||
std::vector<ggml_cuda_graph_node_properties> extra;
|
||||
|
||||
void record_update(bool use_graph, bool update_required) {
|
||||
if (use_graph && update_required) {
|
||||
number_consecutive_updates++;
|
||||
|
|
@ -1326,10 +1337,44 @@ struct ggml_backend_cuda_context {
|
|||
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
|
||||
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
|
||||
|
||||
std::unique_ptr<ggml_cuda_graph> cuda_graph;
|
||||
|
||||
int curr_stream_no = 0;
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
// Map from first_node_ptr to cuda_graph - allows multiple graphs per context
|
||||
// when the computation is split across CPU/GPU (e.g., with --n-cpu-moe)
|
||||
std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs;
|
||||
|
||||
ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {
|
||||
auto it = cuda_graphs.find(first_node_ptr);
|
||||
if (it == cuda_graphs.end()) {
|
||||
cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>();
|
||||
return cuda_graphs[first_node_ptr].get();
|
||||
}
|
||||
return it->second.get();
|
||||
}
|
||||
|
||||
// Check if any CUDA graph is enabled for this context (used by kernels that need to know
|
||||
// if graphs are in use without having access to the specific graph key)
|
||||
bool any_cuda_graph_enabled() const {
|
||||
for (const auto & [key, graph] : cuda_graphs) {
|
||||
if (graph && graph->is_enabled()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if any CUDA graph has an instance for this context
|
||||
bool any_cuda_graph_has_instance() const {
|
||||
for (const auto & [key, graph] : cuda_graphs) {
|
||||
if (graph && graph->instance != nullptr) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif // USE_CUDA_GRAPH
|
||||
|
||||
explicit ggml_backend_cuda_context(int device) :
|
||||
device(device),
|
||||
name(GGML_CUDA_NAME + std::to_string(device)) {
|
||||
|
|
|
|||
|
|
@ -629,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max(
|
|||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
|
||||
const int nbatch_fa) {
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
|
||||
const int ne11, const int ne12, const int nbatch_fa) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
|
|
@ -641,11 +641,14 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
|
||||
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
|
||||
|
||||
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
|
|
@ -654,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
return;
|
||||
}
|
||||
|
||||
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
|
||||
const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
|
||||
const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
|
||||
const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
|
||||
|
||||
if (jt*ncols1 + j >= ne01) {
|
||||
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
|
||||
|
||||
if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val = 0.0f;
|
||||
|
|
@ -681,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
|
|
@ -778,13 +785,11 @@ void launch_fattn(
|
|||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
const bool is_mla = DV == 512; // TODO better parameterization
|
||||
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(V || is_mla);
|
||||
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
|
@ -794,9 +799,9 @@ void launch_fattn(
|
|||
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
|
||||
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
|
||||
GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
|
||||
GGML_ASSERT(K->nb[0] == ggml_element_size(K));
|
||||
GGML_ASSERT(V->nb[0] == ggml_element_size(V));
|
||||
|
||||
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||
|
||||
|
|
@ -817,10 +822,10 @@ void launch_fattn(
|
|||
size_t nb12 = K->nb[2];
|
||||
size_t nb13 = K->nb[3];
|
||||
|
||||
const char * V_data = V ? (const char *) V->data : nullptr;
|
||||
size_t nb21 = V ? V->nb[1] : nb11;
|
||||
size_t nb22 = V ? V->nb[2] : nb12;
|
||||
size_t nb23 = V ? V->nb[3] : nb13;
|
||||
const char * V_data = (const char *) V->data;
|
||||
size_t nb21 = V->nb[1];
|
||||
size_t nb22 = V->nb[2];
|
||||
size_t nb23 = V->nb[3];
|
||||
|
||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||
const size_t bs = ggml_blck_size(K->type);
|
||||
|
|
@ -849,36 +854,45 @@ void launch_fattn(
|
|||
K_data = (char *) K_f16.ptr;
|
||||
}
|
||||
|
||||
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
if (ggml_is_contiguously_allocated(V)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
if (need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
if (V_is_K_view) {
|
||||
V_data = K_data;
|
||||
nb21 = nb11;
|
||||
nb22 = nb12;
|
||||
nb23 = nb13;
|
||||
} else {
|
||||
GGML_ASSERT(V->nb[0] == ts);
|
||||
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
||||
const int64_t s01 = nb21 / ts;
|
||||
const int64_t s02 = nb22 / ts;
|
||||
const int64_t s03 = nb23 / ts;
|
||||
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
const size_t ts = ggml_type_size(V->type);
|
||||
|
||||
nb21 = V->ne[0] * sizeof(half);
|
||||
nb22 = V->ne[1] * nb21;
|
||||
nb23 = V->ne[2] * nb22;
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
if (ggml_is_contiguously_allocated(V)) {
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
nb21 = nb21*bs*sizeof(half)/ts;
|
||||
nb22 = nb22*bs*sizeof(half)/ts;
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
} else {
|
||||
GGML_ASSERT(V->nb[0] == ts);
|
||||
to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
|
||||
const int64_t s01 = nb21 / ts;
|
||||
const int64_t s02 = nb22 / ts;
|
||||
const int64_t s03 = nb23 / ts;
|
||||
to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
|
||||
|
||||
nb21 = V->ne[0] * sizeof(half);
|
||||
nb22 = V->ne[1] * nb21;
|
||||
nb23 = V->ne[2] * nb22;
|
||||
}
|
||||
V_data = (char *) V_f16.ptr;
|
||||
}
|
||||
V_data = (char *) V_f16.ptr;
|
||||
}
|
||||
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
||||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
|
||||
const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
|
||||
|
||||
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
||||
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
||||
|
|
@ -953,7 +967,7 @@ void launch_fattn(
|
|||
|
||||
blocks_num.x = ntiles_x;
|
||||
blocks_num.y = parallel_blocks;
|
||||
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
|
||||
blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
|
|
@ -1007,7 +1021,7 @@ void launch_fattn(
|
|||
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue