Compare commits
149 Commits
c5e62b465a
...
ac8f5dcc36
| Author | SHA1 | Date |
|---|---|---|
|
|
ac8f5dcc36 | |
|
|
f1f793ad06 | |
|
|
af5c13841f | |
|
|
277ff5fff7 | |
|
|
384c0076bc | |
|
|
1f34806c44 | |
|
|
887535c33f | |
|
|
d3416a4aa9 | |
|
|
43a4ee4a2c | |
|
|
f851fa5ab0 | |
|
|
f1ac84119c | |
|
|
b069b10ab4 | |
|
|
0c58ba3365 | |
|
|
57ace0d612 | |
|
|
39b27f0da0 | |
|
|
f49e917876 | |
|
|
7c7d6ce5c7 | |
|
|
5208e2d5ba | |
|
|
7992aa7c8e | |
|
|
f1bb125740 | |
|
|
e45f28876e | |
|
|
2bd682bb1f | |
|
|
8744a9f7fa | |
|
|
75e5a9bd01 | |
|
|
805e9ac6a8 | |
|
|
2d39f87bc6 | |
|
|
5ce738cfa7 | |
|
|
3c6a80ffa9 | |
|
|
13820ad653 | |
|
|
05d9a9132a | |
|
|
0c571feee1 | |
|
|
1c139e33df | |
|
|
e760cd49bd | |
|
|
73444564e6 | |
|
|
ba754ce4f3 | |
|
|
5e491258f9 | |
|
|
5fbdefdb9d | |
|
|
9bb5eb30e5 | |
|
|
3e691046dc | |
|
|
775e48abb2 | |
|
|
f54cd74ed0 | |
|
|
f2187bbfa2 | |
|
|
febee580c8 | |
|
|
bccd869968 | |
|
|
721fa41076 | |
|
|
3591e83db9 | |
|
|
fa7dd684bf | |
|
|
e489dd2773 | |
|
|
e10b495dd2 | |
|
|
dbeb6ced46 | |
|
|
378bb8368e | |
|
|
11bd9806bf | |
|
|
e4fbece606 | |
|
|
ecbbdb6608 | |
|
|
b4530b4f8b | |
|
|
0cb1ff419a | |
|
|
b015e4b7dc | |
|
|
7d99222a61 | |
|
|
63c53fe1f1 | |
|
|
8bfb7ed2f2 | |
|
|
0939511846 | |
|
|
9f498d29f1 | |
|
|
ea438d8b0e | |
|
|
c33e4301dc | |
|
|
fac6f0adc3 | |
|
|
a660d4d45d | |
|
|
1fdcb05dc8 | |
|
|
496c3599c6 | |
|
|
5ed2c1b787 | |
|
|
8e0e944b70 | |
|
|
a2db92f41c | |
|
|
6106e9068b | |
|
|
a3fb36fb71 | |
|
|
a1fb3c1509 | |
|
|
9cbc099493 | |
|
|
64ead3fd4f | |
|
|
414bb8d9ed | |
|
|
8809af79a8 | |
|
|
949eca4cba | |
|
|
76885c7697 | |
|
|
df88b2c917 | |
|
|
4e9ebe92e0 | |
|
|
ba70ad8e59 | |
|
|
28b7094750 | |
|
|
311213d209 | |
|
|
68ccd2a899 | |
|
|
09e3a5f07d | |
|
|
d9a48580fc | |
|
|
688de6d7d8 | |
|
|
6f44f47113 | |
|
|
275c08d25d | |
|
|
00a49c2fc1 | |
|
|
8572313000 | |
|
|
27881fbe7b | |
|
|
fa9e415c9b | |
|
|
f95664c76c | |
|
|
417cfc3cc6 | |
|
|
c1f67c19e0 | |
|
|
2b5351a898 | |
|
|
c141ce3533 | |
|
|
1f3d5eb8e9 | |
|
|
70132278cb | |
|
|
a3b4d8d31e | |
|
|
55859a86aa | |
|
|
2dfbbee73f | |
|
|
1e568252b5 | |
|
|
4b1920e9e7 | |
|
|
75dde410a8 | |
|
|
3ea524e9c4 | |
|
|
6d12288037 | |
|
|
a3784e17ad | |
|
|
cc327f5224 | |
|
|
30990788e8 | |
|
|
c68fe36ae2 | |
|
|
475f9879c5 | |
|
|
396f55831c | |
|
|
610e41ae2d | |
|
|
c45df12ee7 | |
|
|
980ddc1e87 | |
|
|
24b553204b | |
|
|
6c90c20cb1 | |
|
|
be25be8ed3 | |
|
|
80a996cfc0 | |
|
|
2715341c1d | |
|
|
66f6d16265 | |
|
|
215ebf6526 | |
|
|
1b69ed44c6 | |
|
|
f931ad883f | |
|
|
f0a480cc22 | |
|
|
15484c9bd6 | |
|
|
6a1f8b4d57 | |
|
|
ac77b8d0e0 | |
|
|
3f99818925 | |
|
|
b70cca2ea3 | |
|
|
3e2f722d11 | |
|
|
2237722056 | |
|
|
16b0f0ae3c | |
|
|
0ca43582e8 | |
|
|
c6255442bb | |
|
|
53a2ccbe12 | |
|
|
2ec76aa8f3 | |
|
|
735886b099 | |
|
|
83a3b7d6a9 | |
|
|
4b0f9d571f | |
|
|
5ffe97be9c | |
|
|
6d84cbb5ab | |
|
|
3877608dc0 | |
|
|
4d772873b9 | |
|
|
8a589317b6 |
|
|
@ -1,97 +0,0 @@
|
|||
ARG UBUNTU_VERSION=24.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=13.1.1
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} AS build
|
||||
|
||||
# CUDA architecture to build for (defaults to all supported archs)
|
||||
ARG CUDA_DOCKER_ARCH=default
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y gcc-14 g++-14 build-essential cmake python3 python3-pip git libssl-dev libgomp1
|
||||
|
||||
ENV CC=gcc-14 CXX=g++-14 CUDAHOSTCXX=g++-14
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \
|
||||
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \
|
||||
fi && \
|
||||
cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \
|
||||
cmake --build build --config Release -j$(nproc)
|
||||
|
||||
RUN mkdir -p /app/lib && \
|
||||
find build -name "*.so*" -exec cp -P {} /app/lib \;
|
||||
|
||||
RUN mkdir -p /app/full \
|
||||
&& cp build/bin/* /app/full \
|
||||
&& cp *.py /app/full \
|
||||
&& cp -r gguf-py /app/full \
|
||||
&& cp -r requirements /app/full \
|
||||
&& cp requirements.txt /app/full \
|
||||
&& cp .devops/tools.sh /app/full/tools.sh
|
||||
|
||||
## Base image
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} AS base
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y libgomp1 curl \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
|
||||
&& find /var/cache -type f -delete
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
|
||||
### Full
|
||||
FROM base AS full
|
||||
|
||||
COPY --from=build /app/full /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y \
|
||||
git \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-wheel \
|
||||
&& pip install --break-system-packages --upgrade setuptools \
|
||||
&& pip install --break-system-packages -r requirements.txt \
|
||||
&& apt autoremove -y \
|
||||
&& apt clean -y \
|
||||
&& rm -rf /tmp/* /var/tmp/* \
|
||||
&& find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \
|
||||
&& find /var/cache -type f -delete
|
||||
|
||||
|
||||
ENTRYPOINT ["/app/tools.sh"]
|
||||
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENTRYPOINT [ "/app/llama-cli" ]
|
||||
|
||||
### Server, Server only
|
||||
FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ]
|
||||
|
||||
ENTRYPOINT [ "/app/llama-server" ]
|
||||
|
|
@ -16,7 +16,7 @@
|
|||
rocmPackages,
|
||||
vulkan-headers,
|
||||
vulkan-loader,
|
||||
curl,
|
||||
openssl,
|
||||
shaderc,
|
||||
useBlas ?
|
||||
builtins.all (x: !x) [
|
||||
|
|
@ -160,7 +160,8 @@ effectiveStdenv.mkDerivation (finalAttrs: {
|
|||
++ optionals useMpi [ mpi ]
|
||||
++ optionals useRocm rocmBuildInputs
|
||||
++ optionals useBlas [ blas ]
|
||||
++ optionals useVulkan vulkanBuildInputs;
|
||||
++ optionals useVulkan vulkanBuildInputs
|
||||
++ [ openssl ];
|
||||
|
||||
cmakeFlags =
|
||||
[
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG ROCM_VERSION=7.2
|
||||
ARG AMDGPU_VERSION=7.2
|
||||
ARG ROCM_VERSION=7.2.1
|
||||
ARG AMDGPU_VERSION=7.2.1
|
||||
|
||||
# Target the ROCm build image
|
||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
|
|
@ -12,11 +12,11 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
|||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
# This is mostly tied to rocBLAS supported archs.
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.0/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.1/reference/system-requirements.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityrad/native_linux/native_linux_compatibility.html
|
||||
# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityryz/native_linux/native_linux_compatibility.html
|
||||
|
||||
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201'
|
||||
ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1150;gfx1200;gfx1201'
|
||||
|
||||
# Set ROCm architectures
|
||||
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,11 @@ IBM zDNN:
|
|||
- any-glob-to-any-file:
|
||||
- ggml/include/ggml-zdnn.h
|
||||
- ggml/src/ggml-zdnn/**
|
||||
AMD ZenDNN:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- ggml/include/ggml-zendnn.h
|
||||
- ggml/src/ggml-zendnn/**
|
||||
documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
|
|
|
|||
|
|
@ -472,6 +472,7 @@ jobs:
|
|||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON \
|
||||
-DGPU_TARGETS="gfx1030" \
|
||||
-DGGML_HIP=ON
|
||||
cmake --build build --config Release -j $(nproc)
|
||||
|
||||
|
|
@ -941,7 +942,7 @@ jobs:
|
|||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
run: |
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70201-81~24.04_amd64.deb"
|
||||
7z x rocwmma.deb
|
||||
7z x data.tar
|
||||
|
||||
|
|
@ -984,12 +985,13 @@ jobs:
|
|||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.1/include/" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DLLAMA_BUILD_BORINGSSL=ON `
|
||||
-DROCM_DIR="${env:HIP_PATH}" `
|
||||
-DGGML_HIP=ON `
|
||||
-DGGML_HIP_ROCWMMA_FATTN=ON `
|
||||
-DGPU_TARGETS="gfx1100" `
|
||||
-DGGML_RPC=ON
|
||||
cmake --build build -j ${env:NUMBER_OF_PROCESSORS}
|
||||
|
||||
|
|
|
|||
|
|
@ -73,10 +73,10 @@ jobs:
|
|||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda-new.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.9.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.9.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.1.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" },
|
||||
{ "tag": "musa", "dockerfile": ".devops/musa.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "intel", "dockerfile": ".devops/intel.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" },
|
||||
{ "tag": "vulkan", "dockerfile": ".devops/vulkan.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" },
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ env:
|
|||
jobs:
|
||||
ubuntu-22-hip-quality-check:
|
||||
runs-on: ubuntu-22.04
|
||||
container: rocm/dev-ubuntu-22.04:7.2
|
||||
container: rocm/dev-ubuntu-22.04:7.2.1
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
|
|
@ -59,7 +59,7 @@ jobs:
|
|||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DGPU_TARGETS=gfx908 \
|
||||
-DGPU_TARGETS=gfx942 \
|
||||
-DGGML_HIP=ON \
|
||||
-DGGML_HIP_EXPORT_METRICS=Off \
|
||||
-DCMAKE_HIP_FLAGS="-Werror -Wno-tautological-compare" \
|
||||
|
|
|
|||
|
|
@ -639,8 +639,8 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- ROCM_VERSION: "7.2"
|
||||
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201"
|
||||
- ROCM_VERSION: "7.2.1"
|
||||
gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1150;gfx1200;gfx1201"
|
||||
build: 'x64'
|
||||
|
||||
steps:
|
||||
|
|
@ -662,7 +662,7 @@ jobs:
|
|||
sudo apt install -y build-essential git cmake wget
|
||||
|
||||
- name: Setup Legacy ROCm
|
||||
if: matrix.ROCM_VERSION == '7.2'
|
||||
if: matrix.ROCM_VERSION == '7.2.1'
|
||||
id: legacy_env
|
||||
run: |
|
||||
sudo mkdir --parents --mode=0755 /etc/apt/keyrings
|
||||
|
|
@ -683,7 +683,7 @@ jobs:
|
|||
sudo apt-get install -y libssl-dev rocm-hip-sdk
|
||||
|
||||
- name: Setup TheRock
|
||||
if: matrix.ROCM_VERSION != '7.2'
|
||||
if: matrix.ROCM_VERSION != '7.2.1'
|
||||
id: therock_env
|
||||
run: |
|
||||
wget https://repo.amd.com/rocm/tarball/therock-dist-linux-gfx1151-${{ matrix.ROCM_VERSION }}.tar.gz
|
||||
|
|
@ -699,7 +699,6 @@ jobs:
|
|||
run: |
|
||||
cmake -B build -S . \
|
||||
-DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \
|
||||
-DCMAKE_HIP_FLAGS="-mllvm --amdgpu-unroll-threshold-local=600" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_NATIVE=OFF \
|
||||
|
|
@ -717,17 +716,20 @@ jobs:
|
|||
id: tag
|
||||
uses: ./.github/actions/get-tag-name
|
||||
|
||||
- name: Get ROCm short version
|
||||
run: echo "ROCM_VERSION_SHORT=$(echo '${{ matrix.ROCM_VERSION }}' | cut -d '.' -f 1,2)" >> $GITHUB_ENV
|
||||
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/bin/
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz
|
||||
path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz
|
||||
name: llama-bin-ubuntu-rocm-${{ env.ROCM_VERSION_SHORT }}-${{ matrix.build }}.tar.gz
|
||||
|
||||
windows-hip:
|
||||
runs-on: windows-2022
|
||||
|
|
@ -749,7 +751,7 @@ jobs:
|
|||
- name: Grab rocWMMA package
|
||||
id: grab_rocwmma
|
||||
run: |
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb"
|
||||
curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70201-81~24.04_amd64.deb"
|
||||
7z x rocwmma.deb
|
||||
7z x data.tar
|
||||
|
||||
|
|
@ -806,7 +808,7 @@ jobs:
|
|||
cmake -G "Unix Makefiles" -B build -S . `
|
||||
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
|
||||
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
|
||||
-DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.1/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
|
||||
-DCMAKE_BUILD_TYPE=Release `
|
||||
-DGGML_BACKEND_DL=ON `
|
||||
-DGGML_NATIVE=OFF `
|
||||
|
|
|
|||
32
ci/run.sh
32
ci/run.sh
|
|
@ -221,7 +221,7 @@ function gg_run_ctest_debug {
|
|||
|
||||
set -e
|
||||
|
||||
# Check cmake and ctest are installed
|
||||
# Check required binaries are installed
|
||||
gg_check_build_requirements
|
||||
|
||||
(cmake -G "${CMAKE_GENERATOR}" -DCMAKE_BUILD_TYPE=Debug ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
|
||||
|
|
@ -252,7 +252,7 @@ function gg_run_ctest_release {
|
|||
|
||||
set -e
|
||||
|
||||
# Check cmake and ctest are installed
|
||||
# Check required binaries are installed
|
||||
gg_check_build_requirements
|
||||
|
||||
(cmake -G "${CMAKE_GENERATOR}" -DCMAKE_BUILD_TYPE=Release ${CMAKE_EXTRA} .. ) 2>&1 | tee -a $OUT/${ci}-cmake.log
|
||||
|
|
@ -627,10 +627,38 @@ function gg_sum_rerank_tiny {
|
|||
}
|
||||
|
||||
function gg_check_build_requirements {
|
||||
if ! command -v git &> /dev/null; then
|
||||
gg_printf 'git not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v git-lfs &> /dev/null; then
|
||||
gg_printf 'git-lfs not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v wget &> /dev/null; then
|
||||
gg_printf 'wget not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v python3 &> /dev/null; then
|
||||
gg_printf 'python3 not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v pip3 &> /dev/null; then
|
||||
gg_printf 'pip3 not found, please install'
|
||||
fi
|
||||
|
||||
if ! python3 -m ensurepip --help &> /dev/null; then
|
||||
gg_printf 'ensurepip not found, please install python3-venv package'
|
||||
fi
|
||||
|
||||
if ! command -v cmake &> /dev/null; then
|
||||
gg_printf 'cmake not found, please install'
|
||||
fi
|
||||
|
||||
if ! command -v ccache &> /dev/null; then
|
||||
gg_printf 'ccache not found, please consider installing for faster builds'
|
||||
fi
|
||||
|
||||
if ! command -v ctest &> /dev/null; then
|
||||
gg_printf 'ctest not found, please install'
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -6,12 +6,111 @@
|
|||
#include "json-schema-to-grammar.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
namespace {
|
||||
|
||||
// Gemma4-specific PEG builder extending the standard chat builder.
|
||||
// Adds value type parsers that use <|\"|> as string delimiters
|
||||
// instead of JSON's double quotes, and disables json-to-schema
|
||||
// conversion for these types.
|
||||
class common_peg_gemma4_builder {
|
||||
common_chat_peg_builder & p_;
|
||||
static constexpr const char * QUOTE = "<|\"|>";
|
||||
|
||||
public:
|
||||
explicit common_peg_gemma4_builder(common_chat_peg_builder & p) : p_(p) {}
|
||||
|
||||
common_peg_parser gemma4_string() {
|
||||
return p_.rule("gemma4-string", [&]() {
|
||||
return p_.literal(QUOTE) + p_.until(QUOTE) + p_.literal(QUOTE);
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_number() {
|
||||
return p_.rule("gemma4-number", [&]() {
|
||||
auto digit1_9 = p_.chars("[1-9]", 1, 1);
|
||||
auto digits = p_.chars("[0-9]");
|
||||
auto int_part = p_.choice({p_.literal("0"), p_.sequence({digit1_9, p_.chars("[0-9]", 0, -1)})});
|
||||
auto frac = p_.sequence({p_.literal("."), digits});
|
||||
auto exp = p_.sequence({p_.choice({p_.literal("e"), p_.literal("E")}),
|
||||
p_.optional(p_.chars("[+-]", 1, 1)), digits});
|
||||
auto not_number_continuation = p_.negate(p_.chars("[0-9.eE+-]", 1, 1));
|
||||
return p_.sequence({p_.optional(p_.literal("-")), int_part, p_.optional(frac),
|
||||
p_.optional(exp), not_number_continuation});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_bool() {
|
||||
return p_.rule("gemma4-bool", [&]() {
|
||||
return p_.choice({p_.literal("true"), p_.literal("false")});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_null() {
|
||||
return p_.rule("gemma4-null", [&]() {
|
||||
return p_.literal("null");
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_dict() {
|
||||
return p_.rule("gemma4-dict", [&]() {
|
||||
auto ws = p_.space();
|
||||
auto key = p_.until(":");
|
||||
auto member = p_.sequence({key, p_.literal(":"), ws, gemma4_value()});
|
||||
auto members = p_.sequence({member, p_.zero_or_more(p_.sequence({p_.literal(","), ws, member}))});
|
||||
return p_.sequence({
|
||||
p_.literal("{"), ws,
|
||||
p_.choice({p_.literal("}"), p_.sequence({members, ws, p_.literal("}")})})
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_array() {
|
||||
return p_.rule("gemma4-array", [&]() {
|
||||
auto ws = p_.space();
|
||||
auto elements = p_.sequence({gemma4_value(), p_.zero_or_more(p_.sequence({p_.literal(","), ws, gemma4_value()}))});
|
||||
return p_.sequence({
|
||||
p_.literal("["), ws,
|
||||
p_.choice({p_.literal("]"), p_.sequence({elements, ws, p_.literal("]")})})
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser gemma4_value() {
|
||||
return p_.rule("gemma4-value", [&]() {
|
||||
return p_.choice({gemma4_string(), gemma4_dict(), gemma4_array(),
|
||||
gemma4_number(), gemma4_bool(), gemma4_null()});
|
||||
});
|
||||
}
|
||||
|
||||
// Select the appropriate value parser based on JSON schema type.
|
||||
// Does NOT use schema() - the gemma4 types are pure PEG without
|
||||
// JSON schema metadata, so GBNF is generated directly from the
|
||||
// PEG structure.
|
||||
common_peg_parser gemma4_value_for_type(const json & schema) {
|
||||
if (!schema.contains("type") || !schema.at("type").is_string()) {
|
||||
return gemma4_value();
|
||||
}
|
||||
std::string type = schema.at("type").get<std::string>();
|
||||
if (type == "string") { return gemma4_string(); }
|
||||
if (type == "number") { return gemma4_number(); }
|
||||
if (type == "integer") { return gemma4_number(); }
|
||||
if (type == "boolean") { return gemma4_bool(); }
|
||||
if (type == "object") { return gemma4_dict(); }
|
||||
if (type == "array") { return gemma4_array(); }
|
||||
return gemma4_value();
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// Helper to iterate over tools/functions
|
||||
static void foreach_function(const json & tools, const std::function<void(const json &)> & fn) {
|
||||
for (const auto & tool : tools) {
|
||||
|
|
@ -43,7 +142,9 @@ common_chat_params peg_generator::generate_parser(const common_chat_template &
|
|||
// Create the result structure
|
||||
common_chat_params data;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.format = (autoparser.tools.format.mode == tool_format::TAG_WITH_GEMMA4_DICT)
|
||||
? COMMON_CHAT_FORMAT_PEG_GEMMA4
|
||||
: COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = autoparser.preserved_tokens;
|
||||
|
||||
auto parser = autoparser.build_parser(inputs);
|
||||
|
|
@ -92,6 +193,7 @@ common_peg_arena autoparser::build_parser(const generation_params & inputs) cons
|
|||
|
||||
ctx.extracting_reasoning = extract_reasoning && reasoning.mode != reasoning_mode::NONE;
|
||||
ctx.content = &content;
|
||||
ctx.reasoning = &reasoning;
|
||||
|
||||
// Build reasoning parser
|
||||
ctx.reasoning_parser = reasoning.build_parser(ctx);
|
||||
|
|
@ -216,6 +318,44 @@ common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_cont
|
|||
p.end();
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_func_parser(common_chat_peg_builder & p, const std::string & name,
|
||||
const common_peg_parser & call_id_section, bool have_call_id,
|
||||
const common_peg_parser & args,
|
||||
std::optional<common_peg_parser> atomic_peek) const {
|
||||
auto open = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix);
|
||||
bool matched_atomic = false;
|
||||
common_peg_parser func_parser = p.eps();
|
||||
|
||||
if (!function.name_suffix.empty()) {
|
||||
func_parser = open + call_id_section + p.space() + args;
|
||||
matched_atomic = true;
|
||||
} else if (have_call_id) {
|
||||
func_parser = p.atomic(open + call_id_section) + p.space() + args;
|
||||
matched_atomic = true;
|
||||
} else if (atomic_peek.has_value()) {
|
||||
func_parser = p.atomic(open + call_id_section + p.space() + *atomic_peek) + args;
|
||||
matched_atomic = true;
|
||||
} else {
|
||||
func_parser = open + call_id_section + p.space() + args;
|
||||
}
|
||||
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close));
|
||||
} else if (!format.per_call_end.empty()) {
|
||||
// When there's no func_close but there is a per_call_end marker, use peek() to ensure
|
||||
// we only emit tool_close when we can actually see the closing marker. This prevents
|
||||
// premature closing during partial parsing when we've seen e.g. "</" which could be
|
||||
// either "</tool_call>" (end) or "<arg_key>" prefix that failed to match.
|
||||
func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end)));
|
||||
} else {
|
||||
func_parser = func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper
|
||||
}
|
||||
if (!matched_atomic) {
|
||||
func_parser = p.atomic(func_parser);
|
||||
}
|
||||
return func_parser;
|
||||
}
|
||||
|
||||
common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const {
|
||||
auto & p = ctx.p;
|
||||
const auto & inputs = ctx.inputs;
|
||||
|
|
@ -229,17 +369,27 @@ common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context
|
|||
const auto & schema = func.contains("parameters") ? func.at("parameters") : json::object();
|
||||
|
||||
// Build call_id parser based on position (if supported)
|
||||
bool have_call_id = false;
|
||||
common_peg_parser call_id_section = p.eps();
|
||||
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
|
||||
!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix;
|
||||
(!call_id.suffix.empty() || !arguments.start.empty())) {
|
||||
if (!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix;
|
||||
} else {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(arguments.start)));
|
||||
}
|
||||
have_call_id = true;
|
||||
}
|
||||
auto args_parser = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
if (!arguments.start.empty()) {
|
||||
args_parser = p.literal(arguments.start) + args_parser;
|
||||
}
|
||||
if (!arguments.end.empty()) {
|
||||
args_parser = args_parser + p.literal(arguments.end);
|
||||
}
|
||||
|
||||
auto func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema));
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + function.close;
|
||||
}
|
||||
auto atomic_peek = !arguments.start.empty() ? std::optional(p.peek(p.literal(arguments.start))) : std::nullopt;
|
||||
auto func_parser = build_func_parser(p, name, call_id_section, have_call_id, args_parser, atomic_peek);
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
|
|
@ -299,12 +449,34 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
|||
for (const auto & [param_name, param_schema] : properties.items()) {
|
||||
bool is_required = required.find(param_name) != required.end();
|
||||
std::string type = "object";
|
||||
auto type_obj = param_schema.contains("type") ? param_schema.at("type") : json::object();
|
||||
if (type_obj.is_string()) {
|
||||
type_obj.get_to(type);
|
||||
} else if (type_obj.is_object()) {
|
||||
if (type_obj.contains("type") && type_obj.at("type").is_string()) {
|
||||
type_obj.at("type").get_to(type);
|
||||
if (param_schema.contains("type")) {
|
||||
const auto & type_obj = param_schema.at("type");
|
||||
if (type_obj.is_string()) {
|
||||
type_obj.get_to(type);
|
||||
} else if (type_obj.is_array()) {
|
||||
// Handle nullable types like ["string", "null"]
|
||||
for (const auto & t : type_obj) {
|
||||
if (t.is_string() && t.get<std::string>() != "null") {
|
||||
type = t.get<std::string>();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (type_obj.is_object()) {
|
||||
if (type_obj.contains("type") && type_obj.at("type").is_string()) {
|
||||
type_obj.at("type").get_to(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Infer string type from enum values when type is unspecified
|
||||
if (type == "object" && param_schema.contains("enum")) {
|
||||
const auto & enum_vals = param_schema.at("enum");
|
||||
if (enum_vals.is_array()) {
|
||||
for (const auto & v : enum_vals) {
|
||||
if (v.is_string()) {
|
||||
type = "string";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -347,52 +519,31 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
|||
args_seq = args_seq + p.repeat(p.space() + any_opt, 0, (int) optional_parsers.size());
|
||||
}
|
||||
|
||||
if (!arguments.start.empty()) {
|
||||
args_seq = p.literal(arguments.start) + args_seq;
|
||||
}
|
||||
if (!arguments.end.empty()) {
|
||||
args_seq = args_seq + p.literal(arguments.end);
|
||||
}
|
||||
|
||||
// Build call_id parser based on position (if supported)
|
||||
common_peg_parser call_id_section = p.eps();
|
||||
bool have_call_id = false;
|
||||
if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() &&
|
||||
!call_id.suffix.empty()) {
|
||||
(!call_id.suffix.empty() || !arguments.start.empty())) {
|
||||
have_call_id = true;
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix);
|
||||
}
|
||||
|
||||
bool matched_atomic = false;
|
||||
common_peg_parser func_parser = p.eps();
|
||||
if (!function.name_suffix.empty()) {
|
||||
func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + args_seq;
|
||||
matched_atomic = true;
|
||||
} else if (have_call_id) {
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section) + p.space() + args_seq;
|
||||
matched_atomic = true;
|
||||
} else if (!arguments.name_prefix.empty() && !required_parsers.empty()) {
|
||||
// Only peek for an arg tag when there are required args that must follow.
|
||||
// When all args are optional, the model may emit no arg tags at all (#20650).
|
||||
func_parser = p.atomic(p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + p.peek(p.literal(arguments.name_prefix))) + args_seq;
|
||||
matched_atomic = true;
|
||||
} else {
|
||||
func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) +
|
||||
call_id_section + p.space() + args_seq;
|
||||
}
|
||||
|
||||
if (!function.close.empty()) {
|
||||
func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close));
|
||||
} else if (!format.per_call_end.empty()) {
|
||||
// When there's no func_close but there is a per_call_end marker, use peek() to ensure
|
||||
// we only emit tool_close when we can actually see the closing marker. This prevents
|
||||
// premature closing during partial parsing when we've seen e.g. "</" which could be
|
||||
// either "</tool_call>" (end) or "<arg_key>" prefix that failed to match.
|
||||
func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end)));
|
||||
} else {
|
||||
func_parser =
|
||||
func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper
|
||||
}
|
||||
if (!matched_atomic) {
|
||||
func_parser = p.atomic(func_parser);
|
||||
if (!call_id.suffix.empty()) {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix)) + call_id.suffix);
|
||||
} else {
|
||||
call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(arguments.start)));
|
||||
}
|
||||
}
|
||||
|
||||
// Only peek for an arg tag when there are required args that must follow.
|
||||
// When all args are optional, the model may emit no arg tags at all (#20650).
|
||||
auto atomic_peek = (!arguments.name_prefix.empty() && !required_parsers.empty()) ?
|
||||
std::optional(p.peek(p.literal(arguments.name_prefix))) : std::nullopt;
|
||||
auto func_parser = build_func_parser(p, name, call_id_section, have_call_id, args_seq, atomic_peek);
|
||||
tool_choice |= p.rule("tool-" + name, func_parser);
|
||||
});
|
||||
|
||||
|
|
@ -440,7 +591,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_
|
|||
const auto & inputs = ctx.inputs;
|
||||
bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
|
||||
// The Gemma4 string quote token used in place of JSON "
|
||||
common_peg_gemma4_builder g4(p);
|
||||
static const std::string QUOTE = "<|\"|>";
|
||||
|
||||
common_peg_parser tool_choice = p.choice();
|
||||
|
|
@ -451,7 +602,6 @@ common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_
|
|||
const auto & params = func.at("parameters");
|
||||
|
||||
if (!params.contains("properties") || !params.at("properties").is_object()) {
|
||||
// No arguments - just match the function name with empty braces
|
||||
auto func_parser = p.atomic(
|
||||
p.tool_open(p.literal(function.name_prefix) + p.tool_name(p.literal(name)) + p.literal("{")) +
|
||||
p.tool_args(p.eps()) +
|
||||
|
|
@ -474,9 +624,33 @@ common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_
|
|||
std::vector<arg_entry> arg_entries;
|
||||
|
||||
for (const auto & [param_name, param_schema] : properties.items()) {
|
||||
std::string type = "object";
|
||||
auto type_v = param_schema.contains("type") ? param_schema.at("type") : json::object();
|
||||
if (type_v.is_string()) type_v.get_to(type);
|
||||
std::string type = "object";
|
||||
if (param_schema.contains("type")) {
|
||||
const auto & type_v = param_schema.at("type");
|
||||
if (type_v.is_string()) {
|
||||
type_v.get_to(type);
|
||||
} else if (type_v.is_array()) {
|
||||
// Handle nullable types like ["string", "null"]
|
||||
for (const auto & t : type_v) {
|
||||
if (t.is_string() && t.get<std::string>() != "null") {
|
||||
type = t.get<std::string>();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Infer string type from enum values when type is unspecified
|
||||
if (type == "object" && param_schema.contains("enum")) {
|
||||
const auto & enum_vals = param_schema.at("enum");
|
||||
if (enum_vals.is_array()) {
|
||||
for (const auto & v : enum_vals) {
|
||||
if (v.is_string()) {
|
||||
type = "string";
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
common_peg_parser value_parser = p.eps();
|
||||
if (type == "string") {
|
||||
|
|
@ -486,9 +660,18 @@ common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_
|
|||
p.tool_arg_string_value(p.schema(p.until(QUOTE),
|
||||
"tool-" + name + "-arg-" + param_name + "-schema", param_schema, true)) +
|
||||
p.literal(QUOTE);
|
||||
} else if (type == "number" || type == "integer") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_number());
|
||||
} else if (type == "boolean") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_bool());
|
||||
} else if (type == "null") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_null());
|
||||
} else if (type == "object") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_dict());
|
||||
} else if (type == "array") {
|
||||
value_parser = p.tool_arg_value(g4.gemma4_array());
|
||||
} else {
|
||||
// Numbers, booleans: raw text up to the next comma or closing brace
|
||||
value_parser = p.tool_arg_value(p.until_one_of({",", "}"}));
|
||||
value_parser = p.tool_arg_value(g4.gemma4_value());
|
||||
}
|
||||
|
||||
auto arg = p.tool_arg(
|
||||
|
|
@ -538,9 +721,9 @@ common_peg_parser analyze_tools::build_tool_parser_tag_gemma4_dict(parser_build_
|
|||
tool_calls = p.optional(tool_calls);
|
||||
}
|
||||
|
||||
auto content_before_tools = p.until(format.per_call_start);
|
||||
auto content_before_tools = p.until_one_of({ format.per_call_start, ctx.reasoning->start });
|
||||
return ctx.reasoning_parser +
|
||||
(force_tools ? p.eps() : p.optional(p.content(content_before_tools))) +
|
||||
(force_tools ? p.eps() : p.optional(p.content(content_before_tools) + p.optional(ctx.reasoning_parser))) +
|
||||
tool_calls + p.end();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "chat-auto-parser.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "common.h"
|
||||
#include "jinja/caps.h"
|
||||
#include "peg-parser.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <optional>
|
||||
|
|
@ -215,12 +216,14 @@ struct tool_id_analysis {
|
|||
// ============================================================================
|
||||
|
||||
struct analyze_content;
|
||||
struct analyze_reasoning;
|
||||
|
||||
struct parser_build_context {
|
||||
common_chat_peg_builder & p;
|
||||
const generation_params & inputs;
|
||||
const generation_params & inputs;
|
||||
common_peg_parser reasoning_parser;
|
||||
bool extracting_reasoning = false;
|
||||
const analyze_reasoning * reasoning = nullptr;
|
||||
const analyze_content * content = nullptr;
|
||||
|
||||
parser_build_context(common_chat_peg_builder & p, const generation_params & inputs);
|
||||
|
|
@ -353,6 +356,13 @@ struct analyze_tools : analyze_base {
|
|||
common_peg_parser build_tool_parser_json_native(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_json(parser_build_context & ctx) const;
|
||||
common_peg_parser build_tool_parser_tag_tagged(parser_build_context & ctx) const;
|
||||
|
||||
// Shared helper: builds func_parser from open+call_id+args, handling atomic wrapping and close.
|
||||
// atomic_peek: if present, used as the peek expression in the third atomicity branch.
|
||||
common_peg_parser build_func_parser(common_chat_peg_builder & p, const std::string & name,
|
||||
const common_peg_parser & call_id_section, bool have_call_id,
|
||||
const common_peg_parser & args,
|
||||
std::optional<common_peg_parser> atomic_peek) const;
|
||||
common_peg_parser build_tool_parser_tag_gemma4_dict(parser_build_context & ctx) const;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ static const std::string ARG_SECOND = "BB_ARG_SND_BB";
|
|||
static const std::string USER_MSG = "U_USER_MSG Hello END_U";
|
||||
static const std::string ASSISTANT_MSG = "A_ASST_MSG I can help END_A";
|
||||
static const std::string THINKING_CONTENT = "REASON_PART I am thinking END_R";
|
||||
static const std::string CALL_ID_001 = "call00001";
|
||||
static const std::string CALL_ID_002 = "call00002";
|
||||
static const std::string CALL_ID_999 = "call99999";
|
||||
|
||||
static std::vector<std::function<void(const common_chat_template & tmpl, autoparser &)>> workarounds(
|
||||
{ // Old reasoning Qwen templates - they don't really display reasoning content, but we still want to
|
||||
|
|
@ -104,10 +107,11 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
|||
analysis.tools.function.name_suffix = "";
|
||||
analysis.tools.arguments.start = "{";
|
||||
analysis.tools.arguments.end = "}";
|
||||
analysis.tools.arguments.name_prefix = "";
|
||||
analysis.tools.arguments.name_suffix = ":";
|
||||
analysis.tools.arguments.separator = ",";
|
||||
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
|
||||
analysis.reasoning.start = "<|channel>thought\n";
|
||||
analysis.reasoning.start = "<|channel>thought";
|
||||
analysis.reasoning.end = "<channel|>";
|
||||
analysis.preserved_tokens.clear();
|
||||
analysis.preserved_tokens.push_back("<|tool_call>");
|
||||
|
|
@ -130,6 +134,7 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
|
|||
analysis.tools.function.name_prefix = "<|tool▁sep|>";
|
||||
analysis.tools.format.per_call_end = "<|tool▁call▁end|>";
|
||||
analysis.tools.function.close = "```";
|
||||
LOG_DBG(ANSI_ORANGE "[Patch: DeepSeek-R1-Distill-Qwen]\n" ANSI_RESET);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
@ -157,7 +162,7 @@ static json user_msg = json{
|
|||
{ "content", USER_MSG }
|
||||
};
|
||||
|
||||
static json build_tool_call(const std::string & name, const json & args, const std::string & id = "call00001") {
|
||||
static json build_tool_call(const std::string & name, const json & args, const std::string & id = CALL_ID_001) {
|
||||
return json{
|
||||
{ "id", id },
|
||||
{ "type", "function" },
|
||||
|
|
@ -165,17 +170,17 @@ static json build_tool_call(const std::string & name, const json & args, const s
|
|||
};
|
||||
}
|
||||
|
||||
static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), "call00001");
|
||||
static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, "call00001");
|
||||
static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, "call00001");
|
||||
static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, "call00001");
|
||||
static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), CALL_ID_001);
|
||||
static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, CALL_ID_001);
|
||||
static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, CALL_ID_001);
|
||||
static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, CALL_ID_001);
|
||||
|
||||
static json first_tool_call =
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00001");
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_001);
|
||||
static json second_tool_call =
|
||||
build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00002");
|
||||
build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_002);
|
||||
static json first_tool_call_alt_id =
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call99999");
|
||||
build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, CALL_ID_999);
|
||||
|
||||
template <typename T>
|
||||
static std::string mode_to_str(T mode) {
|
||||
|
|
@ -214,6 +219,11 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
|
|||
LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str());
|
||||
LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str());
|
||||
LOG_DBG("func_close: '%s'\n", tools.function.close.c_str());
|
||||
LOG_DBG("call_id_prefix: '%s'\n", tools.call_id.prefix.c_str());
|
||||
LOG_DBG("call_id_suffix: '%s'\n", tools.call_id.suffix.c_str());
|
||||
LOG_DBG("call_id_pos: '%s'\n", mode_to_str(tools.call_id.pos).c_str());
|
||||
LOG_DBG("args_start: '%s'\n", tools.arguments.start.c_str());
|
||||
LOG_DBG("args_end: '%s'\n", tools.arguments.end.c_str());
|
||||
LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str());
|
||||
LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str());
|
||||
LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str());
|
||||
|
|
@ -582,12 +592,15 @@ analyze_tools::analyze_tools(const common_chat_template & tmpl,
|
|||
if (caps.supports_parallel_tool_calls) {
|
||||
check_per_call_markers();
|
||||
}
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3a: Function call analysis\n" ANSI_RESET);
|
||||
extract_function_markers();
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3b: Argument analysis\n" ANSI_RESET);
|
||||
if (format.mode == tool_format::TAG_WITH_TAGGED) {
|
||||
analyze_arguments();
|
||||
}
|
||||
extract_argument_separator();
|
||||
extract_args_markers();
|
||||
LOG_DBG(ANSI_ORANGE "Phase 3c: Call id analysis\n" ANSI_RESET);
|
||||
extract_call_id_markers();
|
||||
}
|
||||
}
|
||||
|
|
@ -978,8 +991,6 @@ void analyze_tools::extract_function_markers() {
|
|||
}
|
||||
|
||||
void analyze_tools::analyze_arguments() {
|
||||
LOG_DBG(ANSI_ORANGE "Phase 4: Argument analysis\n" ANSI_RESET);
|
||||
|
||||
extract_argument_name_markers();
|
||||
extract_argument_value_markers();
|
||||
}
|
||||
|
|
@ -1188,7 +1199,7 @@ void analyze_tools::extract_args_markers() {
|
|||
|
||||
const auto & diff = comparison->diff;
|
||||
|
||||
if (format.mode != tool_format::JSON_NATIVE) {
|
||||
if (format.mode == tool_format::JSON_NATIVE) {
|
||||
std::string prefix_marker = !format.section_start.empty() ? format.section_start : format.per_call_start;
|
||||
std::string suffix_marker = !format.section_end.empty() ? format.section_end : format.per_call_end;
|
||||
// these might happen earlier in the tools section as an example or somewhere else, so we need to find the closest ones
|
||||
|
|
@ -1210,6 +1221,10 @@ void analyze_tools::extract_args_markers() {
|
|||
if (find_fun != std::string::npos) {
|
||||
args_start = args_start.substr(find_fun + FUN_FIRST.size(), args_start.size() - find_fun - FUN_FIRST.size());
|
||||
}
|
||||
size_t find_call_id = args_start.find(CALL_ID_001);
|
||||
if (find_call_id != std::string::npos) {
|
||||
args_start = args_start.substr(find_call_id + CALL_ID_001.size(), args_start.size() - find_call_id - CALL_ID_001.size());
|
||||
}
|
||||
arguments.start = args_start;
|
||||
arguments.end = args_end;
|
||||
}
|
||||
|
|
@ -1249,8 +1264,8 @@ void analyze_tools::extract_call_id_markers() {
|
|||
return;
|
||||
}
|
||||
|
||||
std::string id_value_1 = "call00001";
|
||||
std::string id_value_2 = "call99999";
|
||||
std::string id_value_1 = CALL_ID_001;
|
||||
std::string id_value_2 = CALL_ID_999;
|
||||
|
||||
size_t common_id_prefix_len = 0;
|
||||
for (size_t i = 0; i < std::min(id_value_1.length(), id_value_2.length()); i++) {
|
||||
|
|
@ -1349,6 +1364,14 @@ void analyze_tools::extract_call_id_markers() {
|
|||
call_id.suffix = find_first_marker(before_func);
|
||||
}
|
||||
|
||||
if (call_id.prefix == arguments.end) {
|
||||
call_id.prefix = "";
|
||||
}
|
||||
|
||||
if (call_id.suffix == arguments.start) {
|
||||
call_id.suffix = "";
|
||||
}
|
||||
|
||||
// When call_id is detected, per_call_end may have been incorrectly set to include
|
||||
// the call_id_suffix and sample args. Clear it if it starts with call_id_suffix.
|
||||
if (call_id.pos != call_id_position::NONE && !call_id.suffix.empty() &&
|
||||
|
|
|
|||
|
|
@ -75,6 +75,84 @@ static std::string escape_json_string_inner(const std::string & s) {
|
|||
return escaped;
|
||||
}
|
||||
|
||||
static const std::string GEMMA4_QUOTE = "<|\"|>";
|
||||
|
||||
static std::string normalize_gemma4_to_json(const std::string & input) {
|
||||
std::string result;
|
||||
result.reserve(input.size() * 2);
|
||||
|
||||
enum Ctx { DICT, ARRAY };
|
||||
std::vector<Ctx> ctx;
|
||||
|
||||
auto is_ws = [](char c) { return c == ' ' || c == '\t' || c == '\n' || c == '\r'; };
|
||||
auto skip_ws = [&](size_t & pos) {
|
||||
while (pos < input.size() && is_ws(input[pos])) {
|
||||
result += input[pos++];
|
||||
}
|
||||
};
|
||||
|
||||
auto quote_unquoted_key = [&](size_t & pos) {
|
||||
if (pos < input.size() && input[pos] != '"' && input[pos] != '}') {
|
||||
result += '"';
|
||||
while (pos < input.size() && input[pos] != ':' && !is_ws(input[pos])) {
|
||||
result += input[pos++];
|
||||
}
|
||||
result += '"';
|
||||
skip_ws(pos);
|
||||
}
|
||||
};
|
||||
|
||||
size_t i = 0;
|
||||
while (i < input.size()) {
|
||||
if (i + GEMMA4_QUOTE.size() <= input.size() &&
|
||||
input.compare(i, GEMMA4_QUOTE.size(), GEMMA4_QUOTE) == 0) {
|
||||
result += '"';
|
||||
i += GEMMA4_QUOTE.size();
|
||||
continue;
|
||||
}
|
||||
|
||||
char c = input[i];
|
||||
|
||||
if (c == '{') {
|
||||
result += c;
|
||||
ctx.push_back(DICT);
|
||||
++i;
|
||||
skip_ws(i);
|
||||
quote_unquoted_key(i);
|
||||
continue;
|
||||
}
|
||||
if (c == '}') {
|
||||
result += c;
|
||||
if (!ctx.empty()) ctx.pop_back();
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
if (c == '[') {
|
||||
result += c;
|
||||
ctx.push_back(ARRAY);
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
if (c == ']') {
|
||||
result += c;
|
||||
if (!ctx.empty()) ctx.pop_back();
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
if (c == ',' && !ctx.empty() && ctx.back() == DICT) {
|
||||
result += c;
|
||||
++i;
|
||||
skip_ws(i);
|
||||
quote_unquoted_key(i);
|
||||
continue;
|
||||
}
|
||||
|
||||
result += c;
|
||||
++i;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Convert Python-style single-quoted strings to JSON double-quoted strings
|
||||
// Only converts outer string delimiters, properly handling escape sequences:
|
||||
// - {'key': 'value'} -> {"key": "value"}
|
||||
|
|
@ -214,6 +292,14 @@ std::string & common_chat_peg_mapper::args_target() {
|
|||
return (current_tool && !current_tool->name.empty()) ? current_tool->arguments : args_buffer;
|
||||
}
|
||||
|
||||
std::string common_chat_peg_mapper::normalize_container_value(const std::string & input) {
|
||||
return normalize_quotes_to_json(input);
|
||||
}
|
||||
|
||||
std::string common_chat_peg_gemma4_mapper::normalize_container_value(const std::string & input) {
|
||||
return normalize_quotes_to_json(normalize_gemma4_to_json(input));
|
||||
}
|
||||
|
||||
void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena,
|
||||
const common_peg_parse_result & parse_result_arg) {
|
||||
arena.visit(parse_result_arg, [this](const common_peg_ast_node & node) { map(node); });
|
||||
|
|
@ -352,7 +438,7 @@ void common_chat_peg_mapper::map(const common_peg_ast_node & node) {
|
|||
// For potential containers, normalize Python-style single quotes to JSON double quotes
|
||||
bool is_potential_container = value_content[0] == '[' || value_content[0] == '{';
|
||||
if (is_potential_container) {
|
||||
value_content = normalize_quotes_to_json(value_content);
|
||||
value_content = normalize_container_value(value_content);
|
||||
}
|
||||
|
||||
// Try to parse as JSON value (number, bool, null, object, array)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,9 @@ class common_chat_peg_mapper {
|
|||
|
||||
virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result);
|
||||
virtual void map(const common_peg_ast_node & node);
|
||||
private:
|
||||
protected:
|
||||
virtual std::string normalize_container_value(const std::string & input);
|
||||
private:
|
||||
// Tool call handling state
|
||||
std::optional<common_chat_tool_call> pending_tool_call; // Tool call waiting for name
|
||||
common_chat_tool_call * current_tool = nullptr;
|
||||
|
|
@ -30,6 +32,13 @@ class common_chat_peg_mapper {
|
|||
std::string & args_target();
|
||||
};
|
||||
|
||||
class common_chat_peg_gemma4_mapper : public common_chat_peg_mapper {
|
||||
public:
|
||||
common_chat_peg_gemma4_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {}
|
||||
protected:
|
||||
std::string normalize_container_value(const std::string & input) override;
|
||||
};
|
||||
|
||||
struct content_structure;
|
||||
struct tool_call_structure;
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@
|
|||
#include "jinja/caps.h"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <ctime>
|
||||
|
|
@ -694,6 +696,8 @@ const char * common_chat_format_name(common_chat_format format) {
|
|||
return "peg-simple";
|
||||
case COMMON_CHAT_FORMAT_PEG_NATIVE:
|
||||
return "peg-native";
|
||||
case COMMON_CHAT_FORMAT_PEG_GEMMA4:
|
||||
return "peg-gemma4";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
}
|
||||
|
|
@ -760,12 +764,12 @@ static void foreach_parameter(const json &
|
|||
}
|
||||
}
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
static std::string common_chat_template_direct_apply_impl(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override,
|
||||
const std::optional<json> & tools_override,
|
||||
const std::optional<json> & additional_context) {
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt) {
|
||||
jinja::context ctx(tmpl.source());
|
||||
|
||||
nlohmann::ordered_json inp = nlohmann::ordered_json{
|
||||
|
|
@ -812,6 +816,12 @@ std::string common_chat_template_direct_apply(
|
|||
return result;
|
||||
}
|
||||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
return common_chat_template_direct_apply_impl(tmpl, inputs, std::nullopt, std::nullopt, std::nullopt);
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
|
@ -862,7 +872,7 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_
|
|||
data.supports_thinking = true;
|
||||
data.thinking_start_tag = "[THINK]";
|
||||
data.thinking_end_tag = "[/THINK]";
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override = */ adjusted_messages);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
"[THINK]",
|
||||
|
|
@ -945,7 +955,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
|||
adjusted_messages.push_back(msg);
|
||||
}
|
||||
|
||||
auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
auto prompt = common_chat_template_direct_apply_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
|
||||
// Check if we need to replace the return token with end token during
|
||||
// inference and without generation prompt. For more details see:
|
||||
|
|
@ -1072,7 +1082,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.preserved_tokens = {
|
||||
">>>all",
|
||||
|
|
@ -1166,7 +1176,7 @@ static common_chat_params common_chat_params_init_kimi_k2(const common_chat_temp
|
|||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
|
|
@ -1289,7 +1299,7 @@ static common_chat_params common_chat_params_init_lfm2(const common_chat_templat
|
|||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
|
|
@ -1368,7 +1378,7 @@ static common_chat_params common_chat_params_init_lfm2_5(const common_chat_templ
|
|||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
|
|
@ -1439,7 +1449,7 @@ static common_chat_params common_chat_params_init_gigachat_v3(
|
|||
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, inputs);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = false;
|
||||
data.preserved_tokens = {
|
||||
|
|
@ -1621,7 +1631,7 @@ static json common_chat_extra_context() {
|
|||
return ctx;
|
||||
}
|
||||
|
||||
static std::optional<common_chat_params> try_specialized_template(
|
||||
std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
const autoparser::generation_params & params) {
|
||||
|
|
@ -1722,9 +1732,9 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
}
|
||||
|
||||
params.add_generation_prompt = false;
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
std::string no_gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
params.add_generation_prompt = true;
|
||||
std::string gen_prompt = common_chat_template_direct_apply(tmpl, params);
|
||||
std::string gen_prompt = common_chat_template_direct_apply_impl(tmpl, params);
|
||||
auto diff = calculate_diff_split(no_gen_prompt, gen_prompt);
|
||||
params.generation_prompt = diff.right;
|
||||
|
||||
|
|
@ -1758,7 +1768,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
common_chat_params data;
|
||||
auto params_copy = params;
|
||||
params_copy.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||
data.prompt = common_chat_template_direct_apply(tmpl, params_copy);
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, params_copy);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.generation_prompt = params.generation_prompt;
|
||||
auto parser = build_chat_peg_parser([¶ms](common_chat_peg_builder &p) {
|
||||
|
|
@ -1768,7 +1778,7 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
|||
return data;
|
||||
}
|
||||
|
||||
if (auto result = try_specialized_template(tmpl, src, params)) {
|
||||
if (auto result = common_chat_try_specialized_template(tmpl, src, params)) {
|
||||
result->generation_prompt = params.generation_prompt;
|
||||
return *result;
|
||||
}
|
||||
|
|
@ -1905,8 +1915,13 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
|||
// Try to extract any partial results from what was successfully parsed
|
||||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
std::unique_ptr<common_chat_peg_mapper> mapper;
|
||||
if (params.format == COMMON_CHAT_FORMAT_PEG_GEMMA4) {
|
||||
mapper = std::make_unique<common_chat_peg_gemma4_mapper>(msg);
|
||||
} else {
|
||||
mapper = std::make_unique<common_chat_peg_mapper>(msg);
|
||||
}
|
||||
mapper->from_ast(ctx.ast, result);
|
||||
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "\nAST for partial parse (fail):\n%s\n", ctx.ast.dump().c_str());
|
||||
|
|
@ -1921,8 +1936,13 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
|||
common_chat_msg msg;
|
||||
msg.role = "assistant";
|
||||
|
||||
auto mapper = common_chat_peg_mapper(msg);
|
||||
mapper.from_ast(ctx.ast, result);
|
||||
std::unique_ptr<common_chat_peg_mapper> mapper;
|
||||
if (params.format == COMMON_CHAT_FORMAT_PEG_GEMMA4) {
|
||||
mapper = std::make_unique<common_chat_peg_gemma4_mapper>(msg);
|
||||
} else {
|
||||
mapper = std::make_unique<common_chat_peg_mapper>(msg);
|
||||
}
|
||||
mapper->from_ast(ctx.ast, result);
|
||||
|
||||
if (ctx.is_debug()) {
|
||||
fprintf(stderr, "\nAST for %s parse:\n%s\n", is_partial ? "partial" : "full", ctx.ast.dump().c_str());
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "jinja/parser.h"
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
#include "peg-parser.h"
|
||||
#include "jinja/parser.h"
|
||||
#include "jinja/runtime.h"
|
||||
#include "jinja/caps.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
|
||||
#include "nlohmann/json_fwd.hpp"
|
||||
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
|
|
@ -19,8 +19,6 @@
|
|||
using chat_template_caps = jinja::caps;
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
struct common_chat_templates;
|
||||
|
||||
namespace autoparser {
|
||||
|
|
@ -75,41 +73,9 @@ struct common_chat_template {
|
|||
const std::string & bos_token() const { return bos_tok; }
|
||||
const std::string & eos_token() const { return eos_tok; }
|
||||
|
||||
// TODO: this is ugly, refactor it somehow
|
||||
json add_system(const json & messages, const std::string & system_prompt) const {
|
||||
GGML_ASSERT(messages.is_array());
|
||||
auto msgs_copy = messages;
|
||||
if (!caps.supports_system_role) {
|
||||
if (msgs_copy.empty()) {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "user"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else {
|
||||
auto & first_msg = msgs_copy[0];
|
||||
if (!first_msg.contains("content")) {
|
||||
first_msg["content"] = "";
|
||||
}
|
||||
first_msg["content"] = system_prompt + "\n\n"
|
||||
+ first_msg["content"].get<std::string>();
|
||||
}
|
||||
} else {
|
||||
if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") {
|
||||
msgs_copy.insert(msgs_copy.begin(), json{
|
||||
{"role", "system"},
|
||||
{"content", system_prompt}
|
||||
});
|
||||
} else if (msgs_copy[0].at("role") == "system") {
|
||||
msgs_copy[0]["content"] = system_prompt;
|
||||
}
|
||||
}
|
||||
return msgs_copy;
|
||||
}
|
||||
|
||||
chat_template_caps original_caps() const {
|
||||
return caps;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct common_chat_msg {
|
||||
|
|
@ -184,6 +150,7 @@ enum common_chat_format {
|
|||
// These are intended to be parsed by the PEG parser
|
||||
COMMON_CHAT_FORMAT_PEG_SIMPLE,
|
||||
COMMON_CHAT_FORMAT_PEG_NATIVE,
|
||||
COMMON_CHAT_FORMAT_PEG_GEMMA4,
|
||||
|
||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||
};
|
||||
|
|
@ -256,8 +223,8 @@ common_chat_templates_ptr common_chat_templates_init(const struct llama_model *
|
|||
const std::string & bos_token_override = "",
|
||||
const std::string & eos_token_override = "");
|
||||
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||
std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = "");
|
||||
|
||||
struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls,
|
||||
const struct common_chat_templates_inputs & inputs);
|
||||
|
|
@ -274,9 +241,9 @@ std::string common_chat_format_example(const struct common_chat_templates *
|
|||
bool use_jinja,
|
||||
const std::map<std::string, std::string> & chat_template_kwargs);
|
||||
|
||||
const char * common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
const char * common_chat_format_name(common_chat_format format);
|
||||
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params);
|
||||
|
||||
// used by arg and server
|
||||
const char * common_reasoning_format_name(common_reasoning_format format);
|
||||
|
|
@ -302,7 +269,9 @@ std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_tem
|
|||
|
||||
std::string common_chat_template_direct_apply(
|
||||
const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs,
|
||||
const std::optional<json> & messages_override = std::nullopt,
|
||||
const std::optional<json> & tools_override = std::nullopt,
|
||||
const std::optional<json> & additional_context = std::nullopt);
|
||||
const autoparser::generation_params & inputs);
|
||||
|
||||
std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
const common_chat_template & tmpl,
|
||||
const std::string & src,
|
||||
const autoparser::generation_params & params);
|
||||
|
|
|
|||
|
|
@ -306,6 +306,19 @@ value filter_expression::execute_impl(context & ctx) {
|
|||
filter_id = "strip"; // alias
|
||||
}
|
||||
JJ_DEBUG("Applying filter '%s' to %s", filter_id.c_str(), input->type().c_str());
|
||||
// TODO: Refactor filters so this coercion can be done automatically
|
||||
if (!input->is_undefined() && !is_val<value_string>(input) && (
|
||||
filter_id == "capitalize" ||
|
||||
filter_id == "lower" ||
|
||||
filter_id == "replace" ||
|
||||
filter_id == "strip" ||
|
||||
filter_id == "title" ||
|
||||
filter_id == "upper" ||
|
||||
filter_id == "wordcount"
|
||||
)) {
|
||||
JJ_DEBUG("Coercing %s to String for '%s' filter", input->type().c_str(), filter_id.c_str());
|
||||
input = mk_val<value_string>(input->as_string());
|
||||
}
|
||||
return try_builtin_func(ctx, filter_id, input)->invoke(func_args(ctx));
|
||||
|
||||
} else if (is_stmt<call_expression>(filter)) {
|
||||
|
|
|
|||
|
|
@ -465,8 +465,9 @@ const func_builtins & value_int_t::get_builtins() const {
|
|||
double val = static_cast<double>(args.get_pos(0)->as_int());
|
||||
return mk_val<value_float>(val);
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"safe", tojson},
|
||||
{"string", tojson},
|
||||
{"tojson", tojson},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
|
@ -485,8 +486,9 @@ const func_builtins & value_float_t::get_builtins() const {
|
|||
int64_t val = static_cast<int64_t>(args.get_pos(0)->as_float());
|
||||
return mk_val<value_int>(val);
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"safe", tojson},
|
||||
{"string", tojson},
|
||||
{"tojson", tojson},
|
||||
};
|
||||
return builtins;
|
||||
}
|
||||
|
|
@ -771,6 +773,11 @@ const func_builtins & value_string_t::get_builtins() const {
|
|||
|
||||
|
||||
const func_builtins & value_bool_t::get_builtins() const {
|
||||
static const func_handler tostring = [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_bool>();
|
||||
bool val = args.get_pos(0)->as_bool();
|
||||
return mk_val<value_string>(val ? "True" : "False");
|
||||
};
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"int", [](const func_args & args) -> value {
|
||||
|
|
@ -783,11 +790,8 @@ const func_builtins & value_bool_t::get_builtins() const {
|
|||
bool val = args.get_pos(0)->as_bool();
|
||||
return mk_val<value_float>(val ? 1.0 : 0.0);
|
||||
}},
|
||||
{"string", [](const func_args & args) -> value {
|
||||
args.ensure_vals<value_bool>();
|
||||
bool val = args.get_pos(0)->as_bool();
|
||||
return mk_val<value_string>(val ? "True" : "False");
|
||||
}},
|
||||
{"safe", tostring},
|
||||
{"string", tostring},
|
||||
{"tojson", tojson},
|
||||
};
|
||||
return builtins;
|
||||
|
|
@ -1100,18 +1104,14 @@ const func_builtins & value_object_t::get_builtins() const {
|
|||
}
|
||||
|
||||
const func_builtins & value_none_t::get_builtins() const {
|
||||
static const func_handler tostring = [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
};
|
||||
static const func_builtins builtins = {
|
||||
{"default", default_value},
|
||||
{"tojson", tojson},
|
||||
{"string", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"safe", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"strip", [](const func_args &) -> value {
|
||||
return mk_val<value_string>("None");
|
||||
}},
|
||||
{"string", tostring},
|
||||
{"safe", tostring},
|
||||
{"items", empty_value_fn<value_array>},
|
||||
{"map", empty_value_fn<value_array>},
|
||||
{"reject", empty_value_fn<value_array>},
|
||||
|
|
|
|||
|
|
@ -1561,7 +1561,23 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
|||
if (!s.schema) {
|
||||
return true;
|
||||
}
|
||||
if (s.raw && s.schema->contains("type") && s.schema->at("type").is_string() && s.schema->at("type") == "string") {
|
||||
if (s.raw && s.schema->contains("type")) {
|
||||
const auto & type_val = s.schema->at("type");
|
||||
if (type_val.is_string() && type_val == "string") {
|
||||
return true;
|
||||
}
|
||||
// Handle nullable types like ["string", "null"] - delegate when the
|
||||
// non-null type is string, since the tagged format uses raw text
|
||||
if (type_val.is_array()) {
|
||||
for (const auto & t : type_val) {
|
||||
if (t.is_string() && t.get<std::string>() != "null") {
|
||||
return t.get<std::string>() == "string";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Delegate for enum schemas in raw mode - enum values are literal strings
|
||||
if (s.raw && !s.schema->contains("type") && s.schema->contains("enum")) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -7464,9 +7464,6 @@ class Gemma4Model(Gemma3Model):
|
|||
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
|
||||
# TODO @ngxson : there are some known (rare) issues with the tokenizer during development
|
||||
# but I don't have time to dive into them right now;
|
||||
# using a dedicated tokenizer name so that we can fix later without re-converting GGUF
|
||||
self.gguf_writer.add_tokenizer_model("gemma4")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
|
|
|
|||
|
|
@ -57,13 +57,14 @@ ZenDNN is optimized for AMD EPYC™ processors and AMD Ryzen™ processors based
|
|||
|
||||
## Supported Operations
|
||||
|
||||
The ZenDNN backend currently accelerates **matrix multiplication (MUL_MAT)** operations only. Other operations are handled by the standard CPU backend.
|
||||
The ZenDNN backend accelerates **matrix multiplication (MUL_MAT)** and **expert-based matrix multiplication (MUL_MAT_ID)** operations. Other operations are handled by the standard CPU backend.
|
||||
|
||||
| Operation | Status | Notes |
|
||||
|:-------------|:-------:|:----------------------------------------------:|
|
||||
| MUL_MAT | Support | Accelerated via ZenDNN LowOHA MatMul |
|
||||
| MUL_MAT_ID | Support | Accelerated via ZenDNN LowOHA MatMul (MoE) |
|
||||
|
||||
*Note:* Since only MUL_MAT is accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs).
|
||||
*Note:* Since MUL_MAT and MUL_MAT_ID are accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs and Mixture-of-Experts models).
|
||||
|
||||
## DataType Supports
|
||||
|
||||
|
|
@ -181,7 +182,7 @@ For detailed profiling and logging options, refer to the [ZenDNN Logging Documen
|
|||
|
||||
## Known Issues
|
||||
|
||||
- **Limited operation support**: Currently only matrix multiplication (MUL_MAT) is accelerated via ZenDNN. Other operations fall back to the standard CPU backend.
|
||||
- **Limited operation support**: Currently matrix multiplication (MUL_MAT) and expert-based matrix multiplication (MUL_MAT_ID) are accelerated via ZenDNN. Other operations fall back to the standard CPU backend. Future updates may expand supported operations.
|
||||
- **BF16 support**: BF16 operations require AMD Zen 4 or Zen 5 architecture (EPYC 9004/9005 series). On older CPUs, operations will use FP32.
|
||||
- **NUMA awareness**: For multi-socket systems, manual NUMA binding may be required for optimal performance.
|
||||
|
||||
|
|
@ -216,4 +217,4 @@ Please add the **[ZenDNN]** prefix/tag in issues/PRs titles to help the ZenDNN-t
|
|||
|
||||
## TODO
|
||||
|
||||
- Expand operation support beyond MUL_MAT (attention operations, activations, etc.)
|
||||
- Expand operation support beyond MUL_MAT and MUL_MAT_ID (attention operations, activations, etc.)
|
||||
|
|
|
|||
|
|
@ -389,7 +389,7 @@ You can download it from your Linux distro's package manager or from here: [ROCm
|
|||
|
||||
|
||||
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
|
||||
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3.
|
||||
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3. Note that [`HSA_OVERRIDE_GFX_VERSION`] is [not supported on Windows](https://github.com/ROCm/ROCm/issues/2654)
|
||||
|
||||
### Unified Memory
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ Legend:
|
|||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ❌ | 🟡 | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | ✅ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
|
|
|||
9986
docs/ops/ZenDNN.csv
9986
docs/ops/ZenDNN.csv
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,730 @@
|
|||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
||||
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
||||
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
|
||||
constexpr unsigned int SWIZZLE_BITS_2 = 2;
|
||||
|
||||
typedef struct{
|
||||
unsigned int n; //batch size
|
||||
unsigned int c; //number if channels
|
||||
unsigned int h; //height
|
||||
unsigned int w; //width
|
||||
unsigned int k; //number of filters
|
||||
unsigned int r; //filter height
|
||||
unsigned int s; //filter width
|
||||
unsigned int u; //stride height
|
||||
unsigned int v; //stride width
|
||||
unsigned int p; //padding height
|
||||
unsigned int q; //padding width
|
||||
unsigned int d_h; //dilation height
|
||||
unsigned int d_w; //dilation width
|
||||
unsigned int Oh; //output height
|
||||
unsigned int Ow; //output width
|
||||
uint3 SC_fastdiv;
|
||||
uint3 OW_fastdiv;
|
||||
uint3 C_fastdiv;
|
||||
uint3 RS_fastdiv;
|
||||
uint3 S_fastdiv;
|
||||
uint3 OHOW_fastdiv;
|
||||
int64_t inc_next[3];
|
||||
unsigned int inChannelOffset;
|
||||
unsigned int weightKOffset;
|
||||
unsigned int PQ;
|
||||
unsigned int KPQ;
|
||||
unsigned int NKPQ;
|
||||
unsigned int CHW;
|
||||
} param_t;
|
||||
|
||||
|
||||
/// Clears the predicates
|
||||
|
||||
template<const unsigned int K_STRID>
|
||||
__device__ void clear_mask(unsigned int masks_[][2], bool clear = true) {
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < K_STRID; ++s) {
|
||||
masks_[s][0] = clear ? 0 : masks_[s][0];
|
||||
masks_[s][1] = clear ? 0 : masks_[s][1];
|
||||
}
|
||||
}
|
||||
|
||||
template<const unsigned int K_STRID>
|
||||
__device__ void add_byte_offset(int64_t element_offset[], const int64_t offset) {
|
||||
#pragma unroll
|
||||
for (int s = 0; s < K_STRID; ++s) {
|
||||
element_offset[s] += offset;
|
||||
}
|
||||
}
|
||||
|
||||
template<const unsigned int TILE_ROWS,
|
||||
const unsigned int TILE_COLS,
|
||||
const unsigned int A_K_STRID,
|
||||
const unsigned int ROW_STEP>
|
||||
__device__ void prepareIteratorA(unsigned int thread_row,
|
||||
unsigned int masks[][2],
|
||||
int64_t element_offset[],
|
||||
const param_t param) {
|
||||
int offset_n[A_K_STRID];
|
||||
int offset_p[A_K_STRID];
|
||||
int offset_q[A_K_STRID];
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < A_K_STRID; ++s) {
|
||||
|
||||
const unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
||||
offset_n[s] = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
offset_p[s] = fastdiv(npq_res, param.OW_fastdiv); //* param.u - param.p;
|
||||
offset_q[s] = fastmodulo(npq_res, param.OW_fastdiv); // * param.v - param.q;
|
||||
const int h = offset_p[s] * (int)param.u - (int) param.p;
|
||||
const int w = offset_q[s] * (int)param.v - (int) param.q;
|
||||
|
||||
element_offset[s] = offset_n[s] * (int64_t)param.CHW + h * (int64_t)(param.inChannelOffset) + w * (int64_t)param.c;
|
||||
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
|
||||
clear_mask<A_K_STRID>(masks);
|
||||
|
||||
for (int r = 0; r < param.r; ++r) {
|
||||
#pragma unroll
|
||||
for (int s_idx = 0; s_idx < A_K_STRID; ++s_idx) {
|
||||
const int h = offset_p[s_idx] * param.u - param.p + r * param.d_h;
|
||||
|
||||
bool pred = (offset_n[s_idx] < param.n && h >= 0 && h < param.h);
|
||||
masks[s_idx][0] |= (pred << r);
|
||||
}
|
||||
}
|
||||
|
||||
for (int s = 0; s < param.s; ++s) {
|
||||
#pragma unroll
|
||||
for (int s_idx = 0; s_idx < A_K_STRID; ++s_idx) {
|
||||
const int w = offset_q[s_idx] * param.v - param.q + s * param.d_w;
|
||||
bool pred = (w >= 0 && w < param.w);
|
||||
masks[s_idx][1] |= (pred << s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int preload=16>
|
||||
__device__ void cp_async_zfill(void *ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
unsigned int smem_ptr;
|
||||
int src_in_bytes = pred_guard ? preload : 0;
|
||||
|
||||
asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 "
|
||||
"%0, smem_ptr; }\n"
|
||||
: "=r"(smem_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr),
|
||||
"l"(global_ptr),
|
||||
"n"(preload), "r"(src_in_bytes));
|
||||
#else
|
||||
GGML_UNUSED(ptr);
|
||||
GGML_UNUSED(global_ptr);
|
||||
GGML_UNUSED(pred_guard);
|
||||
#endif
|
||||
}
|
||||
|
||||
// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int NUM_THREADS>
|
||||
__device__ __forceinline__ void tileMemcpySwizzleB(
|
||||
const half* __restrict__ src,
|
||||
half* __restrict__ dst,
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
const unsigned int curC,
|
||||
const int64_t ki,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
unsigned int thread_row,
|
||||
const unsigned int thread_col,
|
||||
param_t param
|
||||
) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||
|
||||
constexpr unsigned int TILE_COLS = 32;
|
||||
|
||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++) {
|
||||
// apply swizzle to the dst index
|
||||
const unsigned int src_index = thread_row * param.weightKOffset + ki;
|
||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[src_index]),
|
||||
thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k);
|
||||
|
||||
#else
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k) {
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
} else { // read 4 halves
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
#endif
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(curR);
|
||||
GGML_UNUSED(curS);
|
||||
GGML_UNUSED(ki);
|
||||
GGML_UNUSED(start_k);
|
||||
GGML_UNUSED(end_k);
|
||||
GGML_UNUSED(thread_row);
|
||||
GGML_UNUSED(thread_col);
|
||||
GGML_UNUSED(param);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
// this is a special case of the above for when TILE_COLS == 32
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int NUM_THREADS>
|
||||
__device__ __forceinline__ unsigned int tileMemcpySwizzleA(
|
||||
const half* __restrict__ src,
|
||||
half* __restrict__ dst,
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
unsigned int masks[][2],
|
||||
const int64_t element_offset[],
|
||||
unsigned int thread_row,
|
||||
const unsigned int thread_col,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
param_t param
|
||||
) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||
|
||||
constexpr unsigned int TILE_COLS = 32;
|
||||
|
||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
|
||||
const unsigned int curC = start_k+thread_col*8;
|
||||
clear_mask<NUM_ITERS>(masks, curC >= end_k);
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++) {
|
||||
bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
|
||||
// apply swizzle to the dst index
|
||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[element_offset[i]+curC]), valid);
|
||||
#else
|
||||
if (valid) {
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[element_offset[i]+curC])[0];
|
||||
} else {
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
#endif
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
return curC;
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(curR);
|
||||
GGML_UNUSED(curS);
|
||||
GGML_UNUSED(start_k);
|
||||
GGML_UNUSED(end_k);
|
||||
GGML_UNUSED(masks);
|
||||
GGML_UNUSED(element_offset);
|
||||
GGML_UNUSED(thread_row);
|
||||
GGML_UNUSED(thread_col);
|
||||
GGML_UNUSED(param);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int TILE_COLS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ unsigned int tileMemcpyLoadA(
|
||||
const half* __restrict__ src,
|
||||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
unsigned int masks[][2],
|
||||
const int64_t element_offset[],
|
||||
unsigned int thread_row,
|
||||
const unsigned int thread_col,
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
unsigned int oldC,
|
||||
param_t param
|
||||
) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
|
||||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const unsigned int curC = start_k+block_k+thread_col*8;
|
||||
if (curC > oldC)
|
||||
clear_mask<NUM_ITERS>(masks, curC >= end_k);
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++) {
|
||||
bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
|
||||
if (valid) {
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[element_offset[i]+curC])[0];
|
||||
} else{
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
}
|
||||
return curC;
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst_reg);
|
||||
GGML_UNUSED(block_k);
|
||||
GGML_UNUSED(curR);
|
||||
GGML_UNUSED(curS);
|
||||
GGML_UNUSED(start_k);
|
||||
GGML_UNUSED(end_k);
|
||||
GGML_UNUSED(masks);
|
||||
GGML_UNUSED(element_offset);
|
||||
GGML_UNUSED(thread_row);
|
||||
GGML_UNUSED(thread_col);
|
||||
GGML_UNUSED(oldC);
|
||||
GGML_UNUSED(param);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int TILE_COLS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA(
|
||||
const half* __restrict__ src,
|
||||
half* __restrict__ dst,
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
unsigned int masks[][2],
|
||||
const int64_t element_offset[],
|
||||
unsigned int thread_row,
|
||||
const unsigned int thread_col,
|
||||
unsigned int iter_idx,
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
unsigned int oldC,
|
||||
param_t param
|
||||
) {
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
constexpr unsigned int ITER_STEPS = ROW_STEP * TILE_COLS_VECTORIZED;
|
||||
|
||||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const unsigned int curC = start_k+block_k+thread_col*8;
|
||||
if (curC > oldC)
|
||||
clear_mask<NUM_ITERS>(masks, curC >= end_k);
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++) {
|
||||
bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
|
||||
unsigned int dst_index = iter_idx;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
|
||||
cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[element_offset[i]+curC]), valid);
|
||||
iter_idx += ITER_STEPS;
|
||||
}
|
||||
return curC;
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(block_k);
|
||||
GGML_UNUSED(curR);
|
||||
GGML_UNUSED(curS);
|
||||
GGML_UNUSED(start_k);
|
||||
GGML_UNUSED(end_k);
|
||||
GGML_UNUSED(masks);
|
||||
GGML_UNUSED(element_offset);
|
||||
GGML_UNUSED(thread_row);
|
||||
GGML_UNUSED(thread_col);
|
||||
GGML_UNUSED(iter_idx);
|
||||
GGML_UNUSED(oldC);
|
||||
GGML_UNUSED(param);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int TILE_COLS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ void tileMemcpyLoadB(
|
||||
const half* __restrict__ src,
|
||||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
const unsigned int curC,
|
||||
const int64_t ki,
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
unsigned int thread_row,
|
||||
const unsigned int thread_col,
|
||||
param_t param
|
||||
) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
|
||||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
unsigned int iter_idx = thread_row * param.weightKOffset + ki;
|
||||
unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS;
|
||||
const int ITER_STEPS = ROW_STEP * param.weightKOffset;
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++) {
|
||||
const unsigned int src_index = iter_idx;
|
||||
if (krow_idx < param.k && curC < end_k) {
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
} else { // read 4 halves
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
krow_idx += ROW_STEP;
|
||||
iter_idx += ITER_STEPS;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst_reg);
|
||||
GGML_UNUSED(block_k);
|
||||
GGML_UNUSED(curR);
|
||||
GGML_UNUSED(curS);
|
||||
GGML_UNUSED(ki);
|
||||
GGML_UNUSED(start_k);
|
||||
GGML_UNUSED(end_k);
|
||||
GGML_UNUSED(thread_row);
|
||||
GGML_UNUSED(thread_col);
|
||||
GGML_UNUSED(param);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int TILE_COLS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ void tileMemcpyAsyncLoadB(
|
||||
const half *src,
|
||||
half *dst,
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
const unsigned int curC,
|
||||
const int64_t ki,
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
unsigned int thread_row,
|
||||
const unsigned int thread_col,
|
||||
unsigned int iter_src_idx,
|
||||
unsigned int iter_dst_idx,
|
||||
unsigned int krow_idx,
|
||||
const int ITER_SRC_STEPS,
|
||||
param_t param
|
||||
) {
|
||||
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
constexpr unsigned int ITER_DST_STEPS = ROW_STEP * TILE_COLS_VECTORIZED;
|
||||
|
||||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
iter_src_idx += ki;
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++) {
|
||||
const unsigned int src_index = iter_src_idx;
|
||||
unsigned int dst_index = iter_dst_idx;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
|
||||
cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[src_index]), krow_idx < param.k && curC < end_k);
|
||||
|
||||
iter_src_idx += ITER_SRC_STEPS;
|
||||
krow_idx += ROW_STEP;
|
||||
iter_dst_idx += ITER_DST_STEPS;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(block_k);
|
||||
GGML_UNUSED(curR);
|
||||
GGML_UNUSED(curS);
|
||||
GGML_UNUSED(ki);
|
||||
GGML_UNUSED(start_k);
|
||||
GGML_UNUSED(end_k);
|
||||
GGML_UNUSED(thread_row);
|
||||
GGML_UNUSED(thread_col);
|
||||
GGML_UNUSED(iter_src_idx);
|
||||
GGML_UNUSED(iter_dst_idx);
|
||||
GGML_UNUSED(krow_idx);
|
||||
GGML_UNUSED(ITER_SRC_STEPS);
|
||||
GGML_UNUSED(param);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
// same as above but without the swizzle
|
||||
|
||||
// this is a special case of the above for when TILE_COLS == 32
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ void tileMemcpySwizzleStore(
|
||||
const float4 (&src_reg)[ELEMENTS_PER_THREAD],
|
||||
half* __restrict__ dst,
|
||||
unsigned int thread_row,
|
||||
const unsigned int thread_col
|
||||
) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||
|
||||
constexpr unsigned int TILE_COLS = 32;
|
||||
|
||||
// reinterpret input/output as float4
|
||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
constexpr unsigned int ITER_STEPS = ROW_STEP * TILE_COLS_VECTORIZED;
|
||||
|
||||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++) {
|
||||
// apply swizzle to the dst index
|
||||
unsigned int dst_index = iter_idx;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
dst_float4[dst_index] = src_reg[i];
|
||||
iter_idx += ITER_STEPS;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(src_reg);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(thread_row);
|
||||
GGML_UNUSED(thread_col);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename T, const int BN, const int rowStrideA, const int layout,
|
||||
const bool vec_load, const int ksplit, const int PAD>
|
||||
__device__ __forceinline__ void loadFilter(const T * __restrict__ kernel,
|
||||
T * __restrict__ smemweight,
|
||||
const unsigned int by,
|
||||
const unsigned int innerRowA,
|
||||
const unsigned int innerColA,
|
||||
const unsigned int weightKOffset,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const param_t param){
|
||||
|
||||
const unsigned int weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4;
|
||||
const unsigned int kidx = start_k + innerColA * 4;
|
||||
#pragma unroll
|
||||
for (int offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
|
||||
const unsigned int nidx = by * BN + innerRowA + offset;
|
||||
if (vec_load) {
|
||||
if (nidx < param.k && kidx < end_k) {
|
||||
if constexpr (std::is_same_v<T, float>){
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&kernel[nidx * weightKOffset + kidx])[0];
|
||||
smemweight[weight_sts_addr + offset + 0] = tmp.x;
|
||||
smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y;
|
||||
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z;
|
||||
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w;
|
||||
} else { // read 4 halves
|
||||
float2 tmp = reinterpret_cast<const float2 *>(&kernel[nidx * weightKOffset + kidx])[0];
|
||||
const half *val = reinterpret_cast<const half *>(&tmp);
|
||||
smemweight[weight_sts_addr + offset + 0] = val[0];
|
||||
smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1];
|
||||
smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2];
|
||||
smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3];
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
if (nidx < param.k && kidx + i < end_k) {
|
||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[nidx * weightKOffset + kidx + i];
|
||||
} else {
|
||||
smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<const int BM, const int rowStrideA, const int layout,
|
||||
const bool vec_load, const int ksplit, const int PAD>
|
||||
__device__ __forceinline__ void loadInput(const float * __restrict__ input,
|
||||
float * __restrict__ smeminput,
|
||||
const unsigned int bx,
|
||||
const unsigned int innerRowA,
|
||||
const unsigned int innerColA,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int PQ,
|
||||
const unsigned int CHW,
|
||||
const unsigned int inChannelOffset,
|
||||
const param_t param) {
|
||||
const unsigned int input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4;
|
||||
const unsigned int kidx = start_k + innerColA * 4;
|
||||
#pragma unroll
|
||||
for (unsigned int offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
||||
const unsigned int midx = bx * BM + innerRowA + offset;
|
||||
int n = (ksplit > 0) ? midx / PQ : blockIdx.z;
|
||||
const unsigned int npq_res = midx % PQ;
|
||||
const int posh_ori = fastdiv((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.u - param.p;
|
||||
const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.v - param.q;
|
||||
const unsigned int inOffset = n * CHW;
|
||||
if (vec_load) {
|
||||
const unsigned int cur0 = fastdiv(kidx,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const unsigned int cur1 = fastdiv(fastmodulo(kidx,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const unsigned int cur2 = fastmodulo(fastmodulo(kidx,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const unsigned int curC = layout == 0 ? cur2 : cur0;
|
||||
const unsigned int curR = layout == 0 ? cur0 : cur1;
|
||||
const unsigned int curS = layout == 0 ? cur1 : cur2;
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx < end_k) {
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
float4 tmp = reinterpret_cast<const float4 *>(&input[inOffset + inOffsetTmp])[0];
|
||||
smeminput[input_sts_addr + offset + 0] = tmp.x;
|
||||
smeminput[input_sts_addr + offset + BM+PAD] = tmp.y;
|
||||
smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z;
|
||||
smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w;
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i)
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
const unsigned int cur0 = fastdiv(kidx + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset
|
||||
const unsigned int cur1 = fastdiv(fastmodulo(kidx + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const unsigned int cur2 = fastmodulo(fastmodulo(kidx + i,
|
||||
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv),
|
||||
layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset
|
||||
const unsigned int curC = layout == 0 ? cur2 : cur0;
|
||||
const unsigned int curR = layout == 0 ? cur0 : cur1;
|
||||
const unsigned int curS = layout == 0 ? cur1 : cur2;
|
||||
const int curH = posh_ori + curR * param.d_h; // input h
|
||||
const int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx + i < end_k) {
|
||||
int inOffsetTmp = layout == 0 ?
|
||||
curH * inChannelOffset + curW * param.c + curC:
|
||||
curC * inChannelOffset + curH * param.w + curW;
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp];
|
||||
} else {
|
||||
smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256
|
||||
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
#include "common.cuh"
|
||||
|
||||
|
||||
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
|
||||
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(const void * generic_ptr) {
|
||||
#ifdef CP_ASYNC_AVAILABLE
|
||||
return __cvta_generic_to_shared(generic_ptr);
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
#include "ggml-cuda/concat.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/conv2d.cuh"
|
||||
#include "ggml-cuda/conv2d-implicit.cuh"
|
||||
#include "ggml-cuda/conv2d-dw.cuh"
|
||||
#include "ggml-cuda/conv2d-transpose.cuh"
|
||||
#include "ggml-cuda/convert.cuh"
|
||||
|
|
@ -2741,7 +2742,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
ggml_cuda_op_im2col_3d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_2D:
|
||||
ggml_cuda_op_conv2d(ctx, dst);
|
||||
ggml_cuda_op_conv2d_implicit(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
ggml_cuda_op_conv2d_dw(ctx, dst);
|
||||
|
|
|
|||
|
|
@ -1009,8 +1009,8 @@ public:
|
|||
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
||||
|
||||
struct stored_graph {
|
||||
ggml_context_ptr ctx_ptr;
|
||||
ggml_cgraph * graph;
|
||||
std::vector<uint8_t> buffer;
|
||||
ggml_cgraph * graph;
|
||||
};
|
||||
|
||||
private:
|
||||
|
|
@ -1518,10 +1518,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
|||
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
|
||||
|
||||
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
|
||||
|
||||
if (stored_graphs[device].buffer.size() < buf_size) {
|
||||
stored_graphs[device].buffer.resize(buf_size);
|
||||
}
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_size,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.mem_buffer =*/ stored_graphs[device].buffer.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||
|
|
@ -1551,7 +1553,6 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
|
|||
}
|
||||
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
|
||||
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
|
||||
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
|
||||
stored_graphs[device].graph = graph;
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
|
|||
ExternalProject_Add(
|
||||
zendnn
|
||||
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
|
||||
GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08
|
||||
GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13
|
||||
PREFIX ${ZENDNN_PREFIX}
|
||||
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
|
||||
BINARY_DIR ${ZENDNN_BUILD_DIR}
|
||||
|
|
|
|||
|
|
@ -190,6 +190,170 @@ static void ggml_zendnn_compute_forward_mul_mat(
|
|||
}
|
||||
}
|
||||
|
||||
struct mmid_row_mapping {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
};
|
||||
|
||||
static void ggml_zendnn_compute_forward_mul_mat_id(
|
||||
ggml_backend_zendnn_context * ctx,
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0]; // expert weights
|
||||
const ggml_tensor * src1 = dst->src[1]; // inputs
|
||||
const ggml_tensor * ids = dst->src[2]; // expert ids
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
// exit for no tokens to process
|
||||
if (ne2 == 0 || ne11 == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_type const vec_dot_type = src0->type;
|
||||
ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
GGML_ASSERT(ne3 == 1);
|
||||
|
||||
// row groups
|
||||
const int n_ids = ids->ne[0]; // n_expert_used
|
||||
const int n_as = ne02; // n_experts
|
||||
|
||||
std::vector<int64_t> matrix_row_counts(n_as, 0);
|
||||
std::vector<std::vector<mmid_row_mapping>> matrix_rows(n_as);
|
||||
|
||||
int64_t max_rows = 0;
|
||||
// group rows by expert (preprocessing step)
|
||||
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
||||
for (int id = 0; id < n_ids; ++id) {
|
||||
const int32_t i02 = *(const int32_t *)((const char *)ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
||||
|
||||
matrix_rows[i02].push_back({id, iid1});
|
||||
matrix_row_counts[i02]++;
|
||||
if (matrix_row_counts[i02] > max_rows) {
|
||||
max_rows = matrix_row_counts[i02];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (max_rows == 0) {
|
||||
return; // no rows to process
|
||||
}
|
||||
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
|
||||
// size for converting src1 rows to vec_dot_type if needed
|
||||
const size_t nbw1 = row_size;
|
||||
const size_t nbw2 = nbw1 * ne11;
|
||||
const size_t nbw3 = nbw2 * ne12;
|
||||
const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0;
|
||||
|
||||
// size for MoE gather/scatter buffers
|
||||
const size_t wdata_cur_size = max_rows * row_size;
|
||||
const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01);
|
||||
|
||||
// allocate single buffer for all needs
|
||||
const size_t total_size = src1_conv_size + wdata_cur_size + dst_cur_size;
|
||||
if (ctx->work_size < total_size) {
|
||||
ctx->work_data.reset(new char[total_size]);
|
||||
ctx->work_size = total_size;
|
||||
}
|
||||
|
||||
// partition the buffer
|
||||
char * work_data = ctx->work_data.get();
|
||||
char * wdata_cur = work_data + src1_conv_size;
|
||||
char * dst_cur = wdata_cur + wdata_cur_size;
|
||||
|
||||
if (src1->type != vec_dot_type) {
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
#pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13);
|
||||
void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3;
|
||||
from_float(src1_f32, src1_conv, ne10);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const void * wdata = src1->type == vec_dot_type ? src1->data : work_data;
|
||||
|
||||
// process each expert with gather -> gemm -> scatter pattern
|
||||
for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) {
|
||||
const int64_t cne1 = matrix_row_counts[cur_a];
|
||||
|
||||
if (cne1 == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
|
||||
|
||||
// gather input rows for this expert
|
||||
#pragma omp parallel for num_threads(ctx->n_threads) schedule(static)
|
||||
for (int64_t ir1 = 0; ir1 < cne1; ++ir1) {
|
||||
const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1];
|
||||
const int64_t id = row_mapping.i1;
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2;
|
||||
|
||||
std::memcpy(
|
||||
wdata_cur + ir1 * row_size,
|
||||
(const char *) wdata + (i11 + i12*ne11) * row_size,
|
||||
row_size
|
||||
);
|
||||
}
|
||||
|
||||
// batched gemm for all tokens in this expert
|
||||
if (!ggml_zendnn_sgemm(ctx,
|
||||
ne01, // m
|
||||
cne1, // n
|
||||
ne10, // k
|
||||
src0_cur,
|
||||
ne00, // lda
|
||||
wdata_cur,
|
||||
ne10, // ldb
|
||||
dst_cur,
|
||||
ne01, // ldc
|
||||
src0->type,
|
||||
vec_dot_type,
|
||||
dst->type)) {
|
||||
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
|
||||
}
|
||||
|
||||
// scatter output rows to destination
|
||||
#pragma omp parallel for num_threads(ctx->n_threads) schedule(static)
|
||||
for (int64_t ir1 = 0; ir1 < cne1; ++ir1) {
|
||||
const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1];
|
||||
const int64_t id = row_mapping.i1;
|
||||
const int64_t i1 = id;
|
||||
const int64_t i2 = row_mapping.i2;
|
||||
|
||||
std::memcpy(
|
||||
(char *) dst->data + i1*nb1 + i2*nb2,
|
||||
dst_cur + ir1 * ggml_row_size(dst->type, ne01),
|
||||
ggml_row_size(dst->type, ne01)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// backend interface
|
||||
|
||||
static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) {
|
||||
|
|
@ -218,6 +382,9 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm
|
|||
case GGML_OP_MUL_MAT:
|
||||
ggml_zendnn_compute_forward_mul_mat(ctx, node);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
ggml_zendnn_compute_forward_mul_mat_id(ctx, node);
|
||||
break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
|
|
@ -361,6 +528,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
|
|||
return true;
|
||||
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
const ggml_tensor * weights = op->src[0];
|
||||
const ggml_tensor * inputs = op->src[1];
|
||||
|
|
@ -374,6 +542,17 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
|
|||
ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
|
||||
return false;
|
||||
}
|
||||
// MUL_MAT_ID performs best with a moderate number of experts due to its
|
||||
// gather + batched matmul + scatter approach. Future versions will leverage
|
||||
// ZenDNN's grouped_gemm for better scalability with larger expert counts:
|
||||
// https://github.com/amd/ZenDNN/blob/main/docs/operator/lowoha_group_gemm_operator.md
|
||||
if (op->op == GGML_OP_MUL_MAT_ID) {
|
||||
const int64_t n_experts = weights->ne[2];
|
||||
const int64_t max_experts = 32;
|
||||
if (n_experts > max_experts) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
switch (weights->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_BF16:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,266 @@
|
|||
{%- macro format_parameters(properties, required) -%}
|
||||
{%- set standard_keys = ['description', 'type', 'properties', 'required', 'nullable'] -%}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in properties | dictsort -%}
|
||||
{%- set add_comma = false -%}
|
||||
{%- if key not in standard_keys -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{{ key }}:{
|
||||
{%- if value['description'] -%}
|
||||
description:<|"|>{{ value['description'] }}<|"|>
|
||||
{%- set add_comma = true -%}
|
||||
{%- endif -%}
|
||||
{%- if value['nullable'] %}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
nullable:true
|
||||
{%- endif -%}
|
||||
{%- if value['type'] | upper == 'STRING' -%}
|
||||
{%- if value['enum'] -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
enum:{{ format_argument(value['enum']) }}
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'OBJECT' -%}
|
||||
,properties:{
|
||||
{%- if value['properties'] is defined and value['properties'] is mapping -%}
|
||||
{{- format_parameters(value['properties'], value['required'] | default([])) -}}
|
||||
{%- elif value is mapping -%}
|
||||
{{- format_parameters(value, value['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- if value['required'] -%}
|
||||
,required:[
|
||||
{%- for item in value['required'] | default([]) -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- endif -%}
|
||||
{%- elif value['type'] | upper == 'ARRAY' -%}
|
||||
{%- if value['items'] is mapping and value['items'] -%}
|
||||
,items:{
|
||||
{%- set ns_items = namespace(found_first=false) -%}
|
||||
{%- for item_key, item_value in value['items'] | dictsort -%}
|
||||
{%- if item_value is not none -%}
|
||||
{%- if ns_items.found_first %},{% endif -%}
|
||||
{%- set ns_items.found_first = true -%}
|
||||
{%- if item_key == 'properties' -%}
|
||||
properties:{
|
||||
{%- if item_value is mapping -%}
|
||||
{{- format_parameters(item_value, value['items']['required'] | default([])) -}}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- elif item_key == 'required' -%}
|
||||
required:[
|
||||
{%- for req_item in item_value -%}
|
||||
<|"|>{{- req_item -}}<|"|>
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
]
|
||||
{%- elif item_key == 'type' -%}
|
||||
{%- if item_value is string -%}
|
||||
type:{{ format_argument(item_value | upper) }}
|
||||
{%- else -%}
|
||||
type:{{ format_argument(item_value | map('upper') | list) }}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{{ item_key }}:{{ format_argument(item_value) }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if add_comma %},{%- else -%} {%- set add_comma = true -%} {% endif -%}
|
||||
type:<|"|>{{ value['type'] | upper }}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_function_declaration(tool_data) -%}
|
||||
declaration:{{- tool_data['function']['name'] -}}{description:<|"|>{{- tool_data['function']['description'] -}}<|"|>
|
||||
{%- set params = tool_data['function']['parameters'] -%}
|
||||
{%- if params -%}
|
||||
,parameters:{
|
||||
{%- if params['properties'] -%}
|
||||
properties:{ {{- format_parameters(params['properties'], params['required']) -}} },
|
||||
{%- endif -%}
|
||||
{%- if params['required'] -%}
|
||||
required:[
|
||||
{%- for item in params['required'] -%}
|
||||
<|"|>{{- item -}}<|"|>
|
||||
{{- ',' if not loop.last -}}
|
||||
{%- endfor -%}
|
||||
],
|
||||
{%- endif -%}
|
||||
{%- if params['type'] -%}
|
||||
type:<|"|>{{- params['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{%- if 'response' in tool_data['function'] -%}
|
||||
{%- set response_declaration = tool_data['function']['response'] -%}
|
||||
,response:{
|
||||
{%- if response_declaration['description'] -%}
|
||||
description:<|"|>{{- response_declaration['description'] -}}<|"|>,
|
||||
{%- endif -%}
|
||||
{%- if response_declaration['type'] | upper == 'OBJECT' -%}
|
||||
type:<|"|>{{- response_declaration['type'] | upper -}}<|"|>}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
}
|
||||
{%- endmacro -%}
|
||||
{%- macro format_argument(argument, escape_keys=True) -%}
|
||||
{%- if argument is string -%}
|
||||
{{- '<|"|>' + argument + '<|"|>' -}}
|
||||
{%- elif argument is boolean -%}
|
||||
{{- 'true' if argument else 'false' -}}
|
||||
{%- elif argument is mapping -%}
|
||||
{{- '{' -}}
|
||||
{%- set ns = namespace(found_first=false) -%}
|
||||
{%- for key, value in argument | dictsort -%}
|
||||
{%- if ns.found_first %},{% endif -%}
|
||||
{%- set ns.found_first = true -%}
|
||||
{%- if escape_keys -%}
|
||||
{{- '<|"|>' + key + '<|"|>' -}}
|
||||
{%- else -%}
|
||||
{{- key -}}
|
||||
{%- endif -%}
|
||||
:{{- format_argument(value, escape_keys=escape_keys) -}}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- elif argument is sequence -%}
|
||||
{{- '[' -}}
|
||||
{%- for item in argument -%}
|
||||
{{- format_argument(item, escape_keys=escape_keys) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ']' -}}
|
||||
{%- else -%}
|
||||
{{- argument -}}
|
||||
{%- endif -%}
|
||||
{%- endmacro -%}
|
||||
{%- macro strip_thinking(text) -%}
|
||||
{%- set ns = namespace(result='') -%}
|
||||
{%- for part in text.split('<channel|>') -%}
|
||||
{%- if '<|channel>' in part -%}
|
||||
{%- set ns.result = ns.result + part.split('<|channel>')[0] -%}
|
||||
{%- else -%}
|
||||
{%- set ns.result = ns.result + part -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{{- ns.result | trim -}}
|
||||
{%- endmacro -%}
|
||||
|
||||
{%- set ns = namespace(prev_message_type=None) -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{{ bos_token }}
|
||||
{#- Handle System/Tool Definitions Block -#}
|
||||
{%- if (enable_thinking is defined and enable_thinking) or tools or messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- '<|turn>system\n' -}}
|
||||
|
||||
{#- Inject Thinking token at the very top of the FIRST system turn -#}
|
||||
{%- if enable_thinking is defined and enable_thinking -%}
|
||||
{{- '<|think|>' -}}
|
||||
{%- set ns.prev_message_type = 'think' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if messages[0]['role'] in ['system', 'developer'] -%}
|
||||
{{- messages[0]['content'] | trim -}}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if tools -%}
|
||||
{%- for tool in tools %}
|
||||
{{- '<|tool>' -}}
|
||||
{{- format_function_declaration(tool) | trim -}}
|
||||
{{- '<tool|>' -}}
|
||||
{%- endfor %}
|
||||
{%- set ns.prev_message_type = 'tool' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif %}
|
||||
|
||||
{#- Loop through messages -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{%- set ns.prev_message_type = None -%}
|
||||
{%- set role = 'model' if message['role'] == 'assistant' else message['role'] -%}
|
||||
{{- '<|turn>' + role + '\n' }}
|
||||
|
||||
{%- if message['tool_calls'] -%}
|
||||
{%- for tool_call in message['tool_calls'] -%}
|
||||
{%- set function = tool_call['function'] -%}
|
||||
{{- '<|tool_call>call:' + function['name'] + '{' -}}
|
||||
{%- if function['arguments'] is mapping -%}
|
||||
{%- set ns_args = namespace(found_first=false) -%}
|
||||
{%- for key, value in function['arguments'] | dictsort -%}
|
||||
{%- if ns_args.found_first %},{% endif -%}
|
||||
{%- set ns_args.found_first = true -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- endfor -%}
|
||||
{%- elif function['arguments'] is string -%}
|
||||
{{- function['arguments'] -}}
|
||||
{%- endif -%}
|
||||
{{- '}<tool_call|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_call' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['tool_responses'] -%}
|
||||
{#- Tool Response handling -#}
|
||||
{%- for tool_response in message['tool_responses'] -%}
|
||||
{{- '<|tool_response>' -}}
|
||||
{%- if tool_response['response'] is mapping -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{' -}}
|
||||
{%- for key, value in tool_response['response'] | dictsort -%}
|
||||
{{- key -}}:{{- format_argument(value, escape_keys=False) -}}
|
||||
{%- if not loop.last %},{% endif -%}
|
||||
{%- endfor -%}
|
||||
{{- '}' -}}
|
||||
{%- else -%}
|
||||
{{- 'response:' + tool_response['name'] | default('unknown') + '{value:' + format_argument(tool_response['response'], escape_keys=False) + '}' -}}
|
||||
{%- endif -%}
|
||||
{{- '<tool_response|>' -}}
|
||||
{%- endfor -%}
|
||||
{%- set ns.prev_message_type = 'tool_response' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if message['content'] is string -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(message['content']) -}}
|
||||
{%- else -%}
|
||||
{{- message['content'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif message['content'] is sequence -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'text' -%}
|
||||
{%- if role == 'model' -%}
|
||||
{{- strip_thinking(item['text']) -}}
|
||||
{%- else -%}
|
||||
{{- item['text'] | trim -}}
|
||||
{%- endif -%}
|
||||
{%- elif item['type'] == 'image' -%}
|
||||
{{- '\n\n<|image|>\n\n' -}}
|
||||
{%- set ns.prev_message_type = 'image' -%}
|
||||
{%- elif item['type'] == 'audio' -%}
|
||||
{{- '<|audio|>' -}}
|
||||
{%- set ns.prev_message_type = 'audio' -%}
|
||||
{%- elif item['type'] == 'video' -%}
|
||||
{{- '\n\n<|video|>\n\n' -}}
|
||||
{%- set ns.prev_message_type = 'video' -%}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if not (message['tool_responses'] and not message['content']) -%}
|
||||
{{- '<turn|>\n' -}}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if ns.prev_message_type != 'tool_response' -%}
|
||||
{{- '<|turn>model\n' -}}
|
||||
{%- endif -%}
|
||||
{%- if not enable_thinking | default(false) -%}
|
||||
{{- '<|channel>thought\n<channel|>' -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama-context.h"
|
||||
#include "ggml.h"
|
||||
#include "stdint.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
// Reserve a new compute graph. It is valid until the next call to llama_graph_reserve.
|
||||
LLAMA_API struct ggml_cgraph * llama_graph_reserve(
|
||||
|
|
@ -10,3 +10,47 @@ LLAMA_API struct ggml_cgraph * llama_graph_reserve(
|
|||
uint32_t n_tokens,
|
||||
uint32_t n_seqs,
|
||||
uint32_t n_outputs);
|
||||
|
||||
// Get the default ggml_type for a given ftype.
|
||||
LLAMA_API ggml_type llama_ftype_get_default_type(llama_ftype ftype);
|
||||
|
||||
// Quantization state.
|
||||
struct quantize_state_impl;
|
||||
|
||||
LLAMA_API quantize_state_impl * llama_quant_init(
|
||||
const llama_model * model,
|
||||
const llama_model_quantize_params * params);
|
||||
|
||||
LLAMA_API void llama_quant_free(quantize_state_impl * qs);
|
||||
|
||||
// Descriptor for constructing a mock model for quantization testing.
|
||||
struct llama_quant_model_desc {
|
||||
const char * architecture;
|
||||
uint32_t n_embd;
|
||||
uint32_t n_ff;
|
||||
uint32_t n_layer;
|
||||
uint32_t n_head;
|
||||
uint32_t n_head_kv;
|
||||
uint32_t n_expert;
|
||||
uint32_t n_embd_head_k;
|
||||
uint32_t n_embd_head_v;
|
||||
};
|
||||
|
||||
// Create a mock model from a metadata descriptor (for testing).
|
||||
// The returned model must be freed with llama_model_free().
|
||||
LLAMA_API llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc);
|
||||
|
||||
// Returns true if this tensor should be quantized (based on name, dims, params).
|
||||
LLAMA_API bool llama_quant_tensor_allows_quantization(
|
||||
const quantize_state_impl * qs,
|
||||
const ggml_tensor * tensor);
|
||||
|
||||
// Compute quantization type assignments for a list of tensors.
|
||||
// All tensors should be quantizable (use llama_quant_tensor_allows_quantization to filter).
|
||||
// result_types: caller-allocated array of n_tensors elements, filled with assigned types.
|
||||
LLAMA_API void llama_quant_compute_types(
|
||||
quantize_state_impl * qs,
|
||||
llama_ftype ftype,
|
||||
ggml_tensor ** tensors,
|
||||
ggml_type * result_types,
|
||||
size_t n_tensors);
|
||||
|
|
|
|||
|
|
@ -66,9 +66,8 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
|
|||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
// note: the SWA cache is never quantized because it is relatively small
|
||||
kv_swa = std::make_unique<llama_kv_cache>(
|
||||
model, GGML_TYPE_F16, GGML_TYPE_F16,
|
||||
model, type_k, type_v,
|
||||
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
#include "llama.h"
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-model-loader.h"
|
||||
#include "llama-ext.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <cinttypes>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
|
|
@ -197,6 +197,7 @@ struct quantize_state_impl {
|
|||
|
||||
// per-tensor metadata, computed in the preliminary loop and used in the main loop
|
||||
struct tensor_metadata {
|
||||
std::string name;
|
||||
ggml_type target_type;
|
||||
tensor_category category;
|
||||
std::string remapped_imatrix_name;
|
||||
|
|
@ -788,7 +789,7 @@ static bool tensor_requires_imatrix(const char * tensor_name, const ggml_type ds
|
|||
// given a file type, get the default tensor type
|
||||
//
|
||||
|
||||
static ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
|
||||
ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
|
||||
switch (ftype) {
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0: return GGML_TYPE_Q4_0;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_1: return GGML_TYPE_Q4_1;
|
||||
|
|
@ -827,16 +828,32 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) {
|
|||
case LLAMA_FTYPE_MOSTLY_IQ3_S:
|
||||
case LLAMA_FTYPE_MOSTLY_IQ3_M: return GGML_TYPE_IQ3_S;
|
||||
|
||||
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
|
||||
default: return GGML_TYPE_COUNT;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void init_quantize_state_counters(quantize_state_impl & qs, std::vector<tensor_metadata> & metadata) {
|
||||
for (auto & tm : metadata) {
|
||||
tensor_category cat = tensor_get_category(tm.name);
|
||||
tm.category = cat;
|
||||
|
||||
if (category_is_attn_v(cat)) {
|
||||
++qs.n_attention_wv;
|
||||
}
|
||||
|
||||
if (cat == tensor_category::OUTPUT) {
|
||||
qs.has_tied_embeddings = false;
|
||||
}
|
||||
}
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer;
|
||||
}
|
||||
|
||||
//
|
||||
// main quantization driver
|
||||
//
|
||||
|
||||
static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
|
||||
ggml_type default_type;
|
||||
llama_ftype ftype = params->ftype;
|
||||
|
||||
int nthread = params->nthread;
|
||||
|
|
@ -845,7 +862,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
nthread = std::thread::hardware_concurrency();
|
||||
}
|
||||
|
||||
default_type = llama_ftype_get_default_type(ftype);
|
||||
ggml_type default_type = llama_ftype_get_default_type(ftype);
|
||||
if (default_type == GGML_TYPE_COUNT) {
|
||||
throw std::runtime_error(format("invalid output file type %d\n", ftype));
|
||||
}
|
||||
|
||||
// mmap consistently increases speed on Linux, and also increases speed on Windows with
|
||||
// hot cache. It may cause a slowdown on macOS, possibly related to free memory.
|
||||
|
|
@ -964,6 +984,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
});
|
||||
}
|
||||
|
||||
// compute tensor metadata once and cache it
|
||||
std::vector<tensor_metadata> metadata(tensors.size());
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
metadata[i].name = ggml_get_name(tensors[i]->tensor);
|
||||
}
|
||||
|
||||
// initialize quantization state counters and metadata categories
|
||||
init_quantize_state_counters(qs, metadata);
|
||||
|
||||
int idx = 0;
|
||||
uint16_t n_split = 1;
|
||||
|
||||
|
|
@ -976,25 +1005,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
std::vector<gguf_context_ptr> ctx_outs(n_split);
|
||||
ctx_outs[0] = std::move(ctx_out);
|
||||
|
||||
// compute tensor metadata once and cache it
|
||||
std::vector<tensor_metadata> metadata(tensors.size());
|
||||
|
||||
// initialize quantization state before preliminary loop (counters for use_more_bits)
|
||||
{
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
const auto cat = tensor_get_category(tensors[i]->tensor->name);
|
||||
if (category_is_attn_v(cat)) {
|
||||
++qs.n_attention_wv;
|
||||
}
|
||||
if (cat == tensor_category::OUTPUT) {
|
||||
qs.has_tied_embeddings = false;
|
||||
}
|
||||
metadata[i].category = cat; // save and re-use the category while we're at it
|
||||
}
|
||||
// these also need to be set to n_layer by default
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer;
|
||||
}
|
||||
|
||||
// flag for --dry-run
|
||||
bool will_require_imatrix = false;
|
||||
|
||||
|
|
@ -1005,7 +1015,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
const auto * it = tensors[i];
|
||||
const struct ggml_tensor * tensor = it->tensor;
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
|
||||
uint16_t i_split = params->keep_split ? it->idx : 0;
|
||||
if (!ctx_outs[i_split]) {
|
||||
|
|
@ -1034,7 +1043,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
" - offending tensor: %s\n"
|
||||
" - target type: %s\n"
|
||||
"============================================================================\n\n",
|
||||
name.c_str(), ggml_type_name(metadata[i].target_type));
|
||||
metadata[i].name.c_str(), ggml_type_name(metadata[i].target_type));
|
||||
throw std::runtime_error("this quantization requires an imatrix!");
|
||||
}
|
||||
}
|
||||
|
|
@ -1107,7 +1116,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
new_ofstream(weight.idx);
|
||||
}
|
||||
|
||||
const std::string name = ggml_get_name(tensor);
|
||||
const size_t tensor_size = ggml_nbytes(tensor);
|
||||
|
||||
if (!params->dry_run) {
|
||||
|
|
@ -1238,9 +1246,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
total_size_new += new_size;
|
||||
|
||||
// update the gguf meta data as we go
|
||||
gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
|
||||
GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
|
||||
gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
|
||||
gguf_set_tensor_type(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_type);
|
||||
GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), metadata[i].name.c_str())) == new_size);
|
||||
gguf_set_tensor_data(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_data);
|
||||
|
||||
// write tensor data + padding
|
||||
fout.write((const char *) new_data, new_size);
|
||||
|
|
@ -1305,3 +1313,89 @@ uint32_t llama_model_quantize(
|
|||
|
||||
return 0;
|
||||
}
|
||||
|
||||
//
|
||||
// Helper functions for external tools exposed in llama-ext.h
|
||||
//
|
||||
|
||||
quantize_state_impl * llama_quant_init(
|
||||
const llama_model * model,
|
||||
const llama_model_quantize_params * params) {
|
||||
return new quantize_state_impl(*model, params);
|
||||
}
|
||||
|
||||
void llama_quant_free(quantize_state_impl * qs) {
|
||||
delete qs;
|
||||
}
|
||||
|
||||
llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc) {
|
||||
struct llama_model_params mparams = llama_model_default_params();
|
||||
auto * model = new llama_model(mparams);
|
||||
|
||||
model->arch = llm_arch_from_string(desc->architecture);
|
||||
|
||||
// infer llm_type: only LLM_TYPE_70B matters for quantization logic
|
||||
if (model->arch == LLM_ARCH_LLAMA && desc->n_layer == 80 && desc->n_head != desc->n_head_kv) {
|
||||
model->type = LLM_TYPE_70B;
|
||||
}
|
||||
|
||||
model->hparams.n_embd = desc->n_embd;
|
||||
model->hparams.n_embd_head_k_full = desc->n_embd_head_k;
|
||||
model->hparams.n_embd_head_v_full = desc->n_embd_head_v;
|
||||
model->hparams.n_layer = desc->n_layer;
|
||||
model->hparams.n_expert = desc->n_expert;
|
||||
|
||||
for (uint32_t i = 0; i < desc->n_layer; i++) {
|
||||
model->hparams.n_head_arr[i] = desc->n_head;
|
||||
model->hparams.n_head_kv_arr[i] = desc->n_head_kv;
|
||||
model->hparams.n_ff_arr[i] = desc->n_ff;
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
|
||||
bool llama_quant_tensor_allows_quantization(
|
||||
const quantize_state_impl * qs,
|
||||
const ggml_tensor * tensor) {
|
||||
return tensor_allows_quantization(qs->params, qs->model.arch, tensor);
|
||||
}
|
||||
|
||||
void llama_quant_compute_types(
|
||||
quantize_state_impl * qs,
|
||||
llama_ftype ftype,
|
||||
ggml_tensor ** tensors,
|
||||
ggml_type * result_types,
|
||||
size_t n_tensors) {
|
||||
// reset per-computation state
|
||||
qs->n_attention_wv = 0;
|
||||
qs->n_ffn_down = 0;
|
||||
qs->n_ffn_gate = 0;
|
||||
qs->n_ffn_up = 0;
|
||||
qs->i_attention_wv = 0;
|
||||
qs->i_ffn_down = 0;
|
||||
qs->i_ffn_gate = 0;
|
||||
qs->i_ffn_up = 0;
|
||||
qs->n_fallback = 0;
|
||||
qs->has_imatrix = false;
|
||||
qs->has_tied_embeddings = true;
|
||||
|
||||
// build metadata from tensor names
|
||||
std::vector<tensor_metadata> metadata(n_tensors);
|
||||
for (size_t i = 0; i < n_tensors; i++) {
|
||||
metadata[i].name = ggml_get_name(tensors[i]);
|
||||
}
|
||||
|
||||
// initialize counters and categories
|
||||
init_quantize_state_counters(*qs, metadata);
|
||||
|
||||
// use a local copy of params with the requested ftype
|
||||
llama_model_quantize_params local_params = *qs->params;
|
||||
local_params.ftype = ftype;
|
||||
|
||||
ggml_type default_type = llama_ftype_get_default_type(ftype);
|
||||
|
||||
// compute types
|
||||
for (size_t i = 0; i < n_tensors; i++) {
|
||||
result_types[i] = llama_tensor_get_type(*qs, &local_params, tensors[i], default_type, metadata[i]);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -493,6 +493,16 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_GEMMA4:
|
||||
// Gemma4 uses SPM-style BPE: spaces are replaced with ▁ by the
|
||||
// normalizer, then BPE merges run on the whole text without
|
||||
// word-level pre-splitting. We only need to split on newlines
|
||||
// since BPE merge lookup asserts no newlines in tokens.
|
||||
regex_exprs = {
|
||||
"[^\\n]+|[\\n]+",
|
||||
};
|
||||
byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
|
|
@ -506,6 +516,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||
}
|
||||
|
||||
std::vector<std::string> regex_exprs;
|
||||
bool byte_encode = true; // GPT-2 byte encoding; false for SPM-style BPE (raw UTF-8)
|
||||
};
|
||||
|
||||
struct llm_tokenizer_bpe_session {
|
||||
|
|
@ -550,9 +561,10 @@ struct llm_tokenizer_bpe_session {
|
|||
|
||||
void tokenize(const std::string & text, std::vector<llama_token> & output) {
|
||||
int final_prev_index = -1;
|
||||
const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs);
|
||||
const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode);
|
||||
|
||||
symbols_final.clear();
|
||||
auto tok_pre = vocab.get_pre_type();
|
||||
|
||||
for (const auto & word : word_collection) {
|
||||
work_queue = llm_bigram_bpe::queue();
|
||||
|
|
@ -565,6 +577,13 @@ struct llm_tokenizer_bpe_session {
|
|||
if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) {
|
||||
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
||||
offset = word.size();
|
||||
} else if (tok_pre == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && word.find_first_not_of('\n') == std::string::npos) {
|
||||
// fix for gemma 4, ref: https://github.com/ggml-org/llama.cpp/pull/21343
|
||||
auto tok = vocab.text_to_token(word);
|
||||
if (tok != LLAMA_TOKEN_NULL) {
|
||||
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
||||
offset = word.size();
|
||||
}
|
||||
}
|
||||
|
||||
while (offset < word.size()) {
|
||||
|
|
@ -1864,7 +1883,31 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
special_pad_id = 3; // <|plamo:pad|>
|
||||
special_mask_id = LLAMA_TOKEN_NULL;
|
||||
} else if (tokenizer_model == "gemma4") {
|
||||
type = LLAMA_VOCAB_TYPE_SPM;
|
||||
type = LLAMA_VOCAB_TYPE_BPE;
|
||||
|
||||
// read bpe merges and populate bpe ranks
|
||||
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
|
||||
if (merges_keyidx == -1) {
|
||||
throw std::runtime_error("cannot find tokenizer merges in model file\n");
|
||||
}
|
||||
{
|
||||
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
|
||||
for (int i = 0; i < n_merges; i++) {
|
||||
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
|
||||
|
||||
std::string first;
|
||||
std::string second;
|
||||
|
||||
const size_t pos = word.find(' ', 1);
|
||||
|
||||
if (pos != std::string::npos) {
|
||||
first = word.substr(0, pos);
|
||||
second = word.substr(pos + 1);
|
||||
}
|
||||
|
||||
bpe_ranks.emplace(std::make_pair(first, second), i);
|
||||
}
|
||||
}
|
||||
|
||||
// default special tokens (to be read from GGUF)
|
||||
special_bos_id = LLAMA_TOKEN_NULL;
|
||||
|
|
@ -1874,7 +1917,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
special_pad_id = LLAMA_TOKEN_NULL;
|
||||
special_mask_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
tokenizer_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
tokenizer_pre = "gemma4";
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
||||
}
|
||||
|
|
@ -1882,6 +1925,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
// for now, only BPE models have pre-tokenizers
|
||||
if (type == LLAMA_VOCAB_TYPE_BPE) {
|
||||
add_space_prefix = false;
|
||||
escape_whitespaces = false;
|
||||
clean_spaces = true;
|
||||
if (tokenizer_pre.empty()) {
|
||||
LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
|
||||
|
|
@ -1948,6 +1992,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
} else if (
|
||||
tokenizer_pre == "jais-2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2;
|
||||
} else if (
|
||||
tokenizer_pre == "gemma4") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4;
|
||||
escape_whitespaces = true;
|
||||
} else if (
|
||||
tokenizer_pre == "jina-v1-en" ||
|
||||
tokenizer_pre == "jina-v2-code" ||
|
||||
|
|
@ -3045,6 +3093,10 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
|
|||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||
|
||||
if (escape_whitespaces) {
|
||||
llama_escape_whitespace(text);
|
||||
}
|
||||
|
||||
#ifdef PRETOKENIZERDEBUG
|
||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
|
||||
#endif
|
||||
|
|
@ -3224,6 +3276,12 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
|
|||
return _try_copy(token_text.data(), token_text.size());
|
||||
}
|
||||
if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
|
||||
if (escape_whitespaces) {
|
||||
// SPM-style BPE: tokens contain ▁ for spaces
|
||||
std::string result = token_text;
|
||||
llama_unescape_whitespace(result);
|
||||
return _try_copy(result.data(), result.size());
|
||||
}
|
||||
std::string result = llama_decode_text(token_text);
|
||||
return _try_copy(result.data(), result.size());
|
||||
}
|
||||
|
|
@ -3654,9 +3712,7 @@ int llama_vocab::max_token_len() const {
|
|||
|
||||
int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
|
||||
GGML_ASSERT(token_left.find(' ') == std::string::npos);
|
||||
GGML_ASSERT(token_left.find('\n') == std::string::npos);
|
||||
GGML_ASSERT(token_right.find(' ') == std::string::npos);
|
||||
GGML_ASSERT(token_right.find('\n') == std::string::npos);
|
||||
|
||||
auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right));
|
||||
if (it == pimpl->bpe_ranks.end()) {
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ enum llama_vocab_pre_type {
|
|||
LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47,
|
||||
LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48,
|
||||
LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49,
|
||||
LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
|
|
|||
|
|
@ -912,7 +912,7 @@ bool unicode_cpt_is_han(uint32_t cpt) {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode) {
|
||||
// unicode categories
|
||||
static const std::map<std::string, int> k_ucat_enum = {
|
||||
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
||||
|
|
@ -1099,5 +1099,9 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
start += offset;
|
||||
}
|
||||
|
||||
return unicode_byte_encoding_process(bpe_words);
|
||||
if (byte_encode) {
|
||||
return unicode_byte_encoding_process(bpe_words);
|
||||
}
|
||||
|
||||
return bpe_words;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,4 +108,4 @@ uint32_t unicode_tolower(uint32_t cpt);
|
|||
|
||||
bool unicode_cpt_is_han(uint32_t cpt);
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode = true);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
*
|
||||
!*.*
|
||||
!snapshots/
|
||||
*.o
|
||||
ggml-common.h
|
||||
**/*.swp
|
||||
|
|
|
|||
|
|
@ -274,6 +274,12 @@ if (TARGET cpp-httplib)
|
|||
add_executable(test-gguf-model-data test-gguf-model-data.cpp)
|
||||
target_link_libraries(test-gguf-model-data PRIVATE gguf-model-data common)
|
||||
llama_test(test-gguf-model-data LABEL "model")
|
||||
|
||||
# test-quant-type-selection requires gguf-model-data for remote model metadata
|
||||
llama_build_and_test(test-quant-type-selection.cpp LABEL "model")
|
||||
target_link_libraries(test-quant-type-selection PRIVATE gguf-model-data)
|
||||
target_compile_definitions(test-quant-type-selection PRIVATE
|
||||
SNAPSHOT_DIR="${CMAKE_CURRENT_SOURCE_DIR}/snapshots")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -125,6 +125,35 @@ static bool gguf_skip_value(gguf_buf_reader & r, int32_t vtype) {
|
|||
}
|
||||
|
||||
static bool gguf_read_uint32_val(gguf_buf_reader & r, int32_t vtype, uint32_t & out) {
|
||||
// Handle array-valued fields (e.g. per-layer head counts in hybrid models)
|
||||
// by reading the first element as a representative value.
|
||||
if (vtype == GGUF_TYPE_ARRAY) {
|
||||
int32_t elem_type;
|
||||
uint64_t count;
|
||||
if (!r.read_val(elem_type)) {
|
||||
return false;
|
||||
}
|
||||
if (!r.read_val(count)) {
|
||||
return false;
|
||||
}
|
||||
if (count == 0) {
|
||||
return false;
|
||||
}
|
||||
// Read first element, skip the rest
|
||||
if (!gguf_read_uint32_val(r, elem_type, out)) {
|
||||
return false;
|
||||
}
|
||||
for (uint64_t i = 1; i < count; i++) {
|
||||
size_t sz = gguf_val_type_size(elem_type);
|
||||
if (sz == 0) {
|
||||
return false;
|
||||
}
|
||||
if (!r.skip(sz)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (vtype == GGUF_TYPE_UINT8) {
|
||||
uint8_t v;
|
||||
if (!r.read_val(v)) {
|
||||
|
|
@ -487,7 +516,8 @@ static std::string detect_gguf_filename(const std::string & repo, const std::str
|
|||
static std::optional<gguf_remote_model> fetch_and_parse(
|
||||
const std::string & repo,
|
||||
const std::string & filename,
|
||||
const std::string & cache_path) {
|
||||
const std::string & cache_path,
|
||||
bool verbose) {
|
||||
std::string url = "https://huggingface.co/" + repo + "/resolve/main/" + filename;
|
||||
|
||||
// Progressive download inspired by RangeView.fetchChunk()
|
||||
|
|
@ -496,7 +526,9 @@ static std::optional<gguf_remote_model> fetch_and_parse(
|
|||
const size_t max_chunk = 64 * 1024 * 1024;
|
||||
|
||||
while (chunk_size <= max_chunk) {
|
||||
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str());
|
||||
}
|
||||
|
||||
char range_buf[64];
|
||||
snprintf(range_buf, sizeof(range_buf), "bytes=0-%zu", chunk_size - 1);
|
||||
|
|
@ -542,7 +574,8 @@ static std::optional<gguf_remote_model> fetch_or_cached(
|
|||
const std::string & repo,
|
||||
const std::string & filename,
|
||||
const std::string & cdir,
|
||||
const std::string & repo_part) {
|
||||
const std::string & repo_part,
|
||||
bool verbose) {
|
||||
std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
|
||||
|
||||
{
|
||||
|
|
@ -550,20 +583,23 @@ static std::optional<gguf_remote_model> fetch_or_cached(
|
|||
if (std::filesystem::exists(cache_path) && read_file(cache_path, cached)) {
|
||||
auto result = gguf_parse_meta(cached);
|
||||
if (result.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fs_create_directory_with_parents(cdir);
|
||||
return fetch_and_parse(repo, filename, cache_path);
|
||||
return fetch_and_parse(repo, filename, cache_path, verbose);
|
||||
}
|
||||
|
||||
std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
const std::string & repo,
|
||||
const std::string & quant,
|
||||
const std::string & cache_dir) {
|
||||
const std::string & cache_dir,
|
||||
bool verbose) {
|
||||
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
|
||||
std::string repo_part = sanitize_for_path(repo);
|
||||
|
||||
|
|
@ -573,7 +609,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
|
||||
if (!model_opt.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
|
||||
return std::nullopt;
|
||||
|
|
@ -588,8 +624,10 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
|||
return std::nullopt;
|
||||
}
|
||||
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
}
|
||||
|
||||
for (int i = 2; i <= model.n_split; i++) {
|
||||
char num_buf[6], total_buf[6];
|
||||
|
|
@ -597,7 +635,7 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
|||
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
|
||||
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
|
||||
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
|
||||
if (!shard.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
|
||||
return std::nullopt;
|
||||
|
|
@ -620,7 +658,8 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
|||
gguf_context_ptr gguf_fetch_gguf_ctx(
|
||||
const std::string & repo,
|
||||
const std::string & quant,
|
||||
const std::string & cache_dir) {
|
||||
const std::string & cache_dir,
|
||||
bool verbose) {
|
||||
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
|
||||
std::string repo_part = sanitize_for_path(repo);
|
||||
|
||||
|
|
@ -631,7 +670,7 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
|
||||
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part, verbose);
|
||||
if (!model_opt.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
|
||||
return nullptr;
|
||||
|
|
@ -659,8 +698,10 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
if (verbose) {
|
||||
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
|
||||
model.n_split, model.n_split - 1);
|
||||
}
|
||||
|
||||
for (int i = 2; i <= model.n_split; i++) {
|
||||
char num_buf[6], total_buf[6];
|
||||
|
|
@ -668,7 +709,7 @@ gguf_context_ptr gguf_fetch_gguf_ctx(
|
|||
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
|
||||
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
|
||||
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
|
||||
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part, verbose);
|
||||
if (!shard.has_value()) {
|
||||
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
|
||||
return nullptr;
|
||||
|
|
|
|||
|
|
@ -40,9 +40,11 @@ struct gguf_remote_model {
|
|||
std::optional<gguf_remote_model> gguf_fetch_model_meta(
|
||||
const std::string & repo,
|
||||
const std::string & quant = "Q8_0",
|
||||
const std::string & cache_dir = ""); // empty = default
|
||||
const std::string & cache_dir = "", // empty = default
|
||||
bool verbose = true);
|
||||
|
||||
gguf_context_ptr gguf_fetch_gguf_ctx(
|
||||
const std::string & repo,
|
||||
const std::string & quant = "Q8_0",
|
||||
const std::string & cache_dir = "");
|
||||
const std::string & cache_dir = "",
|
||||
bool verbose = true);
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -42,6 +42,7 @@
|
|||
#include <thread>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
# define N_THREADS 1
|
||||
|
|
@ -7621,6 +7622,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_conv_2d( { 16, 16, 8, 1}, { 3, 3, 8, 12},
|
||||
GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d( { 16, 16, 16, 1}, { 3, 3, 16, 6},
|
||||
GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d( { 16, 16, 24, 1}, { 3, 3, 24, 6},
|
||||
GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d( { 16, 16, 8, 3}, { 3, 3, 8, 6},
|
||||
GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d( { 24, 24, 32, 1 }, { 3, 3, 32, 8},
|
||||
GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d( { 24, 24, 96, 1 }, { 3, 3, 96, 8},
|
||||
GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d( { 24, 24, 128, 1 }, { 3, 3, 128, 8},
|
||||
GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false));
|
||||
test_cases.emplace_back(new test_conv_2d( { 24, 24, 128, 3 }, { 3, 3, 128, 8},
|
||||
GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false));
|
||||
|
||||
|
||||
|
||||
// sycl backend will limit task global_range < MAX_INT
|
||||
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
|
||||
// however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)
|
||||
|
|
@ -8769,6 +8789,71 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
}
|
||||
}
|
||||
|
||||
// Stable-diffusion layers
|
||||
std::map<std::string, uint32_t> idx_sd{
|
||||
{ "iw", 0 },
|
||||
{ "ih", 1 },
|
||||
{ "kw", 2 },
|
||||
{ "kh", 3 },
|
||||
{ "Cout", 4 },
|
||||
{ "Cin", 5 },
|
||||
{ "B", 6 },
|
||||
};
|
||||
|
||||
// Input image size
|
||||
uint32_t w = 768;
|
||||
uint32_t h = 1024;
|
||||
|
||||
// Number of filters (base)
|
||||
uint32_t Cout_b = 128;
|
||||
uint32_t Cin_b = 128;
|
||||
|
||||
std::vector<std::array<uint32_t, 7>> cases_sd = {
|
||||
{ w / 8, h / 8, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x10 (called 10 times)
|
||||
{ w / 4, h / 4, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x7
|
||||
{ w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 2, 1 }, // x5
|
||||
{ w, h, 3, 3, Cout_b, Cin_b, 1 }, // x5
|
||||
{ w / 8, h / 8, 1, 1, Cout_b * 4, Cin_b * 4, 1 }, // x4
|
||||
{ w / 8, h / 8, 1, 1, 4, 4, 1 },
|
||||
{ w / 8, h / 8, 3, 3, Cout_b * 4, 4, 1 },
|
||||
|
||||
{ w / 2, h / 2, 3, 3, Cout_b * 4, Cin_b * 4, 1 },
|
||||
{ w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 4, 1 },
|
||||
{ w / 2, h / 2, 1, 1, Cout_b * 2, Cin_b * 4, 1 },
|
||||
|
||||
{ w, h, 3, 3, Cout_b, Cin_b * 2, 1 },
|
||||
{ w, h, 1, 1, Cout_b, Cin_b * 2, 1 },
|
||||
{ w, h, 3, 3, Cout_b * 2, Cin_b * 2, 1 },
|
||||
|
||||
{ w, h, 3, 3, 3, Cin_b, 1 },
|
||||
};
|
||||
|
||||
for (auto act_case : cases_sd) {
|
||||
GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1);
|
||||
GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1);
|
||||
|
||||
uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0;
|
||||
uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0;
|
||||
|
||||
test_cases.emplace_back(new test_conv_2d(
|
||||
{ act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] },
|
||||
{ act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] },
|
||||
GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false));
|
||||
}
|
||||
|
||||
for (auto act_case : cases_sd) {
|
||||
GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1);
|
||||
GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1);
|
||||
|
||||
uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0;
|
||||
uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0;
|
||||
|
||||
test_cases.emplace_back(new test_conv_2d(
|
||||
{ act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] },
|
||||
{ act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] },
|
||||
GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, false));
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
|
||||
|
||||
|
|
|
|||
|
|
@ -589,6 +589,51 @@ static common_chat_tool amount_tool{
|
|||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool toggle_tool{
|
||||
/* .name = */ "toggle",
|
||||
/* .description = */ "Toggle a feature",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to enable the feature"
|
||||
}
|
||||
},
|
||||
"required": ["enabled"]
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool nullable_tool{
|
||||
/* .name = */ "set_nullable",
|
||||
/* .description = */ "Set a nullable value",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"type": "null",
|
||||
"description": "A null value"
|
||||
}
|
||||
},
|
||||
"required": ["value"]
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool config_tool{
|
||||
/* .name = */ "set_config",
|
||||
/* .description = */ "Set configuration",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"description": "Configuration dict"
|
||||
}
|
||||
},
|
||||
"required": ["config"]
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool imaginary_number_tool{
|
||||
/* .name = */ "imaginary_number",
|
||||
/* .description = */ "Imaginary number converter",
|
||||
|
|
@ -612,6 +657,66 @@ static common_chat_tool imaginary_number_tool{
|
|||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool nullable_string_tool{
|
||||
/* .name = */ "set_nullable_str",
|
||||
/* .description = */ "Set a nullable string value",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": ["string", "null"],
|
||||
"description": "A nullable string"
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool nullable_string_null_first_tool{
|
||||
/* .name = */ "set_nullable_str_nf",
|
||||
/* .description = */ "Set a nullable string value with null first in type array",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": ["null", "string"],
|
||||
"description": "A nullable string with null first"
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool nullable_int_tool{
|
||||
/* .name = */ "set_nullable_int",
|
||||
/* .description = */ "Set a nullable integer value",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": ["integer", "null"],
|
||||
"description": "A nullable integer"
|
||||
}
|
||||
},
|
||||
"required": ["count"]
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool enum_no_type_tool{
|
||||
/* .name = */ "set_unit",
|
||||
/* .description = */ "Set a temperature unit",
|
||||
/* .parameters = */ R"({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"unit": {
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "Temperature unit"
|
||||
}
|
||||
},
|
||||
"required": ["unit"]
|
||||
})",
|
||||
};
|
||||
|
||||
static common_chat_tool string_param_tool{
|
||||
/* .name = */ "string_param",
|
||||
/* .description = */ "Tool with string parameter for testing",
|
||||
|
|
@ -1869,6 +1974,130 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
|||
tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).expect_reconstruction().run();
|
||||
}
|
||||
|
||||
{
|
||||
// Google Gemma 4 (tool calling with Gemma4 dict format)
|
||||
auto tst = peg_tester("models/templates/gemma4.jinja");
|
||||
|
||||
tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run();
|
||||
|
||||
// Simple tool call with string argument
|
||||
tst.test(
|
||||
"<|tool_call>call:get_time{city:<|\"|>London<|\"|>}<tool_call|>")
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_tool_calls("get_time", R"({"city": "London"})"))
|
||||
.run();
|
||||
|
||||
// Tool call with string argument containing special chars
|
||||
tst.test(
|
||||
"<|tool_call>call:get_time{city:<|\"|>San Francisco<|\"|>}<tool_call|>")
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_tool_calls("get_time", R"({"city": "San Francisco"})"))
|
||||
.run();
|
||||
|
||||
// Tool call with empty args
|
||||
tst.test(
|
||||
"<|tool_call>call:empty_args{}<tool_call|>")
|
||||
.tools({ empty_args_tool })
|
||||
.expect(message_with_tool_calls("empty_args", "{}"))
|
||||
.run();
|
||||
|
||||
// Tool call with string and content
|
||||
tst.test(
|
||||
"Hello, world!\nWhat's up?<|tool_call>call:get_time{city:<|\"|>Paris<|\"|>}<tool_call|>")
|
||||
.tools({ get_time_tool })
|
||||
.expect(message_with_content_and_tool_call("Hello, world!\nWhat's up?", "get_time", R"({"city": "Paris"})"))
|
||||
.run();
|
||||
|
||||
// Parallel tool calls
|
||||
tst.test(
|
||||
"<|tool_call>call:get_time{city:<|\"|>London<|\"|>}<tool_call|>"
|
||||
"<|tool_call>call:get_weather{city:<|\"|>Paris<|\"|>}<tool_call|>")
|
||||
.tools({ get_time_tool, get_weather_tool })
|
||||
.parallel_tool_calls(true)
|
||||
.expect_tool_calls({
|
||||
{ "get_time", R"({"city": "London"})", "" },
|
||||
{ "get_weather", R"({"city": "Paris"})", "" },
|
||||
})
|
||||
.run();
|
||||
|
||||
// Tool call with integer argument (number type)
|
||||
tst.test(
|
||||
"<|tool_call>call:special_function{arg1:42}<tool_call|>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_with_tool_calls("special_function", R"({"arg1": 42})"))
|
||||
.run();
|
||||
|
||||
// Tool call with negative number argument
|
||||
tst.test(
|
||||
"<|tool_call>call:special_function{arg1:-7}<tool_call|>")
|
||||
.tools({ special_function_tool })
|
||||
.expect(message_with_tool_calls("special_function", R"({"arg1": -7})"))
|
||||
.run();
|
||||
|
||||
// Tool call with decimal number argument
|
||||
tst.test(
|
||||
"<|tool_call>call:amount{orig:3.14}<tool_call|>")
|
||||
.tools({ amount_tool })
|
||||
.expect(message_with_tool_calls("amount", R"({"orig": 3.14})"))
|
||||
.run();
|
||||
|
||||
// Tool call with boolean argument (true)
|
||||
tst.test(
|
||||
"<|tool_call>call:toggle{enabled:true}<tool_call|>")
|
||||
.tools({ toggle_tool })
|
||||
.expect(message_with_tool_calls("toggle", R"({"enabled": true})"))
|
||||
.run();
|
||||
|
||||
// Tool call with boolean argument (false)
|
||||
tst.test(
|
||||
"<|tool_call>call:toggle{enabled:false}<tool_call|>")
|
||||
.tools({ toggle_tool })
|
||||
.expect(message_with_tool_calls("toggle", R"({"enabled": false})"))
|
||||
.run();
|
||||
|
||||
// Tool call with null argument
|
||||
tst.test(
|
||||
"<|tool_call>call:set_nullable{value:null}<tool_call|>")
|
||||
.tools({ nullable_tool })
|
||||
.expect(message_with_tool_calls("set_nullable", R"({"value": null})"))
|
||||
.run();
|
||||
|
||||
// Tool call with array argument (todo list)
|
||||
tst.test(
|
||||
"<|tool_call>call:todo_list{todos:[<|\"|>buy milk<|\"|>,<|\"|>walk dog<|\"|>]}<tool_call|>")
|
||||
.tools({ todo_list })
|
||||
.expect(message_with_tool_calls("todo_list", R"({"todos":["buy milk","walk dog"]})"))
|
||||
.run();
|
||||
|
||||
// Tool call with object/dict argument
|
||||
tst.test(
|
||||
"<|tool_call>call:set_config{config:{theme:<|\"|>dark<|\"|>,count:3}}<tool_call|>")
|
||||
.tools({ config_tool })
|
||||
.expect(message_with_tool_calls("set_config", R"({"config":{"theme":"dark","count":3}})"))
|
||||
.run();
|
||||
|
||||
// Tool call with empty array
|
||||
tst.test(
|
||||
"<|tool_call>call:todo_list{todos:[]}<tool_call|>")
|
||||
.tools({ todo_list })
|
||||
.expect(message_with_tool_calls("todo_list", R"({"todos":[]})"))
|
||||
.run();
|
||||
|
||||
// Tool call with empty dict
|
||||
tst.test(
|
||||
"<|tool_call>call:set_config{config:{}}<tool_call|>")
|
||||
.tools({ config_tool })
|
||||
.expect(message_with_tool_calls("set_config", R"({"config":{}})"))
|
||||
.run();
|
||||
|
||||
// Tool call with scientific notation number
|
||||
tst.test(
|
||||
"<|tool_call>call:amount{orig:1.5e10}<tool_call|>")
|
||||
.tools({ amount_tool })
|
||||
.expect(message_with_tool_calls("amount", R"({"orig": 1.5e10})"))
|
||||
.run();
|
||||
}
|
||||
|
||||
{
|
||||
// Qwen-QwQ-32B (reasoning model)
|
||||
auto tst = peg_tester("models/templates/Qwen-QwQ-32B.jinja");
|
||||
|
|
@ -2031,6 +2260,7 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
|||
}
|
||||
})
|
||||
.run();
|
||||
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2214,6 +2444,58 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
|||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
// nullable string type ["string", "null"]
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
"<function=set_nullable_str>\n"
|
||||
"<parameter=name>\nhello world\n</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.tools({ nullable_string_tool })
|
||||
.expect_tool_calls({
|
||||
{ "set_nullable_str", R"({"name": "hello world"})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// nullable string with null first in type array ["null", "string"]
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
"<function=set_nullable_str_nf>\n"
|
||||
"<parameter=name>\nhello world\n</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.tools({ nullable_string_null_first_tool })
|
||||
.expect_tool_calls({
|
||||
{ "set_nullable_str_nf", R"({"name": "hello world"})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// nullable integer type ["integer", "null"] - should use JSON value path, not string
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
"<function=set_nullable_int>\n"
|
||||
"<parameter=count>\n42\n</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.tools({ nullable_int_tool })
|
||||
.expect_tool_calls({
|
||||
{ "set_nullable_int", R"({"count": 42})", {} },
|
||||
})
|
||||
.run();
|
||||
|
||||
// enum without explicit type key - should infer string from enum values
|
||||
tst.test(
|
||||
"<tool_call>\n"
|
||||
"<function=set_unit>\n"
|
||||
"<parameter=unit>\ncelsius\n</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>")
|
||||
.tools({ enum_no_type_tool })
|
||||
.expect_tool_calls({
|
||||
{ "set_unit", R"({"unit": "celsius"})", {} },
|
||||
})
|
||||
.run();
|
||||
}
|
||||
{
|
||||
auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug);
|
||||
|
|
@ -2372,55 +2654,57 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
|||
// #20424 introduced effective_input = generation_prompt + input, but the throw
|
||||
// uses input.substr(result.end) where result.end is in effective_input space.
|
||||
{
|
||||
auto tmpls = common_chat_templates_ptr(
|
||||
common_chat_templates_init(nullptr, read_file("models/templates/GLM-4.7-Flash.jinja")));
|
||||
if (!g_template_filter.empty() && std::string("models/templates/GLM-4.7-Flash.jinja").find(g_template_filter) != std::string::npos) {
|
||||
auto tmpls = common_chat_templates_ptr(
|
||||
common_chat_templates_init(nullptr, read_file("models/templates/GLM-4.7-Flash.jinja")));
|
||||
|
||||
static common_chat_tool weather_tool{
|
||||
"get_weather", "Get weather",
|
||||
R"({"type":"object","properties":{"city":{"type":"string"}},"required":["city"]})",
|
||||
};
|
||||
static common_chat_tool weather_tool{
|
||||
"get_weather", "Get weather",
|
||||
R"({"type":"object","properties":{"city":{"type":"string"}},"required":["city"]})",
|
||||
};
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.tools = { weather_tool };
|
||||
inputs.enable_thinking = true;
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.use_jinja = true;
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
msg.content = "get_weather";
|
||||
inputs.messages = { msg };
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.tools = { weather_tool };
|
||||
inputs.enable_thinking = true;
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.use_jinja = true;
|
||||
common_chat_msg msg;
|
||||
msg.role = "user";
|
||||
msg.content = "get_weather";
|
||||
inputs.messages = { msg };
|
||||
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs);
|
||||
common_peg_arena arena;
|
||||
arena.load(params.parser);
|
||||
common_chat_parser_params pp(params);
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs);
|
||||
common_peg_arena arena;
|
||||
arena.load(params.parser);
|
||||
common_chat_parser_params pp(params);
|
||||
|
||||
// generation_prompt is non-empty for thinking models, so result.end
|
||||
// will be offset by generation_prompt.size() into effective_input space.
|
||||
assert(!pp.generation_prompt.empty());
|
||||
// generation_prompt is non-empty for thinking models, so result.end
|
||||
// will be offset by generation_prompt.size() into effective_input space.
|
||||
assert(!pp.generation_prompt.empty());
|
||||
|
||||
std::string bad_input =
|
||||
"Thinking.\n"
|
||||
"</think>"
|
||||
"<tool_call>get_weather"
|
||||
"<arg_key>city</arg_key><arg_value>Tokyo</arg_value>"
|
||||
"</tool_call>\n";
|
||||
std::string bad_input =
|
||||
"Thinking.\n"
|
||||
"</think>"
|
||||
"<tool_call>get_weather"
|
||||
"<arg_key>city</arg_key><arg_value>Tokyo</arg_value>"
|
||||
"</tool_call>\n";
|
||||
|
||||
bool got_runtime_error = false;
|
||||
bool got_out_of_range = false;
|
||||
std::string error_msg;
|
||||
try {
|
||||
common_chat_peg_parse(arena, bad_input, /*is_partial=*/false, pp);
|
||||
} catch (const std::out_of_range & e) {
|
||||
got_out_of_range = true;
|
||||
error_msg = e.what();
|
||||
} catch (const std::runtime_error & e) {
|
||||
got_runtime_error = true;
|
||||
error_msg = e.what();
|
||||
bool got_runtime_error = false;
|
||||
bool got_out_of_range = false;
|
||||
std::string error_msg;
|
||||
try {
|
||||
common_chat_peg_parse(arena, bad_input, /*is_partial=*/false, pp);
|
||||
} catch (const std::out_of_range & e) {
|
||||
got_out_of_range = true;
|
||||
error_msg = e.what();
|
||||
} catch (const std::runtime_error & e) {
|
||||
got_runtime_error = true;
|
||||
error_msg = e.what();
|
||||
}
|
||||
GGML_ASSERT(!got_out_of_range && "throw path crashed with out_of_range (input.substr in effective_input space)");
|
||||
GGML_ASSERT(got_runtime_error && "throw path should produce std::runtime_error with parse position");
|
||||
}
|
||||
GGML_ASSERT(!got_out_of_range && "throw path crashed with out_of_range (input.substr in effective_input space)");
|
||||
GGML_ASSERT(got_runtime_error && "throw path should produce std::runtime_error with parse position");
|
||||
}
|
||||
|
||||
// Kimi-K2-Thinking tests - custom parser
|
||||
|
|
@ -3000,6 +3284,21 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
|||
.expect(message_assist_call_id)
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
tst.test("[TOOL_CALLS]special_function[CALL_ID]000000001[ARGS]{\"arg1\": 1}"
|
||||
"[TOOL_CALLS]special_function_with_opt[CALL_ID]000000002[ARGS]{\"arg1\": 1, \"arg2\": 2}")
|
||||
.parallel_tool_calls(true)
|
||||
.tools({
|
||||
special_function_tool, special_function_tool_with_optional_param
|
||||
})
|
||||
.expect_tool_calls({
|
||||
{ "special_function", R"({"arg1": 1})", "000000001" },
|
||||
{ "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", "000000002" },
|
||||
})
|
||||
.expect_reconstruction()
|
||||
.run();
|
||||
|
||||
|
||||
}
|
||||
// Devstral
|
||||
{
|
||||
|
|
|
|||
|
|
@ -116,6 +116,39 @@ int main() {
|
|||
// Verify tensor count
|
||||
TEST_ASSERT(model3.tensors.size() == 780, "expected tensor count == 780");
|
||||
|
||||
// Test a hybrid-attention model with array-valued head counts
|
||||
auto result4 = gguf_fetch_model_meta("ggml-org/Step-3.5-Flash-GGUF", "Q4_K");
|
||||
if (!result4.has_value()) {
|
||||
fprintf(stderr, "FAIL: could not fetch Step-3.5-Flash metadata\n");
|
||||
return 1;
|
||||
}
|
||||
const auto & model4 = result4.value();
|
||||
|
||||
fprintf(stderr, "Architecture: %s\n", model4.architecture.c_str());
|
||||
fprintf(stderr, "n_embd: %u\n", model4.n_embd);
|
||||
fprintf(stderr, "n_ff: %u\n", model4.n_ff);
|
||||
fprintf(stderr, "n_vocab: %u\n", model4.n_vocab);
|
||||
fprintf(stderr, "n_layer: %u\n", model4.n_layer);
|
||||
fprintf(stderr, "n_head: %u\n", model4.n_head);
|
||||
fprintf(stderr, "n_head_kv: %u\n", model4.n_head_kv);
|
||||
fprintf(stderr, "n_expert: %u\n", model4.n_expert);
|
||||
fprintf(stderr, "n_embd_head_k: %u\n", model4.n_embd_head_k);
|
||||
fprintf(stderr, "n_embd_head_v: %u\n", model4.n_embd_head_v);
|
||||
fprintf(stderr, "tensors: %zu\n", model4.tensors.size());
|
||||
|
||||
TEST_ASSERT(model4.architecture == "step35", "expected architecture 'step35'");
|
||||
|
||||
TEST_ASSERT(model4.n_layer == 45, "expected n_layer == 45");
|
||||
TEST_ASSERT(model4.n_embd == 4096, "expected n_embd == 4096");
|
||||
TEST_ASSERT(model4.n_ff == 11264, "expected n_ff == 11264");
|
||||
TEST_ASSERT(model4.n_head == 64, "expected n_head == 64 (first element of per-layer array)");
|
||||
TEST_ASSERT(model4.n_head_kv == 8, "expected n_head_kv == 8 (first element of per-layer array)");
|
||||
TEST_ASSERT(model4.n_expert == 288, "expected n_expert == 288");
|
||||
TEST_ASSERT(model4.n_embd_head_k == 128, "expected n_embd_head_k == 128");
|
||||
TEST_ASSERT(model4.n_embd_head_v == 128, "expected n_embd_head_v == 128");
|
||||
TEST_ASSERT(model4.n_vocab == 128896, "expected n_vocab == 128896");
|
||||
TEST_ASSERT(model4.tensors.size() == 754, "expected tensor count == 754");
|
||||
|
||||
fprintf(stderr, "=== ALL TESTS PASSED ===\n");
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -523,6 +523,18 @@ static void test_filters(testing & t) {
|
|||
"hello"
|
||||
);
|
||||
|
||||
test_template(t, "upper array",
|
||||
"{{ items|upper }}",
|
||||
{{"items", json::array({"hello", "world"})}},
|
||||
"['HELLO', 'WORLD']"
|
||||
);
|
||||
|
||||
test_template(t, "upper dict",
|
||||
"{{ items|upper }}",
|
||||
{{"items", {{"hello", "world"}}}},
|
||||
"{'HELLO': 'WORLD'}"
|
||||
);
|
||||
|
||||
test_template(t, "capitalize",
|
||||
"{{ 'heLlo World'|capitalize }}",
|
||||
json::object(),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,520 @@
|
|||
#include "../src/llama-ext.h"
|
||||
#include "ggml-cpp.h"
|
||||
#include "gguf-model-data.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ftype name <-> enum mapping
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct ftype_name_entry {
|
||||
const char * name;
|
||||
llama_ftype ftype;
|
||||
};
|
||||
|
||||
static const ftype_name_entry ftype_name_table[] = {
|
||||
{ "F32", LLAMA_FTYPE_ALL_F32 },
|
||||
{ "F16", LLAMA_FTYPE_MOSTLY_F16 },
|
||||
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16 },
|
||||
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0 },
|
||||
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1 },
|
||||
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0 },
|
||||
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1 },
|
||||
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0 },
|
||||
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K },
|
||||
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S },
|
||||
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S },
|
||||
{ "Q3_K_M", LLAMA_FTYPE_MOSTLY_Q3_K_M },
|
||||
{ "Q3_K_L", LLAMA_FTYPE_MOSTLY_Q3_K_L },
|
||||
{ "Q4_K_S", LLAMA_FTYPE_MOSTLY_Q4_K_S },
|
||||
{ "Q4_K_M", LLAMA_FTYPE_MOSTLY_Q4_K_M },
|
||||
{ "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S },
|
||||
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M },
|
||||
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K },
|
||||
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S },
|
||||
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M },
|
||||
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS },
|
||||
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS },
|
||||
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S },
|
||||
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M },
|
||||
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS },
|
||||
{ "IQ3_XS", LLAMA_FTYPE_MOSTLY_IQ3_XS },
|
||||
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S },
|
||||
{ "IQ3_M", LLAMA_FTYPE_MOSTLY_IQ3_M },
|
||||
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL },
|
||||
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS },
|
||||
{ "TQ1_0", LLAMA_FTYPE_MOSTLY_TQ1_0 },
|
||||
{ "TQ2_0", LLAMA_FTYPE_MOSTLY_TQ2_0 },
|
||||
{ "MXFP4_MOE", LLAMA_FTYPE_MOSTLY_MXFP4_MOE },
|
||||
{ "NVFP4", LLAMA_FTYPE_MOSTLY_NVFP4 },
|
||||
};
|
||||
|
||||
static llama_ftype llama_ftype_from_name(const char * name) {
|
||||
for (const auto & e : ftype_name_table) {
|
||||
if (strcmp(name, e.name) == 0) {
|
||||
return e.ftype;
|
||||
}
|
||||
}
|
||||
return (llama_ftype) -1;
|
||||
}
|
||||
|
||||
static const char * llama_ftype_to_name(llama_ftype ftype) {
|
||||
for (const auto & e : ftype_name_table) {
|
||||
if (e.ftype == ftype) {
|
||||
return e.name;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ggml_type name lookup
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static ggml_type ggml_type_from_name(const std::string & name) {
|
||||
for (int i = 0; i < GGML_TYPE_COUNT; i++) {
|
||||
const char * tname = ggml_type_name((ggml_type) i);
|
||||
if (tname && name == tname) {
|
||||
return (ggml_type) i;
|
||||
}
|
||||
}
|
||||
return GGML_TYPE_COUNT;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// File parser for snapshot files (quant type schemas)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct snapshot_section {
|
||||
llama_ftype ftype;
|
||||
ggml_type default_type;
|
||||
std::vector<std::pair<std::string, ggml_type>> overrides;
|
||||
};
|
||||
|
||||
// This function is pretty ugly, but it's a trade-off of readable snapshot files
|
||||
// versus readable parsing code
|
||||
static bool parse_snapshot_file(const std::string & path, std::vector<snapshot_section> & sections) {
|
||||
std::ifstream f(path);
|
||||
if (!f.good()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
snapshot_section * cur = nullptr;
|
||||
std::string line;
|
||||
|
||||
while (std::getline(f, line)) {
|
||||
if (line.empty() || line[0] == '#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
// section header: [FTYPE_NAME] default_type
|
||||
if (line[0] == '[') {
|
||||
auto close = line.find(']');
|
||||
if (close == std::string::npos) {
|
||||
fprintf(stderr, "parse error: missing ] in '%s'\n", line.c_str());
|
||||
return false;
|
||||
}
|
||||
std::string ftype_str = line.substr(1, close - 1);
|
||||
std::string default_str;
|
||||
size_t pos = close + 1;
|
||||
while (pos < line.size() && line[pos] == ' ') {
|
||||
pos++;
|
||||
}
|
||||
default_str = line.substr(pos);
|
||||
|
||||
llama_ftype ftype = llama_ftype_from_name(ftype_str.c_str());
|
||||
if ((int) ftype < 0) {
|
||||
fprintf(stderr, "parse error: unknown ftype '%s'\n", ftype_str.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_type dtype = ggml_type_from_name(default_str);
|
||||
if (dtype == GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "parse error: unknown default type '%s'\n", default_str.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
sections.push_back({ ftype, dtype, {} });
|
||||
cur = §ions.back();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!cur) {
|
||||
fprintf(stderr, "parse error: tensor line before any section: '%s'\n", line.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sp = line.rfind(' ');
|
||||
if (sp == std::string::npos) {
|
||||
fprintf(stderr, "parse error: no space in tensor line: '%s'\n", line.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string tname = line.substr(0, sp);
|
||||
std::string ttype = line.substr(sp + 1);
|
||||
|
||||
ggml_type gt = ggml_type_from_name(ttype);
|
||||
if (gt == GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "parse error: unknown type '%s' for tensor '%s'\n", ttype.c_str(), tname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
cur->overrides.push_back({ tname, gt });
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Remote model support using gguf-model-data.cpp
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct remote_model_spec {
|
||||
const char * repo;
|
||||
const char * quant;
|
||||
};
|
||||
|
||||
// Get model name from repo: strip org prefix, strip -GGUF suffix,
|
||||
// and strip anything up to and including first '_' (e.g. "deepseek-ai_DeepSeek-V3.1").
|
||||
static std::string model_name_from_repo(const char * repo) {
|
||||
std::string s(repo);
|
||||
|
||||
auto slash = s.find('/');
|
||||
if (slash != std::string::npos) {
|
||||
s = s.substr(slash + 1);
|
||||
}
|
||||
|
||||
const std::string suffix = "-GGUF";
|
||||
if (s.size() >= suffix.size() && s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0) {
|
||||
s = s.substr(0, s.size() - suffix.size());
|
||||
}
|
||||
|
||||
auto underscore = s.find('_');
|
||||
if (underscore != std::string::npos) {
|
||||
s = s.substr(underscore + 1);
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
static std::string snapshot_file_from_name(const std::string & name) {
|
||||
std::string lower = name;
|
||||
for (auto & c : lower) {
|
||||
c = std::tolower(c);
|
||||
}
|
||||
return lower;
|
||||
}
|
||||
|
||||
static const remote_model_spec model_specs[] = {
|
||||
{ "ggml-org/Qwen3-0.6B-GGUF", "Q8_0" },
|
||||
{ "ggml-org/GLM-4.6V-GGUF", "Q8_0" },
|
||||
{ "ggml-org/Step-3.5-Flash-GGUF", "Q4_K" },
|
||||
{ "ggml-org/Qwen3-Coder-Next-GGUF", "Q8_0" },
|
||||
{ "ggml-org/Qwen3-14B-GGUF", "Q8_0" },
|
||||
{ "ggml-org/Nemotron-Nano-3-30B-A3B-GGUF", "Q8_0" },
|
||||
{ "ggml-org/gpt-oss-120b-GGUF", "mxfp4" },
|
||||
{ "ggml-org/gemma-3-4b-it-GGUF", "Q8_0" },
|
||||
{ "bartowski/Meta-Llama-3.1-70B-Instruct-GGUF", "Q4_K_M" },
|
||||
{ "bartowski/deepseek-ai_DeepSeek-V3.1-GGUF", "IQ1_M" },
|
||||
{ "bartowski/Qwen_Qwen3.5-397B-A17B-GGUF", "IQ1_S" }, // TODO: swap with ggml-org if/when it's released
|
||||
{ "bartowski/Qwen_Qwen3.5-27B-GGUF", "Q8_0" }, // TODO: swap with ggml-org if/when it's released
|
||||
};
|
||||
|
||||
static const int n_model_specs = (int) (sizeof(model_specs) / sizeof(model_specs[0]));
|
||||
|
||||
static llama_model * build_mock_model_from_remote(const gguf_remote_model & remote) {
|
||||
llama_quant_model_desc desc = {};
|
||||
desc.architecture = remote.architecture.c_str();
|
||||
desc.n_embd = remote.n_embd;
|
||||
desc.n_ff = remote.n_ff;
|
||||
desc.n_layer = remote.n_layer;
|
||||
desc.n_head = remote.n_head;
|
||||
desc.n_head_kv = remote.n_head_kv;
|
||||
desc.n_expert = remote.n_expert;
|
||||
desc.n_embd_head_k = remote.n_embd_head_k;
|
||||
desc.n_embd_head_v = remote.n_embd_head_v;
|
||||
return llama_quant_model_from_metadata(&desc);
|
||||
}
|
||||
|
||||
// Single ggml context holding all quantizable tensors for a model.
|
||||
struct mock_tensors {
|
||||
ggml_context_ptr ctx;
|
||||
std::vector<ggml_tensor *> tensors;
|
||||
};
|
||||
|
||||
static mock_tensors build_mock_tensors(const quantize_state_impl * qs, const gguf_remote_model & remote) {
|
||||
const size_t ctx_size = remote.tensors.size() * ggml_tensor_overhead();
|
||||
struct ggml_init_params params = { ctx_size, nullptr, true };
|
||||
ggml_context_ptr ctx(ggml_init(params));
|
||||
|
||||
std::vector<ggml_tensor *> result;
|
||||
|
||||
for (const auto & t : remote.tensors) {
|
||||
ggml_tensor * gt = ggml_new_tensor_4d(ctx.get(), GGML_TYPE_F32, t.ne[0], t.ne[1], t.ne[2], t.ne[3]);
|
||||
ggml_set_name(gt, t.name.c_str());
|
||||
if (llama_quant_tensor_allows_quantization(qs, gt)) {
|
||||
result.push_back(gt);
|
||||
}
|
||||
}
|
||||
|
||||
// sort by layer index then name, matching llama_model_loader::weight_name_comparer
|
||||
std::sort(result.begin(), result.end(), [](const ggml_tensor * a, const ggml_tensor * b) {
|
||||
int a_layer = -1, b_layer = -1;
|
||||
sscanf(a->name, "blk.%d.", &a_layer);
|
||||
sscanf(b->name, "blk.%d.", &b_layer);
|
||||
if (a_layer != b_layer) {
|
||||
return a_layer < b_layer;
|
||||
}
|
||||
return strcmp(a->name, b->name) < 0;
|
||||
});
|
||||
|
||||
return { std::move(ctx), std::move(result) };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Generate mode: regenerate all snapshot files
|
||||
// Use this when either adding new models or modifying quants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static std::string generate_snapshot(const std::string & name,
|
||||
const gguf_remote_model & remote,
|
||||
quantize_state_impl * qs,
|
||||
mock_tensors & mt) {
|
||||
std::ostringstream out;
|
||||
|
||||
out << "# Model: " << name << "\n";
|
||||
out << "# n_embd=" << remote.n_embd << ", n_ff=" << remote.n_ff << ", n_vocab=" << remote.n_vocab
|
||||
<< ", n_layer=" << remote.n_layer << ", n_head=" << remote.n_head << ", n_head_kv=" << remote.n_head_kv;
|
||||
if (remote.n_expert > 0) {
|
||||
out << ", n_expert=" << remote.n_expert;
|
||||
}
|
||||
out << "\n";
|
||||
|
||||
for (int i = 0; i < LLAMA_FTYPE_GUESSED; i++) {
|
||||
llama_ftype ft = (llama_ftype) i;
|
||||
ggml_type default_type = llama_ftype_get_default_type(ft);
|
||||
if (default_type == GGML_TYPE_COUNT) {
|
||||
continue;
|
||||
}
|
||||
const char * fname = llama_ftype_to_name(ft);
|
||||
if (!fname) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<ggml_type> result_types(mt.tensors.size());
|
||||
llama_quant_compute_types(qs, ft, mt.tensors.data(), result_types.data(), mt.tensors.size());
|
||||
|
||||
out << "\n[" << fname << "] " << ggml_type_name(default_type) << "\n";
|
||||
for (size_t j = 0; j < mt.tensors.size(); j++) {
|
||||
if (result_types[j] != default_type) {
|
||||
out << ggml_get_name(mt.tensors[j]) << " " << ggml_type_name(result_types[j]) << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out.str();
|
||||
}
|
||||
|
||||
static int run_generate(const std::string & snapshot_dir) {
|
||||
fprintf(stderr, "This will overwrite all snapshot files in:\n %s\n", snapshot_dir.c_str());
|
||||
fprintf(stderr, "Continue? [y/N] ");
|
||||
int ch = fgetc(stdin);
|
||||
if (ch != 'y' && ch != 'Y') {
|
||||
fprintf(stderr, "Aborted.\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
int n_written = 0;
|
||||
|
||||
for (int m = 0; m < n_model_specs; m++) {
|
||||
const auto & spec = model_specs[m];
|
||||
std::string name = model_name_from_repo(spec.repo);
|
||||
|
||||
fprintf(stderr, "Fetching model metadata for %s from %s...\n", name.c_str(), spec.repo);
|
||||
auto result = gguf_fetch_model_meta(spec.repo, spec.quant);
|
||||
if (!result.has_value()) {
|
||||
fprintf(stderr, "ERROR: could not fetch model metadata for %s\n", name.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto & remote = result.value();
|
||||
llama_model * model = build_mock_model_from_remote(remote);
|
||||
llama_model_quantize_params qparams = llama_model_quantize_default_params();
|
||||
quantize_state_impl * qs = llama_quant_init(model, &qparams);
|
||||
auto mt = build_mock_tensors(qs, remote);
|
||||
|
||||
std::string content = generate_snapshot(name, remote, qs, mt);
|
||||
std::string path = snapshot_dir + "/" + snapshot_file_from_name(name) + ".schema";
|
||||
|
||||
std::ofstream f(path);
|
||||
if (!f.good()) {
|
||||
fprintf(stderr, "ERROR: could not write %s\n", path.c_str());
|
||||
llama_quant_free(qs);
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
}
|
||||
f << content;
|
||||
n_written++;
|
||||
fprintf(stderr, " wrote %s\n", path.c_str());
|
||||
llama_quant_free(qs);
|
||||
llama_model_free(model);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%d files written\n", n_written);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test mode: compare against snapshot files
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static bool run_test_section(quantize_state_impl * qs, mock_tensors & mt, const snapshot_section & section) {
|
||||
// verify default_type matches what llama_ftype_get_default_type returns
|
||||
ggml_type computed_default = llama_ftype_get_default_type(section.ftype);
|
||||
if (computed_default != section.default_type) {
|
||||
printf(" FAIL [%s] default type mismatch: file says %s, code says %s\n", llama_ftype_to_name(section.ftype),
|
||||
ggml_type_name(section.default_type), ggml_type_name(computed_default));
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<ggml_type> result_types(mt.tensors.size());
|
||||
llama_quant_compute_types(qs, section.ftype, mt.tensors.data(), result_types.data(), mt.tensors.size());
|
||||
|
||||
std::map<std::string, ggml_type> override_map(section.overrides.begin(), section.overrides.end());
|
||||
|
||||
bool all_pass = true;
|
||||
int n_override_found = 0;
|
||||
|
||||
for (size_t i = 0; i < mt.tensors.size(); i++) {
|
||||
const char * name = ggml_get_name(mt.tensors[i]);
|
||||
ggml_type got = result_types[i];
|
||||
|
||||
ggml_type expected = section.default_type;
|
||||
auto it = override_map.find(name);
|
||||
if (it != override_map.end()) {
|
||||
expected = it->second;
|
||||
n_override_found++;
|
||||
}
|
||||
|
||||
if (got != expected) {
|
||||
printf(" FAIL %-50s %-10s expected %s, got %s\n", name, llama_ftype_to_name(section.ftype),
|
||||
ggml_type_name(expected), ggml_type_name(got));
|
||||
all_pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_override_found != (int) section.overrides.size()) {
|
||||
printf(" FAIL [%s] override count mismatch: listed %d, matched %d\n", llama_ftype_to_name(section.ftype),
|
||||
(int) section.overrides.size(), n_override_found);
|
||||
all_pass = false;
|
||||
}
|
||||
|
||||
return all_pass;
|
||||
}
|
||||
|
||||
static int run_remote_tests(const std::string & snapshot_dir, const char * argv0) {
|
||||
int total_pass = 0;
|
||||
int total_fail = 0;
|
||||
int total_skip = 0;
|
||||
|
||||
for (int m = 0; m < n_model_specs; m++) {
|
||||
const auto & spec = model_specs[m];
|
||||
std::string name = model_name_from_repo(spec.repo);
|
||||
printf("=== %s ===\n", name.c_str());
|
||||
|
||||
auto result = gguf_fetch_model_meta(spec.repo, spec.quant, "", false);
|
||||
if (!result.has_value()) {
|
||||
printf(" SKIP (could not fetch model metadata)\n\n");
|
||||
total_skip++;
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto & remote = result.value();
|
||||
llama_model * model = build_mock_model_from_remote(remote);
|
||||
llama_model_quantize_params qparams = llama_model_quantize_default_params();
|
||||
quantize_state_impl * qs = llama_quant_init(model, &qparams);
|
||||
auto mt = build_mock_tensors(qs, remote);
|
||||
|
||||
std::string snapshot_path = snapshot_dir + "/" + snapshot_file_from_name(name) + ".schema";
|
||||
std::vector<snapshot_section> sections;
|
||||
if (!parse_snapshot_file(snapshot_path, sections)) {
|
||||
printf(" SKIP (could not read snapshot file: %s)\n\n", snapshot_path.c_str());
|
||||
llama_quant_free(qs);
|
||||
llama_model_free(model);
|
||||
total_skip++;
|
||||
continue;
|
||||
}
|
||||
|
||||
int model_pass = 0;
|
||||
int model_fail = 0;
|
||||
|
||||
for (const auto & section : sections) {
|
||||
bool pass = run_test_section(qs, mt, section);
|
||||
if (pass) {
|
||||
model_pass++;
|
||||
} else {
|
||||
model_fail++;
|
||||
}
|
||||
}
|
||||
|
||||
printf(" %s %s: %d/%d ftype sections passed (%d tensors)\n", model_fail == 0 ? "PASS" : "FAIL", name.c_str(),
|
||||
model_pass, model_pass + model_fail, (int) mt.tensors.size());
|
||||
printf("\n");
|
||||
|
||||
if (model_fail == 0) {
|
||||
total_pass++;
|
||||
} else {
|
||||
total_fail++;
|
||||
}
|
||||
|
||||
llama_quant_free(qs);
|
||||
llama_model_free(model);
|
||||
}
|
||||
|
||||
printf("%d/%d models passed", total_pass, total_pass + total_fail);
|
||||
if (total_skip > 0) {
|
||||
printf(", %d skipped", total_skip);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
if (total_fail > 0) {
|
||||
printf("\nIf these changes are intentional, regenerate snapshot files with:\n");
|
||||
printf(" %s --generate\n", argv0);
|
||||
}
|
||||
|
||||
return total_fail > 0 ? 1 : 0;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::string snapshot_dir = SNAPSHOT_DIR;
|
||||
bool generate = false;
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
if (strcmp(argv[i], "--generate") == 0) {
|
||||
generate = true;
|
||||
} else if (strcmp(argv[i], "--snapshot-dir") == 0 && i + 1 < argc) {
|
||||
snapshot_dir = argv[++i];
|
||||
}
|
||||
}
|
||||
|
||||
if (generate) {
|
||||
return run_generate(snapshot_dir);
|
||||
}
|
||||
|
||||
// suppress llama log warnings during test (e.g. tensor type fallback messages)
|
||||
llama_log_set([](enum ggml_log_level, const char *, void *) {}, nullptr);
|
||||
|
||||
return run_remote_tests(snapshot_dir, argv[0]);
|
||||
}
|
||||
|
|
@ -5,15 +5,15 @@
|
|||
#include "gguf.h"
|
||||
#include "jinja/runtime.h"
|
||||
#include "log.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "peg-parser.h"
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum class output_mode {
|
||||
|
|
@ -34,14 +34,14 @@ enum class input_message_type {
|
|||
};
|
||||
|
||||
struct debug_options {
|
||||
std::string template_path;
|
||||
bool with_tools = true;
|
||||
bool generation_prompt = true;
|
||||
bool enable_reasoning = true;
|
||||
bool debug_jinja = false;
|
||||
bool force_tool_call = false;
|
||||
output_mode mode = output_mode::BOTH;
|
||||
input_message_type input_message = input_message_type::NONE;
|
||||
std::string template_path;
|
||||
bool with_tools = true;
|
||||
bool generation_prompt = true;
|
||||
bool enable_reasoning = true;
|
||||
bool debug_jinja = false;
|
||||
bool force_tool_call = false;
|
||||
output_mode mode = output_mode::BOTH;
|
||||
input_message_type input_message = input_message_type::NONE;
|
||||
};
|
||||
|
||||
static std::string read_file(const std::string & path) {
|
||||
|
|
@ -274,7 +274,7 @@ static void render_scenario(const common_chat_template & tmpl,
|
|||
json final_messages = messages;
|
||||
if (add_generation_prompt && !messages.empty() && messages.back().value("role", "") == "assistant") {
|
||||
final_messages.push_back(json{
|
||||
{ "role", "user" },
|
||||
{ "role", "user" },
|
||||
{ "content", "Now please continue with another response." }
|
||||
});
|
||||
}
|
||||
|
|
@ -305,7 +305,7 @@ static void render_all_scenarios(const common_chat_template & tmpl,
|
|||
const json & tools,
|
||||
bool add_generation_prompt,
|
||||
bool enable_thinking,
|
||||
input_message_type message_type) {
|
||||
input_message_type message_type) {
|
||||
json user_msg = build_user_message();
|
||||
|
||||
auto render_if = [&](input_message_type type, const std::string & name, const json & assistant_msg) {
|
||||
|
|
@ -335,6 +335,24 @@ static void render_all_scenarios(const common_chat_template & tmpl,
|
|||
}
|
||||
}
|
||||
|
||||
static autoparser::generation_params prepare_params(const debug_options & opts, const json & tools) {
|
||||
autoparser::generation_params params;
|
||||
params.messages = json::array({ build_user_message() });
|
||||
params.reasoning_format = opts.enable_reasoning ? COMMON_REASONING_FORMAT_DEEPSEEK : COMMON_REASONING_FORMAT_NONE;
|
||||
params.enable_thinking = opts.enable_reasoning;
|
||||
params.add_generation_prompt = opts.generation_prompt;
|
||||
|
||||
if (opts.with_tools) {
|
||||
params.tools = tools;
|
||||
params.tool_choice = opts.force_tool_call ? COMMON_CHAT_TOOL_CHOICE_REQUIRED : COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
} else {
|
||||
params.tools = json();
|
||||
params.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
}
|
||||
params.parallel_tool_calls = false;
|
||||
return params;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
// Set log level to most verbose to capture all debug output
|
||||
common_log_set_verbosity_thold(99);
|
||||
|
|
@ -369,49 +387,41 @@ int main(int argc, char ** argv) {
|
|||
try {
|
||||
common_chat_template chat_template(template_source, "", "");
|
||||
|
||||
// Build tools definition
|
||||
json tools = opts.with_tools ? build_tools_definition() : json();
|
||||
|
||||
// Render template scenarios if requested
|
||||
if (opts.input_message != input_message_type::NONE &&
|
||||
(opts.mode == output_mode::TEMPLATE || opts.mode == output_mode::BOTH)) {
|
||||
autoparser::generation_params params = prepare_params(opts, tools);
|
||||
common_chat_params parser_data;
|
||||
if (std::optional<common_chat_params> spec_tmpl =
|
||||
common_chat_try_specialized_template(chat_template, template_source, params)) {
|
||||
LOG_ERR("\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR(" TEMPLATE RENDERING OUTPUT\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR("This template uses a specialized parser, analysis results will not be available.");
|
||||
parser_data = *spec_tmpl;
|
||||
} else {
|
||||
// Render template scenarios if requested
|
||||
if (opts.input_message != input_message_type::NONE &&
|
||||
(opts.mode == output_mode::TEMPLATE || opts.mode == output_mode::BOTH)) {
|
||||
LOG_ERR("\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR(" TEMPLATE RENDERING OUTPUT\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
|
||||
render_all_scenarios(chat_template, tools, opts.generation_prompt, opts.enable_reasoning,
|
||||
opts.input_message);
|
||||
}
|
||||
|
||||
// Output analysis if requested
|
||||
if (opts.mode == output_mode::ANALYSIS || opts.mode == output_mode::BOTH) {
|
||||
LOG_ERR("\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR(" TEMPLATE ANALYSIS\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
|
||||
autoparser::autoparser analysis;
|
||||
analysis.analyze_template(chat_template);
|
||||
|
||||
// Generate Parser
|
||||
autoparser::generation_params params;
|
||||
params.messages = json::array({ build_user_message() });
|
||||
params.reasoning_format =
|
||||
opts.enable_reasoning ? COMMON_REASONING_FORMAT_DEEPSEEK : COMMON_REASONING_FORMAT_NONE;
|
||||
params.enable_thinking = opts.enable_reasoning;
|
||||
params.add_generation_prompt = opts.generation_prompt;
|
||||
|
||||
if (opts.with_tools) {
|
||||
params.tools = tools;
|
||||
params.tool_choice = opts.force_tool_call ? COMMON_CHAT_TOOL_CHOICE_REQUIRED : COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
} else {
|
||||
params.tools = json();
|
||||
params.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
render_all_scenarios(chat_template, tools, opts.generation_prompt, opts.enable_reasoning,
|
||||
opts.input_message);
|
||||
}
|
||||
params.parallel_tool_calls = false;
|
||||
|
||||
auto parser_data = autoparser::peg_generator::generate_parser(chat_template, params, analysis);
|
||||
// Output analysis if requested
|
||||
if (opts.mode == output_mode::ANALYSIS || opts.mode == output_mode::BOTH) {
|
||||
LOG_ERR("\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
LOG_ERR(" TEMPLATE ANALYSIS\n");
|
||||
LOG_ERR("================================================================================\n");
|
||||
|
||||
autoparser::autoparser analysis;
|
||||
analysis.analyze_template(chat_template);
|
||||
|
||||
// Generate Parser
|
||||
parser_data = autoparser::peg_generator::generate_parser(chat_template, params, analysis);
|
||||
}
|
||||
|
||||
LOG_ERR("\n=== Generated Parser ===\n");
|
||||
common_peg_arena arena;
|
||||
|
|
|
|||
Loading…
Reference in New Issue