diff --git a/.devops/intel.Dockerfile b/.devops/intel.Dockerfile
index 9ce80a71eb..cd2f9aa79b 100644
--- a/.devops/intel.Dockerfile
+++ b/.devops/intel.Dockerfile
@@ -1,8 +1,8 @@
-ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04
+ARG ONEAPI_VERSION=2025.2.2-0-devel-ubuntu24.04
## Build Image
-FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build
+FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS build
ARG GGML_SYCL_F16=OFF
RUN apt-get update && \
@@ -31,7 +31,7 @@ RUN mkdir -p /app/full \
&& cp requirements.txt /app/full \
&& cp .devops/tools.sh /app/full/tools.sh
-FROM intel/oneapi-basekit:$ONEAPI_VERSION AS base
+FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS base
RUN apt-get update \
&& apt-get install -y libgomp1 curl\
diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile
index 106c62b4dc..df9058d946 100644
--- a/.devops/rocm.Dockerfile
+++ b/.devops/rocm.Dockerfile
@@ -1,8 +1,8 @@
ARG UBUNTU_VERSION=24.04
# This needs to generally match the container host's environment.
-ARG ROCM_VERSION=6.4
-ARG AMDGPU_VERSION=6.4
+ARG ROCM_VERSION=7.0
+ARG AMDGPU_VERSION=7.0
# Target the ROCm build image
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
@@ -13,9 +13,8 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
# Unless otherwise specified, we make a fat build.
# List from https://github.com/ggml-org/llama.cpp/pull/1087#issuecomment-1682807878
# This is mostly tied to rocBLAS supported archs.
-# gfx803, gfx900, gfx1032, gfx1101, gfx1102,not officialy supported
-# gfx906 is deprecated
-#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
+# gfx803, gfx900, gfx906, gfx1032, gfx1101, gfx1102,not officialy supported
+# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151'
#ARG ROCM_DOCKER_ARCH='gfx1151'
@@ -36,13 +35,10 @@ WORKDIR /app
COPY . .
-RUN git clone https://github.com/rocm/rocwmma --branch develop --depth 1
-
RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \
cmake -S . -B build \
-DGGML_HIP=ON \
-DGGML_HIP_ROCWMMA_FATTN=ON \
- -DCMAKE_HIP_FLAGS="-I$(pwd)/rocwmma/library/include/" \
-DAMDGPU_TARGETS="$ROCM_DOCKER_ARCH" \
-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON \
-DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 410552813a..db885907cd 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -97,7 +97,7 @@ jobs:
ctest -L 'main|curl' --verbose --timeout 900
macOS-latest-cmake-x64:
- runs-on: macos-13
+ runs-on: macos-15-intel
steps:
- name: Clone
@@ -362,11 +362,11 @@ jobs:
id: checkout
uses: actions/checkout@v4
- - name: ccache
- uses: ggml-org/ccache-action@v1.2.16
- with:
- key: ubuntu-latest-cmake-rpc
- evict-old-files: 1d
+ # - name: ccache
+ # uses: ggml-org/ccache-action@v1.2.16
+ # with:
+ # key: ubuntu-latest-cmake-rpc
+ # evict-old-files: 1d
- name: Dependencies
id: depends
@@ -387,8 +387,8 @@ jobs:
cd build
ctest -L main --verbose
- ubuntu-22-cmake-vulkan:
- runs-on: ubuntu-22.04
+ ubuntu-24-cmake-vulkan:
+ runs-on: ubuntu-24.04
steps:
- name: Clone
@@ -398,20 +398,40 @@ jobs:
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
with:
- key: ubuntu-22-cmake-vulkan
+ key: ubuntu-24-cmake-vulkan
evict-old-files: 1d
- name: Dependencies
id: depends
run: |
- wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add -
- sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
+ sudo add-apt-repository -y ppa:kisak/kisak-mesa
sudo apt-get update -y
- sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev
+ sudo apt-get install -y build-essential mesa-vulkan-drivers libxcb-xinput0 libxcb-xinerama0 libxcb-cursor-dev libcurl4-openssl-dev
+
+ - name: Get latest Vulkan SDK version
+ id: vulkan_sdk_version
+ run: |
+ echo "VULKAN_SDK_VERSION=$(curl https://vulkan.lunarg.com/sdk/latest/linux.txt)" >> "$GITHUB_ENV"
+
+ - name: Cache Vulkan SDK
+ id: cache_vulkan_sdk
+ uses: actions/cache@v4
+ with:
+ path: ./vulkan_sdk
+ key: vulkan-sdk-${{ env.VULKAN_SDK_VERSION }}-${{ runner.os }}
+
+ - name: Install Vulkan SDK
+ if: steps.cache_vulkan_sdk.outputs.cache-hit != 'true'
+ id: vulkan_sdk_install
+ run: |
+ mkdir -p vulkan_sdk
+ cd vulkan_sdk
+ curl --no-progress-meter https://sdk.lunarg.com/sdk/download/latest/linux/vulkan_sdk.tar.xz | tar -Jx --strip-components=1
- name: Build
id: cmake_build
run: |
+ source ./vulkan_sdk/setup-env.sh
cmake -B build \
-DGGML_VULKAN=ON
cmake --build build --config Release -j $(nproc)
@@ -421,6 +441,7 @@ jobs:
run: |
cd build
export GGML_VK_VISIBLE_DEVICES=0
+ export GGML_VK_DISABLE_F16=1
# This is using llvmpipe and runs slower than other backends
ctest -L main --verbose --timeout 4200
@@ -487,7 +508,7 @@ jobs:
id: depends
run: |
sudo apt-get update
- sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev
+ sudo apt-get install -y build-essential git cmake rocblas-dev hipblas-dev libcurl4-openssl-dev rocwmma-dev
- name: ccache
uses: ggml-org/ccache-action@v1.2.16
@@ -1059,7 +1080,7 @@ jobs:
shell: bash
env:
- WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe
+ WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/24751ead-ddc5-4479-b9e6-f9fe2ff8b9f2/intel-deep-learning-essentials-2025.2.1.25_offline.exe
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel
ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI"
steps:
@@ -1097,10 +1118,12 @@ jobs:
id: checkout
uses: actions/checkout@v4
- - name: Clone rocWMMA repository
- id: clone_rocwmma
+ - name: Grab rocWMMA package
+ id: grab_rocwmma
run: |
- git clone https://github.com/rocm/rocwmma --branch rocm-${{ env.ROCM_VERSION }} --depth 1
+ curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/${{ env.ROCM_VERSION }}/pool/main/r/rocwmma-dev/rocwmma-dev_1.7.0.60402-120~24.04_amd64.deb"
+ 7z x rocwmma.deb
+ 7z x data.tar
- name: Cache ROCm Installation
id: cache-rocm
@@ -1161,8 +1184,9 @@ jobs:
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
- -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" `
+ -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" `
-DCMAKE_BUILD_TYPE=Release `
+ -DROCM_DIR="${env:HIP_PATH}" `
-DGGML_HIP=ON `
-DGGML_HIP_ROCWMMA_FATTN=ON `
-DGGML_RPC=ON `
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index f4eae5da11..2ad3811594 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -75,7 +75,7 @@ jobs:
name: llama-bin-macos-arm64.zip
macOS-x64:
- runs-on: macos-13
+ runs-on: macos-15-intel
steps:
- name: Clone
@@ -462,7 +462,7 @@ jobs:
shell: bash
env:
- WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe
+ WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/24751ead-ddc5-4479-b9e6-f9fe2ff8b9f2/intel-deep-learning-essentials-2025.2.1.25_offline.exe
WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel
ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI"
@@ -505,6 +505,7 @@ jobs:
cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_tbb_thread.2.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_level_zero.dll" ./build/bin
+ cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_level_zero_v2.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_opencl.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_loader.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_win_proxy_loader.dll" ./build/bin
@@ -513,10 +514,15 @@ jobs:
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/svml_dispmd.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libmmd.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libiomp5md.dll" ./build/bin
+ cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl-ls.exe" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/dnnl/latest/bin/dnnl.dll" ./build/bin
cp "${{ env.ONEAPI_ROOT }}/tbb/latest/bin/tbb12.dll" ./build/bin
+ cp "${{ env.ONEAPI_ROOT }}/tcm/latest/bin/tcm.dll" ./build/bin
+ cp "${{ env.ONEAPI_ROOT }}/tcm/latest/bin/libhwloc-15.dll" ./build/bin
+ cp "${{ env.ONEAPI_ROOT }}/umf/latest/bin/umf.dll" ./build/bin
+
echo "cp oneAPI running time dll files to ./build/bin done"
7z a llama-bin-win-sycl-x64.zip ./build/bin/*
@@ -543,10 +549,12 @@ jobs:
id: checkout
uses: actions/checkout@v4
- - name: Clone rocWMMA repository
- id: clone_rocwmma
+ - name: Grab rocWMMA package
+ id: grab_rocwmma
run: |
- git clone https://github.com/rocm/rocwmma --branch develop --depth 1
+ curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.0.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.0.0.70001-42~24.04_amd64.deb"
+ 7z x rocwmma.deb
+ 7z x data.tar
- name: Cache ROCm Installation
id: cache-rocm
@@ -601,7 +609,7 @@ jobs:
cmake -G "Unix Makefiles" -B build -S . `
-DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" `
-DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" `
- -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
+ -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.0.1/include/ -Wno-ignored-attributes -Wno-nested-anon-types" `
-DCMAKE_BUILD_TYPE=Release `
-DGGML_BACKEND_DL=ON `
-DGGML_NATIVE=OFF `
diff --git a/CODEOWNERS b/CODEOWNERS
index 89b84ce850..15e3559fda 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -59,6 +59,9 @@
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
+/ggml/src/ggml-cuda/fattn-wmma* @IMbackK
+/ggml/src/ggml-hip/ @IMbackK
+/ggml/src/ggml-cuda/vendors/hip.h @IMbackK
/ggml/src/ggml-impl.h @ggerganov @slaren
/ggml/src/ggml-metal/ @ggerganov
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
diff --git a/ci/run.sh b/ci/run.sh
index b0af51723b..c6e31fa0b8 100755
--- a/ci/run.sh
+++ b/ci/run.sh
@@ -34,9 +34,9 @@ mkdir -p "$2"
OUT=$(realpath "$1")
MNT=$(realpath "$2")
-rm -f "$OUT/*.log"
-rm -f "$OUT/*.exit"
-rm -f "$OUT/*.md"
+rm -f $OUT/*.log
+rm -f $OUT/*.exit
+rm -f $OUT/*.md
sd=`dirname $0`
cd $sd/../
@@ -607,6 +607,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
fi
ret=0
+
test $ret -eq 0 && gg_run ctest_debug
test $ret -eq 0 && gg_run ctest_release
@@ -624,4 +625,6 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
test $ret -eq 0 && gg_run ctest_with_model_release
fi
+cat $OUT/README.md
+
exit $ret
diff --git a/common/arg.cpp b/common/arg.cpp
index cbca8b5ac5..a020ac4413 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -1615,18 +1615,14 @@ static void add_rpc_devices(const std::string & servers) {
if (!rpc_reg) {
throw std::invalid_argument("failed to find RPC backend");
}
- typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
- ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
- if (!ggml_backend_rpc_add_device_fn) {
- throw std::invalid_argument("failed to find RPC device add function");
+ typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint);
+ ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
+ if (!ggml_backend_rpc_add_server_fn) {
+ throw std::invalid_argument("failed to find RPC add server function");
}
for (const auto & server : rpc_servers) {
- ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
- if (dev) {
- ggml_backend_device_register(dev);
- } else {
- throw std::invalid_argument("failed to register RPC device");
- }
+ auto reg = ggml_backend_rpc_add_server_fn(server.c_str());
+ ggml_backend_register(reg);
}
}
@@ -1932,13 +1928,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_SWA_FULL"));
add_opt(common_arg(
- {"--swa-checkpoints"}, "N",
- string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
- "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
+ {"--ctx-checkpoints", "--swa-checkpoints"}, "N",
+ string_format("max number of context checkpoints to create per slot (default: %d)\n"
+ "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
[](common_params & params, int value) {
- params.n_swa_checkpoints = value;
+ params.n_ctx_checkpoints = value;
}
- ).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
+ ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--kv-unified", "-kvu"},
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp
index 96ba8f533e..b3362519a6 100644
--- a/common/chat-parser.cpp
+++ b/common/chat-parser.cpp
@@ -75,6 +75,35 @@ bool common_chat_msg_parser::add_tool_calls(const json & arr) {
}
return true;
}
+
+bool common_chat_msg_parser::add_tool_call_short_form(const json & tool_call) {
+ if (!tool_call.is_object() || tool_call.size() != 1) {
+ return false;
+ }
+
+ // Get the tool name (the single key in the object)
+ auto it = tool_call.begin();
+ std::string name = it.key();
+
+ if (name.empty()) {
+ return false;
+ }
+
+ // Get the arguments (the nested object)
+ const json & args_json = it.value();
+ std::string arguments = "";
+
+ if (args_json.is_object()) {
+ arguments = args_json.dump();
+ } else if (args_json.is_string()) {
+ arguments = args_json;
+ } else if (!args_json.is_null()) {
+ // For other types, convert to string representation
+ arguments = args_json.dump();
+ }
+
+ return add_tool_call(name, "", arguments);
+}
void common_chat_msg_parser::finish() {
if (!is_partial_ && pos_ != input_.size()) {
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
diff --git a/common/chat-parser.h b/common/chat-parser.h
index 0e64c341a5..c8cdc63fb5 100644
--- a/common/chat-parser.h
+++ b/common/chat-parser.h
@@ -64,6 +64,9 @@ class common_chat_msg_parser {
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
bool add_tool_calls(const nlohmann::ordered_json & arr);
+ // Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } }
+ bool add_tool_call_short_form(const nlohmann::ordered_json & tool_call);
+
void finish();
bool consume_spaces();
diff --git a/common/chat.cpp b/common/chat.cpp
index e2bacdcf52..afbb2a2bdd 100644
--- a/common/chat.cpp
+++ b/common/chat.cpp
@@ -625,6 +625,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
+ case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral";
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
@@ -638,6 +639,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
+ case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
default:
throw std::runtime_error("Unknown chat format");
}
@@ -801,6 +803,7 @@ static std::string apply(
}
tmpl_inputs.add_generation_prompt = inputs.add_generation_prompt;
tmpl_inputs.extra_context = inputs.extra_context;
+ tmpl_inputs.extra_context["enable_thinking"] = inputs.enable_thinking;
if (additional_context) {
tmpl_inputs.extra_context.merge_patch(*additional_context);
}
@@ -982,6 +985,65 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
return data;
}
+
+static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
+ common_chat_params data;
+ data.prompt = apply(tmpl, inputs);
+ data.format = COMMON_CHAT_FORMAT_MAGISTRAL;
+ data.preserved_tokens = {
+ "[THINK]",
+ "[/THINK]",
+ };
+
+ if (inputs.tools.is_array() && !inputs.tools.empty()) {
+ data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+ auto schemas = json::array();
+ foreach_function(inputs.tools, [&](const json & tool) {
+ const auto & function = tool.at("function");
+ schemas.push_back({
+ {"type", "object"},
+ {"properties", {
+ {"name", {
+ {"type", "string"},
+ {"const", function.at("name")},
+ }},
+ {"arguments", function.at("parameters")},
+ {"id", {
+ {"type", "string"},
+ {"pattern", "^[a-zA-Z0-9]{9}$"},
+ }},
+ }},
+ {"required", json::array({"name", "arguments", "id"})},
+ });
+ });
+ auto schema = json {
+ {"type", "array"},
+ {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
+ {"minItems", 1},
+ };
+ if (!inputs.parallel_tool_calls) {
+ schema["maxItems"] = 1;
+ }
+ builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
+ });
+ data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
+ data.preserved_tokens.push_back("[TOOL_CALLS]");
+ } else {
+ data.grammar_lazy = false;
+ if (!inputs.json_schema.is_null()) {
+ if (!inputs.grammar.empty()) {
+ throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
+ }
+ data.grammar = json_schema_to_grammar(inputs.json_schema);
+ } else {
+ data.grammar = inputs.grammar;
+ }
+ }
+
+ return data;
+}
+
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
@@ -992,6 +1054,18 @@ static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
parse_prefixed_json_tool_call_array(builder, prefix);
}
+static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
+ builder.try_parse_reasoning("[THINK]", "[/THINK]");
+
+ if (!builder.syntax().parse_tool_calls) {
+ builder.add_content(builder.consume_rest());
+ return;
+ }
+
+ static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
+ parse_prefixed_json_tool_call_array(builder, prefix);
+}
+
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
@@ -1264,6 +1338,75 @@ static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_
}
return data;
}
+
+static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) {
+ common_chat_params data;
+
+ // Generate the prompt using the apply() function with the template
+ data.prompt = apply(tmpl, inputs);
+ data.format = COMMON_CHAT_FORMAT_APERTUS;
+
+ // Handle thinking tags appropriately based on inputs.enable_thinking
+ if (string_ends_with(data.prompt, "<|inner_prefix|>")) {
+ if (!inputs.enable_thinking) {
+ data.prompt += "<|inner_suffix|>";
+ } else {
+ data.thinking_forced_open = true;
+ }
+ }
+
+ // When tools are present, build grammar for the <|tools_prefix|> format
+ if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) {
+ data.grammar_lazy = true;
+ data.grammar = build_grammar([&](const common_grammar_builder & builder) {
+ auto schemas = json::array();
+ foreach_function(inputs.tools, [&](const json & tool) {
+ const auto & function = tool.at("function");
+ schemas.push_back({
+ { "type", "object" },
+ { "properties",
+ {
+ { function.at("name"), function.at("parameters") }
+ } },
+ { "required", json::array({ function.at("name") }) },
+ });
+ });
+ auto schema = json{
+ { "type", "array" },
+ { "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } },
+ { "minItems", 1 },
+ };
+ if (!inputs.parallel_tool_calls) {
+ schema["maxItems"] = 1;
+ }
+ builder.add_rule("root",
+ std::string(data.thinking_forced_open ? "( \"<|inner_suffix|>\" space )? " : "") +
+ "\"<|tools_prefix|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tools_suffix|>\"");
+ });
+ data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
+ // If thinking_forced_open, then we capture the <|inner_suffix|> tag in the grammar,
+ // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
+ std::string(data.thinking_forced_open ?
+ "[\\s\\S]*?(<\\|inner_suffix\\|>\\s*)" :
+ "(?:<\\|inner_prefix\\|>[\\s\\S]*?<\\|inner_suffix\\|>\\s*)?") +
+ "(<\\|tools_prefix\\|>)[\\s\\S]*" });
+ data.preserved_tokens = {
+ "<|system_start|>",
+ "<|system_end|>",
+ "<|developer_start|>",
+ "<|developer_end|>",
+ "<|user_start|>",
+ "<|user_end|>",
+ "<|assistant_start|>",
+ "<|assistant_end|>",
+ "<|inner_prefix|>",
+ "<|inner_suffix|>",
+ "<|tools_prefix|>",
+ "<|tools_suffix|>",
+ };
+ }
+ return data;
+}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
@@ -2323,6 +2466,37 @@ static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
builder.add_content(builder.consume_rest());
}
+static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
+ // Parse thinking tags
+ builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>");
+ if (!builder.syntax().parse_tool_calls) {
+ builder.add_content(builder.consume_rest());
+ return;
+ }
+
+ // Look for tool calls
+ static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>"));
+ if (auto res = builder.try_find_regex(tool_call_regex)) {
+ builder.move_to(res->groups[0].end);
+
+ auto tool_calls_data = builder.consume_json();
+ if (tool_calls_data.json.is_array()) {
+ builder.consume_spaces();
+ if (!builder.try_consume_literal("<|tools_suffix|>")) {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ for (const auto & value : tool_calls_data.json) {
+ if (value.is_object()) {
+ builder.add_tool_call_short_form(value);
+ }
+ }
+ } else {
+ throw common_chat_msg_partial_exception("Incomplete tool call");
+ }
+ }
+ builder.add_content(builder.consume_rest());
+}
+
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
// Parse thinking tags first - this handles the main reasoning content
builder.try_parse_reasoning("", "");
@@ -2567,6 +2741,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_nemotron_v2(tmpl, params);
}
+ // Apertus format detection
+ if (src.find("<|system_start|>") != std::string::npos && src.find("<|tools_prefix|>") != std::string::npos) {
+ return common_chat_params_init_apertus(tmpl, params);
+ }
+
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -2595,6 +2774,10 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
}
+ if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) {
+ return common_chat_params_init_magistral(tmpl, params);
+ }
+
// Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
@@ -2695,6 +2878,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
common_chat_parse_mistral_nemo(builder);
break;
+ case COMMON_CHAT_FORMAT_MAGISTRAL:
+ common_chat_parse_magistral(builder);
+ break;
case COMMON_CHAT_FORMAT_LLAMA_3_X:
common_chat_parse_llama_3_1(builder);
break;
@@ -2734,6 +2920,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
common_chat_parse_nemotron_v2(builder);
break;
+ case COMMON_CHAT_FORMAT_APERTUS:
+ common_chat_parse_apertus(builder);
+ break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
diff --git a/common/chat.h b/common/chat.h
index 5170fc14f4..a1afe574bd 100644
--- a/common/chat.h
+++ b/common/chat.h
@@ -101,6 +101,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
+ COMMON_CHAT_FORMAT_MAGISTRAL,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
@@ -114,6 +115,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_NEMOTRON_V2,
+ COMMON_CHAT_FORMAT_APERTUS,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
diff --git a/common/common.h b/common/common.h
index 40c6847f32..d33788bd10 100644
--- a/common/common.h
+++ b/common/common.h
@@ -424,7 +424,7 @@ struct common_params {
int32_t timeout_write = timeout_read; // http write timeout in seconds
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
- int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
+ int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 948036586e..f59239fe9d 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -894,6 +894,9 @@ class TextModel(ModelBase):
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
res = "llada-moe"
+ if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
+ # ref: https://huggingface.co/ibm-granite/granite-docling-258M
+ res = "granite-docling"
if res is None:
logger.warning("\n")
@@ -1328,6 +1331,7 @@ class MmprojModel(ModelBase):
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count)
# load preprocessor config
+ self.preprocessor_config = {}
if not self.is_mistral_format:
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
@@ -1350,7 +1354,8 @@ class MmprojModel(ModelBase):
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
# vision config
- self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"]))
+ self.image_size = self.find_vparam(["image_size"])
+ self.gguf_writer.add_vision_image_size(self.image_size)
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
@@ -2381,6 +2386,10 @@ class SmolVLMModel(MmprojModel):
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
self.gguf_writer.add_vision_use_gelu(True)
+ # Add the preprocessor longest edge size
+ preproc_image_size = self.preprocessor_config.get("size", {}).get("longest_edge", self.image_size)
+ self.gguf_writer.add_vision_preproc_image_size(preproc_image_size)
+
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".embeddings." in name:
return gguf.GGMLQuantizationType.F32
@@ -4253,7 +4262,8 @@ class Plamo2Model(TextModel):
# This logic matches modeling_plamo.py's is_mamba function
mamba_step = hparams.get("mamba_step", 2)
mamba_enabled = hparams.get("mamba_enabled", True)
- mamba_layers = []
+ num_key_value_heads = []
+ num_attention_heads = []
if mamba_enabled:
for i in range(block_count):
@@ -4263,17 +4273,21 @@ class Plamo2Model(TextModel):
else:
is_mamba = (i % mamba_step) != (mamba_step // 2)
if is_mamba:
- mamba_layers.append(0)
+ num_key_value_heads.append(0)
+ num_attention_heads.append(0)
else:
- mamba_layers.append(hparams.get("num_key_value_heads", 4))
+ num_key_value_heads.append(hparams.get("num_key_value_heads", 4))
+ num_attention_heads.append(hparams.get("num_attention_heads", 32))
- if mamba_layers:
- self.gguf_writer.add_head_count_kv(mamba_layers)
+ if num_key_value_heads and num_attention_heads:
+ self.gguf_writer.add_head_count_kv(num_key_value_heads)
+ self.gguf_writer.add_head_count(num_attention_heads)
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 2048))
self.gguf_writer.add_embedding_length(hparams.get("hidden_size", 4096))
+ self.gguf_writer.add_key_length(hparams.get("hidden_size_per_head", 128))
+ self.gguf_writer.add_value_length(hparams.get("hidden_size_per_head", 128))
self.gguf_writer.add_block_count(block_count)
- self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
@@ -8972,6 +8986,43 @@ class ModernBertModel(BertModel):
+@ModelBase.register("ApertusForCausalLM")
+class ApertusModel(LlamaModel):
+ model_arch = gguf.MODEL_ARCH.APERTUS
+ undo_permute = False
+
+ _alpha_n = {}
+ _alpha_p = {}
+ _beta = {}
+ _eps = {}
+
+ def modify_tensors(self, data_torch, name, bid):
+ # Handle xIELU activation parameters
+ n_layers = self.hparams["num_hidden_layers"]
+ if name.endswith(".act_fn.alpha_n"):
+ self._alpha_n[bid] = data_torch.to("cpu").float().item()
+ if (len(self._alpha_n) == n_layers):
+ self.gguf_writer.add_xielu_alpha_n([self._alpha_n[k] for k in sorted(self._alpha_n)])
+ return []
+ if name.endswith(".act_fn.alpha_p"):
+ self._alpha_p[bid] = data_torch.to("cpu").float().item()
+ if (len(self._alpha_p) == n_layers):
+ self.gguf_writer.add_xielu_alpha_p([self._alpha_p[k] for k in sorted(self._alpha_p)])
+ return []
+ if name.endswith(".act_fn.beta"):
+ self._beta[bid] = data_torch.to("cpu").float().item()
+ if (len(self._beta) == n_layers):
+ self.gguf_writer.add_xielu_beta([self._beta[k] for k in sorted(self._beta)])
+ return []
+ if name.endswith(".act_fn.eps"):
+ self._eps[bid] = data_torch.to("cpu").float().item()
+ if (len(self._eps) == n_layers):
+ self.gguf_writer.add_xielu_eps([self._eps[k] for k in sorted(self._eps)])
+ return []
+
+ return super().modify_tensors(data_torch, name, bid)
+
+
class MistralModel(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA
model_name = "Mistral"
@@ -9139,7 +9190,7 @@ class LazyTorchTensor(gguf.LazyBase):
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
dtype = cls._dtype_str_map[st_slice.get_dtype()]
shape: tuple[int, ...] = tuple(st_slice.get_shape())
- lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
+ lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[...] if len(s.get_shape()) == 0 else s[:])
return cast(torch.Tensor, lazy)
@classmethod
diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py
index 8f2467194d..64db41c497 100755
--- a/convert_hf_to_gguf_update.py
+++ b/convert_hf_to_gguf_update.py
@@ -141,6 +141,7 @@ models = [
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "modern-bert", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-embedding-small-english-r2", },
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
+ {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md
index 6e9b88935d..92ab27066b 100644
--- a/docs/backend/SYCL.md
+++ b/docs/backend/SYCL.md
@@ -145,12 +145,13 @@ The docker build option is currently limited to *Intel GPU* targets.
```sh
# Using FP16
docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile .
+
+# Using FP32
+docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=OFF" --target light -f .devops/intel.Dockerfile .
```
*Notes*:
-To build in default FP32 *(Slower than FP16 alternative)*, set `--build-arg="GGML_SYCL_F16=OFF"` in the previous command.
-
You can also use the `.devops/llama-server-intel.Dockerfile`, which builds the *"server"* alternative.
Check the [documentation for Docker](../docker.md) to see the available images.
@@ -160,7 +161,7 @@ Check the [documentation for Docker](../docker.md) to see the available images.
# First, find all the DRI cards
ls -la /dev/dri
# Then, pick the card that you want to use (here for e.g. /dev/dri/card1).
-docker run -it --rm -v "$(pwd):/app:Z" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card1:/dev/dri/card1 llama-cpp-sycl -m "/app/models/YOUR_MODEL_FILE" -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33
+docker run -it --rm -v "/path/to/models:/models" --device /dev/dri/renderD128:/dev/dri/renderD128 --device /dev/dri/card0:/dev/dri/card0 llama-cpp-sycl -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -c 4096 -s 0
```
*Notes:*
@@ -215,9 +216,19 @@ To target AMD GPUs with SYCL, the ROCm stack must be installed first.
2. **Install Intel® oneAPI Base toolkit**
+SYCL backend depends on:
+ - Intel® oneAPI DPC++/C++ compiler/running-time.
+ - Intel® oneAPI DPC++/C++ library (oneDPL).
+ - Intel® oneAPI Deep Neural Network Library (oneDNN).
+ - Intel® oneAPI Math Kernel Library (oneMKL).
+
- **For Intel GPU**
-The base toolkit can be obtained from the official [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page.
+All above are included in both **Intel® oneAPI Base toolkit** and **Intel® Deep Learning Essentials** packages.
+
+It's recommended to install **Intel® Deep Learning Essentials** which only provides the necessary libraries with less size.
+
+The **Intel® oneAPI Base toolkit** and **Intel® Deep Learning Essentials** can be obtained from the official [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page.
Please follow the instructions for downloading and installing the Toolkit for Linux, and preferably keep the default installation values unchanged, notably the installation path *(`/opt/intel/oneapi` by default)*.
@@ -225,6 +236,12 @@ Following guidelines/code snippets assume the default installation values. Other
Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI oneDNN for Intel GPUs.
+|Verified release|
+|-|
+|2025.2.1|
+|2025.1|
+|2024.1|
+
- **Adding support to Nvidia GPUs**
**oneAPI Plugin**: In order to enable SYCL support on Nvidia GPUs, please install the [Codeplay oneAPI Plugin for Nvidia GPUs](https://developer.codeplay.com/products/oneapi/nvidia/download). User should also make sure the plugin version matches the installed base toolkit one *(previous step)* for a seamless "oneAPI on Nvidia GPU" setup.
@@ -255,10 +272,11 @@ sycl-ls
When targeting an intel GPU, the user should expect one or more devices among the available SYCL devices. Please make sure that at least one GPU is present via `sycl-ls`, for instance `[level_zero:gpu]` in the sample output below:
```
-[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000]
-[opencl:cpu][opencl:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i7-13700K OpenCL 3.0 (Build 0) [2023.16.10.0.17_160000]
-[opencl:gpu][opencl:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [23.30.26918.50]
-[level_zero:gpu][level_zero:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26918]
+[level_zero:gpu][level_zero:0] Intel(R) oneAPI Unified Runtime over Level-Zero, Intel(R) Arc(TM) A770 Graphics 12.55.8 [1.3.29735+27]
+[level_zero:gpu][level_zero:1] Intel(R) oneAPI Unified Runtime over Level-Zero, Intel(R) UHD Graphics 730 12.2.0 [1.3.29735+27]
+[opencl:cpu][opencl:0] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i5-13400 OpenCL 3.0 (Build 0) [2025.20.8.0.06_160000]
+[opencl:gpu][opencl:1] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics OpenCL 3.0 NEO [24.39.31294]
+[opencl:gpu][opencl:2] Intel(R) OpenCL Graphics, Intel(R) UHD Graphics 730 OpenCL 3.0 NEO [24.39.31294]
```
- **Nvidia GPU**
@@ -353,7 +371,7 @@ cmake --build build --config Release -j -v
#### Retrieve and prepare model
-You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
+You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf).
##### Check device
@@ -466,7 +484,17 @@ If you already have a recent version of Microsoft Visual Studio, you can skip th
3. Install Intel® oneAPI Base toolkit
-The base toolkit can be obtained from the official [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page.
+SYCL backend depends on:
+ - Intel® oneAPI DPC++/C++ compiler/running-time.
+ - Intel® oneAPI DPC++/C++ library (oneDPL).
+ - Intel® oneAPI Deep Neural Network Library (oneDNN).
+ - Intel® oneAPI Math Kernel Library (oneMKL).
+
+All above are included in both **Intel® oneAPI Base toolkit** and **Intel® Deep Learning Essentials** packages.
+
+It's recommended to install **Intel® Deep Learning Essentials** which only provides the necessary libraries with less size.
+
+The **Intel® oneAPI Base toolkit** and **Intel® Deep Learning Essentials** can be obtained from the official [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) page.
Please follow the instructions for downloading and installing the Toolkit for Windows, and preferably keep the default installation values unchanged, notably the installation path *(`C:\Program Files (x86)\Intel\oneAPI` by default)*.
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index 56420587a9..6ce52ffc66 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -209,7 +209,6 @@ option(GGML_HIP "ggml: use HIP"
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
-option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
option(GGML_HIP_EXPORT_METRICS "ggml: enable kernel perf metrics output" OFF)
option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 62b6d65e51..f1b7407859 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -215,6 +215,8 @@ extern "C" {
// Backend registry
//
+ GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
+
GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
// Backend (reg) enumeration
diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h
index 1e67411276..72eff00273 100644
--- a/ggml/include/ggml-rpc.h
+++ b/ggml/include/ggml-rpc.h
@@ -7,26 +7,25 @@
extern "C" {
#endif
-#define RPC_PROTO_MAJOR_VERSION 2
+#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16
// backend API
-GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
+GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device);
GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);
-GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
+GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device);
-GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
+GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
-GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
- const char * cache_dir,
- size_t free_mem, size_t total_mem);
+GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
+ size_t n_threads, size_t n_devices,
+ ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
-
-GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
+GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
#ifdef __cplusplus
}
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 5028a9cebf..60c6b63d05 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -576,6 +576,7 @@ extern "C" {
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_GELU_ERF,
+ GGML_UNARY_OP_XIELU,
GGML_UNARY_OP_COUNT,
};
@@ -1150,6 +1151,18 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
+ // xIELU activation function
+ // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
+ // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
+ // that constrain the positive and negative source alpha values respectively
+ GGML_API struct ggml_tensor * ggml_xielu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float alpha_n,
+ float alpha_p,
+ float beta,
+ float eps);
+
// gated linear unit ops
// A: n columns, r rows,
// result is n / 2 columns, r rows,
@@ -1617,6 +1630,13 @@ extern "C" {
float scale,
float max_bias);
+ GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias);
+
GGML_API void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index fa46f3b491..929bc44881 100644
--- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c
@@ -392,12 +392,8 @@ static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) {
free(alloc);
}
-static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) {
- size_t max_size = 0;
- for (int i = 0; i < alloc->n_chunks; i++) {
- max_size += alloc->chunks[i]->max_size;
- }
- return max_size;
+static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc, int chunk) {
+ return chunk < alloc->n_chunks ? alloc->chunks[chunk]->max_size : 0;
}
@@ -417,10 +413,8 @@ static void ggml_vbuffer_free(struct vbuffer * buf) {
free(buf);
}
-static int ggml_vbuffer_n_chunks(struct vbuffer * buf) {
- int n = 0;
- while (n < GGML_VBUFFER_MAX_CHUNKS && buf->chunks[n]) n++;
- return n;
+static size_t ggml_vbuffer_chunk_size(struct vbuffer * buf, int chunk) {
+ return buf->chunks[chunk] ? ggml_backend_buffer_get_size(buf->chunks[chunk]) : 0;
}
static size_t ggml_vbuffer_size(struct vbuffer * buf) {
@@ -885,12 +879,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
}
}
- size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
- size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]);
-
// even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
- if (new_size > cur_size || galloc->buffers[i] == NULL) {
+ bool realloc = galloc->buffers[i] == NULL;
+ size_t new_size = 0;
+ for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) {
+ size_t cur_chunk_size = galloc->buffers[i] ? ggml_vbuffer_chunk_size(galloc->buffers[i], c) : 0;
+ size_t new_chunk_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i], c);
+ new_size += new_chunk_size;
+ if (new_chunk_size > cur_chunk_size) {
+ realloc = true;
+ }
+ }
+ if (realloc) {
#ifndef NDEBUG
+ size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
#endif
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
index 07784d6f66..6792ba986e 100644
--- a/ggml/src/ggml-backend-impl.h
+++ b/ggml/src/ggml-backend-impl.h
@@ -209,9 +209,6 @@ extern "C" {
void * context;
};
- // Internal backend registry API
- GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
-
// Add backend dynamic loading support to the backend
// Initialize the backend
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index dbc07301b2..eded6eb77e 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -2187,6 +2187,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_XIELU:
{
n_tasks = n_threads;
} break;
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 14f7dcf4f4..6275c8305a 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -8637,7 +8637,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+ const float dt_soft_plus = ggml_softplus(dt[h]);
const float dA = expf(dt_soft_plus * A[h]);
const int g = h / (nh / ng); // repeat_interleave
@@ -8734,7 +8734,7 @@ static void ggml_compute_forward_ssm_scan_f32(
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+ const float dt_soft_plus = ggml_softplus(dt[h]);
const int g = h / (nh / ng); // repeat_interleave
// dim
@@ -8997,6 +8997,10 @@ void ggml_compute_forward_unary(
{
ggml_compute_forward_exp(params, dst);
} break;
+ case GGML_UNARY_OP_XIELU:
+ {
+ ggml_compute_forward_xielu(params, dst);
+ } break;
default:
{
GGML_ABORT("fatal error");
diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp
index 4fce569b3b..cf1a4615d0 100644
--- a/ggml/src/ggml-cpu/unary-ops.cpp
+++ b/ggml/src/ggml-cpu/unary-ops.cpp
@@ -52,6 +52,15 @@ static inline float op_sqrt(float x) {
return sqrtf(x);
}
+static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
+ if (x > 0.0f) {
+ return alpha_p * x * x + beta * x;
+ } else {
+ const float min_x_eps = fminf(x, eps);
+ return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
+ }
+}
+
static inline float op_sin(float x) {
return sinf(x);
}
@@ -121,6 +130,86 @@ static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
}
}
+template
+static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
+ apply_unary_op(params, dst);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
+ apply_unary_op(params, dst);
+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
+ apply_unary_op(params, dst);
+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
+ apply_unary_op(params, dst);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ apply_unary_op(params, dst);
+ } else {
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
+ ggml_type_name(dst->type), ggml_type_name(src0->type));
+ GGML_ABORT("fatal error");
+ }
+}
+
+// Extend vec_unary_op to support functors
+template
+static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
+ constexpr auto src0_to_f32 = type_conversion_table::to_f32;
+ constexpr auto f32_to_dst = type_conversion_table::from_f32;
+
+ for (int i = 0; i < n; i++) {
+ y[i] = f32_to_dst(op(src0_to_f32(x[i])));
+ }
+}
+
+// Extend apply_unary_op to support functors
+template
+static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT( nb0 == sizeof(dst_t));
+ GGML_ASSERT(nb00 == sizeof(src0_t));
+
+ const auto [ir0, ir1] = get_thread_range(params, src0);
+
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+ vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
+ }
+}
+
+// Generic dispatcher for functors
+template
+static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
+ apply_unary_op_functor(params, dst, op);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
+ apply_unary_op_functor(params, dst, op);
+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
+ apply_unary_op_functor(params, dst, op);
+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
+ apply_unary_op_functor(params, dst, op);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ apply_unary_op_functor(params, dst, op);
+ } else {
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
+ ggml_type_name(dst->type), ggml_type_name(src0->type));
+ GGML_ABORT("fatal error");
+ }
+}
+
void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op(params, dst);
}
@@ -184,3 +273,17 @@ void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor *
void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op(params, dst);
}
+
+void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
+ const float alpha_n = ggml_get_op_params_f32(dst, 1);
+ const float alpha_p = ggml_get_op_params_f32(dst, 2);
+ const float beta = ggml_get_op_params_f32(dst, 3);
+ const float eps = ggml_get_op_params_f32(dst, 4);
+
+ const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
+ return op_xielu(f, alpha_n, alpha_p, beta, eps);
+ };
+
+ unary_op_functor(params, dst, xielu_op_params);
+}
+
diff --git a/ggml/src/ggml-cpu/unary-ops.h b/ggml/src/ggml-cpu/unary-ops.h
index b1ade2c8e3..697c1e0da0 100644
--- a/ggml/src/ggml-cpu/unary-ops.h
+++ b/ggml/src/ggml-cpu/unary-ops.h
@@ -22,6 +22,7 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
#ifdef __cplusplus
}
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index c4246b65eb..d51abbeafa 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -220,14 +220,6 @@ static const char * cu_get_error_str(CUresult err) {
#define FAST_FP16_AVAILABLE
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
-#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
-#define FP16_MMA_AVAILABLE
-#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
-
-#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
-#define FP16_MMA_AVAILABLE
-#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
-
#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
@@ -262,27 +254,6 @@ static bool fast_fp16_hardware_available(const int cc) {
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
}
-// Any FP16 tensor core instructions are available for ggml code.
-static bool fp16_mma_available(const int cc) {
-#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
- return false;
-#else
- if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
- GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
- GGML_CUDA_CC_IS_MTHREADS(cc)) {
- return true;
- } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
-#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
- return true;
-#else
- return false;
-#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
- } else {
- return false;
- }
-#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
-}
-
// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fp16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu
index 131a5099a3..68de623d80 100644
--- a/ggml/src/ggml-cuda/fattn-tile.cu
+++ b/ggml/src/ggml-cuda/fattn-tile.cu
@@ -1,6 +1,7 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile.cuh"
+#include "fattn-wmma-f16.cuh"
// kq_stride == number of KQ rows to process per iteration
// kq_nbatch == number of K columns to load in parallel for KQ calculation
@@ -190,10 +191,10 @@ static __global__ void flash_attn_tile(
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
-#ifdef FP16_MMA_AVAILABLE
+#ifdef GGML_USE_WMMA_FATTN
NO_DEVICE_CODE;
return;
-#endif // FP16_MMA_AVAILABLE
+#endif // GGML_USE_WMMA_FATTN
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh
index 59c62553b0..89ab0f1638 100644
--- a/ggml/src/ggml-cuda/fattn-vec.cuh
+++ b/ggml/src/ggml-cuda/fattn-vec.cuh
@@ -535,8 +535,6 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
-
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
index 2219191fd9..6c90d6d52b 100644
--- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu
@@ -6,19 +6,19 @@
#include "fattn-common.cuh"
#include "fattn-wmma-f16.cuh"
-#ifdef FP16_MMA_AVAILABLE
+#ifdef GGML_USE_WMMA_FATTN
#if !defined(GGML_USE_HIP)
#include
-#ifdef GGML_USE_MUSA
+#if defined(GGML_USE_MUSA)
namespace wmma = mtmusa::wmma;
#else // GGML_USE_MUSA
namespace wmma = nvcuda::wmma;
#endif // GGML_USE_MUSA
-#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
+#elif defined(GGML_USE_HIP)
#include
namespace wmma = rocwmma;
#endif // !defined(GGML_USE_HIP)
-#endif // FP16_MMA_AVAILABLE
+#endif // GGML_USE_WMMA_FATTN
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template
@@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
+#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
@@ -481,7 +481,7 @@ static __global__ void flash_attn_ext_f16(
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
+#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)))
}
constexpr int get_max_power_of_2(int x) {
diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
index beeea95eb1..1848d08836 100644
--- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
+++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
@@ -1,3 +1,49 @@
#include "common.cuh"
+#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
+#define GGML_USE_WMMA_FATTN
+#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
+
+#if defined(GGML_HIP_ROCWMMA_FATTN)
+#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+#define GGML_USE_WMMA_FATTN
+#elif defined(CDNA)
+#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance"
+#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+#if defined(RDNA3)
+#define GGML_USE_WMMA_FATTN
+#endif // defined(RDNA3)
+#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
+#define GGML_USE_WMMA_FATTN
+#elif defined(RDNA4)
+#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
+#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
+
+// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.
+static bool ggml_cuda_should_use_wmma_fattn(const int cc) {
+#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
+ return false;
+#else
+ if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) ||
+ GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
+ return true;
+ } else if (GGML_CUDA_CC_IS_CDNA(cc)){
+#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+ return true;
+#else
+ return false;
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+ } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
+ return true;
+#else
+ return false;
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
+ } else {
+ return false;
+ }
+#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
+}
+
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
index 1cbd4f5bd6..d7736d3610 100644
--- a/ggml/src/ggml-cuda/fattn.cu
+++ b/ggml/src/ggml-cuda/fattn.cu
@@ -222,7 +222,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
- if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
+ if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
return BEST_FATTN_KERNEL_NONE;
}
break;
@@ -300,7 +300,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// For large batch sizes, use the WMMA kernel if possible:
- if (fp16_mma_available(cc)) {
+ if (ggml_cuda_should_use_wmma_fattn(cc)) {
return BEST_FATTN_KERNEL_WMMA_F16;
}
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index b7e81b21bc..26e72bbc2b 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2334,6 +2334,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_ELU:
ggml_cuda_op_elu(ctx, dst);
break;
+ case GGML_UNARY_OP_XIELU:
+ ggml_cuda_op_xielu(ctx, dst);
+ break;
default:
return false;
}
diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu
index 039f284719..afe4aee240 100644
--- a/ggml/src/ggml-cuda/topk-moe.cu
+++ b/ggml/src/ggml-cuda/topk-moe.cu
@@ -13,7 +13,7 @@
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
-template
+template
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights,
int32_t * ids,
@@ -204,8 +204,6 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
- cudaStream_t stream = ctx.stream();
-
const int n_expert_used = weights->ne[1];
if (with_norm) {
diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu
index 5aff8a876a..3c564566a5 100644
--- a/ggml/src/ggml-cuda/unary.cu
+++ b/ggml/src/ggml-cuda/unary.cu
@@ -1,4 +1,5 @@
#include "unary.cuh"
+#include "convert.cuh"
static __device__ __forceinline__ float op_abs(float x) {
return fabsf(x);
@@ -375,6 +376,59 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
}
+/* CUDA kernel + launcher for xIELU */
+
+template
+static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ const float xi = ggml_cuda_cast(x[i]);
+
+ const float gate_pos = (xi > 0.0f);
+ const float y_pos = alpha_p * xi * xi + beta * xi;
+ const float min_v_eps = fminf(xi, eps);
+ const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi;
+ const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
+
+ dst[i] = ggml_cuda_cast(out);
+}
+
+template
+static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE;
+ xielu_kernel<<>>(x, dst, k, alpha_n, alpha_p, beta, eps);
+}
+
+void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const void * src0_d = src0->data;
+ void * dst_d = dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+ GGML_ASSERT(src0->type == dst->type);
+
+ const float alpha_n = ggml_get_op_params_f32(dst, 1);
+ const float alpha_p = ggml_get_op_params_f32(dst, 2);
+ const float beta = ggml_get_op_params_f32(dst, 3);
+ const float eps = ggml_get_op_params_f32(dst, 4);
+
+ if (src0->type == GGML_TYPE_F16) {
+ xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
+ } else {
+ xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
+ }
+}
+
+
+
/* silu_back */
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh
index da3caf1d89..8e7644fcd9 100644
--- a/ggml/src/ggml-cuda/unary.cuh
+++ b/ggml/src/ggml-cuda/unary.cuh
@@ -16,6 +16,7 @@
#define CUDA_SIN_BLOCK_SIZE 256
#define CUDA_COS_BLOCK_SIZE 256
#define CUDA_GLU_BLOCK_SIZE 256
+#define CUDA_XIELU_BLOCK_SIZE 256
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
@@ -72,3 +73,5 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 37386afcd4..890c103649 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -6,6 +6,10 @@
#include
#include
+#if defined(GGML_HIP_ROCWMMA_FATTN)
+#include
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
+
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt
index d327b90cce..0e2b1847e0 100644
--- a/ggml/src/ggml-hip/CMakeLists.txt
+++ b/ggml/src/ggml-hip/CMakeLists.txt
@@ -39,12 +39,6 @@ endif()
find_package(hip REQUIRED)
find_package(hipblas REQUIRED)
find_package(rocblas REQUIRED)
-if (GGML_HIP_ROCWMMA_FATTN)
- CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
- if (NOT ${FOUND_ROCWMMA})
- message(FATAL_ERROR "rocwmma has not been found")
- endif()
-endif()
if (${hip_VERSION} VERSION_LESS 6.1)
message(FATAL_ERROR "At least ROCM/HIP V6.1 is required")
@@ -117,10 +111,6 @@ if (NOT GGML_HIP_MMQ_MFMA)
add_compile_definitions(GGML_HIP_NO_MMQ_MFMA)
endif()
-if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0)
- add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12)
-endif()
-
if (GGML_HIP_EXPORT_METRICS)
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps")
endif()
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 86a1ebf62b..d0fb3bccad 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -102,6 +102,9 @@ static bool ggml_op_is_empty(enum ggml_op op) {
}
}
+static inline float ggml_softplus(float input) {
+ return (input > 20.0f) ? input : logf(1 + expf(input));
+}
//
// logging
//
diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp
index dc7d241c3a..95627d3866 100644
--- a/ggml/src/ggml-metal/ggml-metal-common.cpp
+++ b/ggml/src/ggml-metal/ggml-metal-common.cpp
@@ -112,7 +112,7 @@ static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * t
}
bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
if (tensor->src[i]) {
ggml_mem_ranges_add_src(mrs, tensor->src[i]);
}
@@ -173,7 +173,7 @@ static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor *
}
bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
- for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
if (tensor->src[i]) {
if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
return false;
diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt
index cdb3818c78..f8477a2ef3 100644
--- a/ggml/src/ggml-musa/CMakeLists.txt
+++ b/ggml/src/ggml-musa/CMakeLists.txt
@@ -56,7 +56,7 @@ if (MUSAToolkit_FOUND)
set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
foreach(SOURCE ${GGML_SOURCES_MUSA})
- set(COMPILE_FLAGS "-fsigned-char -x musa -mtgpu")
+ set(COMPILE_FLAGS "-Od3 -fno-strict-aliasing -ffast-math -fsigned-char -x musa -mtgpu -fmusa-flush-denormals-to-zero")
foreach(ARCH ${MUSA_ARCHITECTURES})
set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}")
endforeach()
diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
index f99681c84c..aad48d62a8 100644
--- a/ggml/src/ggml-rpc/ggml-rpc.cpp
+++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
@@ -105,9 +105,12 @@ enum rpc_cmd {
RPC_CMD_INIT_TENSOR,
RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO,
+ RPC_CMD_DEVICE_COUNT,
RPC_CMD_COUNT,
};
+static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");
+
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
@@ -117,7 +120,12 @@ struct rpc_msg_hello_rsp {
uint8_t patch;
};
+struct rpc_msg_device_count_rsp {
+ uint32_t device_count;
+};
+
struct rpc_msg_get_alloc_size_req {
+ uint32_t device;
rpc_tensor tensor;
};
@@ -130,6 +138,7 @@ struct rpc_msg_init_tensor_req {
};
struct rpc_msg_alloc_buffer_req {
+ uint32_t device;
uint64_t size;
};
@@ -138,10 +147,18 @@ struct rpc_msg_alloc_buffer_rsp {
uint64_t remote_size;
};
+struct rpc_msg_get_alignment_req {
+ uint32_t device;
+};
+
struct rpc_msg_get_alignment_rsp {
uint64_t alignment;
};
+struct rpc_msg_get_max_size_req {
+ uint32_t device;
+};
+
struct rpc_msg_get_max_size_rsp {
uint64_t max_size;
};
@@ -192,6 +209,10 @@ struct rpc_msg_graph_compute_rsp {
uint8_t result;
};
+struct rpc_msg_get_device_memory_req {
+ uint32_t device;
+};
+
struct rpc_msg_get_device_memory_rsp {
uint64_t free_mem;
uint64_t total_mem;
@@ -207,13 +228,15 @@ static ggml_guid_t ggml_backend_rpc_guid() {
struct ggml_backend_rpc_buffer_type_context {
std::string endpoint;
+ uint32_t device;
std::string name;
- size_t alignment;
- size_t max_size;
+ size_t alignment;
+ size_t max_size;
};
struct ggml_backend_rpc_context {
std::string endpoint;
+ uint32_t device;
std::string name;
};
@@ -608,23 +631,30 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
RPC_STATUS_ASSERT(status);
}
+static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
+ return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
+}
+
static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
- // check if src and dst are on the same server
- ggml_backend_buffer_t src_buffer = src->buffer;
- ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
- ggml_backend_buffer_t dst_buffer = dst->buffer;
- ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
- if (src_ctx->sock != dst_ctx->sock) {
- return false;
+ if (ggml_backend_buffer_is_rpc(src->buffer)) {
+ // check if src and dst are on the same server
+ ggml_backend_buffer_t src_buffer = src->buffer;
+ ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
+ ggml_backend_buffer_t dst_buffer = dst->buffer;
+ ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
+ if (src_ctx->sock != dst_ctx->sock) {
+ return false;
+ }
+ ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+ rpc_msg_copy_tensor_req request;
+ request.src = serialize_tensor(src);
+ request.dst = serialize_tensor(dst);
+ rpc_msg_copy_tensor_rsp response;
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
+ RPC_STATUS_ASSERT(status);
+ return response.result;
}
- ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
- rpc_msg_copy_tensor_req request;
- request.src = serialize_tensor(src);
- request.dst = serialize_tensor(dst);
- rpc_msg_copy_tensor_rsp response;
- bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
- RPC_STATUS_ASSERT(status);
- return response.result;
+ return false;
}
static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -653,7 +683,7 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
- rpc_msg_alloc_buffer_req request = {size};
+ rpc_msg_alloc_buffer_req request = {buft_ctx->device, size};
rpc_msg_alloc_buffer_rsp response;
auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
@@ -669,9 +699,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
}
}
-static size_t get_alignment(const std::shared_ptr & sock) {
+static size_t get_alignment(const std::shared_ptr & sock, uint32_t device) {
+ rpc_msg_get_alignment_req request = {device};
rpc_msg_get_alignment_rsp response;
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.alignment;
}
@@ -681,9 +712,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
return buft_ctx->alignment;
}
-static size_t get_max_size(const std::shared_ptr & sock) {
+static size_t get_max_size(const std::shared_ptr & sock, uint32_t device) {
+ rpc_msg_get_max_size_req request = {device};
rpc_msg_get_max_size_rsp response;
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return response.max_size;
}
@@ -700,7 +732,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty
auto sock = get_socket(buft_ctx->endpoint);
rpc_msg_get_alloc_size_req request;
-
+ request.device = buft_ctx->device;
request.tensor = serialize_tensor(tensor);
rpc_msg_get_alloc_size_rsp response;
@@ -754,7 +786,7 @@ static void add_tensor(ggml_tensor * tensor, std::vector & tensors,
tensors.push_back(serialize_tensor(tensor));
}
-static void serialize_graph(const ggml_cgraph * cgraph, std::vector & output) {
+static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::vector & output) {
uint32_t n_nodes = cgraph->n_nodes;
std::vector tensors;
std::unordered_set visited;
@@ -762,24 +794,29 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector & o
add_tensor(cgraph->nodes[i], tensors, visited);
}
// serialization format:
- // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
+ // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
uint32_t n_tensors = tensors.size();
- int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
+ int output_size = 2*sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
output.resize(output_size, 0);
- memcpy(output.data(), &n_nodes, sizeof(n_nodes));
+ uint8_t * dest = output.data();
+ memcpy(dest, &device, sizeof(device));
+ dest += sizeof(device);
+ memcpy(dest, &n_nodes, sizeof(n_nodes));
+ dest += sizeof(n_nodes);
for (uint32_t i = 0; i < n_nodes; i++) {
- memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
+ memcpy(dest + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
}
- uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
- *out_ntensors = n_tensors;
- rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
+ dest += n_nodes * sizeof(uint64_t);
+ memcpy(dest, &n_tensors, sizeof(n_tensors));
+ dest += sizeof(n_tensors);
+ rpc_tensor * out_tensors = (rpc_tensor *)dest;
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
}
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
std::vector input;
- serialize_graph(cgraph, input);
+ serialize_graph(rpc_ctx->device, cgraph, input);
rpc_msg_graph_compute_rsp response;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
@@ -804,12 +841,13 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .graph_optimize = */ NULL,
};
-ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
+ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device) {
static std::mutex mutex;
std::lock_guard lock(mutex);
+ std::string buft_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
// NOTE: buffer types are allocated and never freed; this is by design
static std::unordered_map buft_map;
- auto it = buft_map.find(endpoint);
+ auto it = buft_map.find(buft_name);
if (it != buft_map.end()) {
return it->second;
}
@@ -818,34 +856,37 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
return nullptr;
}
- size_t alignment = get_alignment(sock);
- size_t max_size = get_max_size(sock);
+ size_t alignment = get_alignment(sock, device);
+ size_t max_size = get_max_size(sock, device);
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
/* .endpoint = */ endpoint,
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
+ /* .device = */ device,
+ /* .name = */ buft_name,
/* .alignment = */ alignment,
/* .max_size = */ max_size
};
-
+ auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
- /* .device = */ ggml_backend_rpc_add_device(endpoint),
+ /* .device = */ ggml_backend_reg_dev_get(reg, device),
/* .context = */ buft_ctx
};
- buft_map[endpoint] = buft;
+ buft_map[buft_name] = buft;
return buft;
}
-ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
+ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
+ std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
- /* .endpoint = */ endpoint,
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
+ /* .endpoint = */ endpoint,
+ /* .device = */ device,
+ /* .name = */ dev_name
};
-
+ auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_rpc_guid(),
/* .iface = */ ggml_backend_rpc_interface,
- /* .device = */ ggml_backend_rpc_add_device(endpoint),
+ /* .device = */ ggml_backend_reg_dev_get(reg, device),
/* .context = */ ctx
};
return backend;
@@ -855,37 +896,39 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) {
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
}
-static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) {
+static void get_device_memory(const std::shared_ptr & sock, uint32_t device, size_t * free, size_t * total) {
+ rpc_msg_get_device_memory_req request;
+ request.device = device;
rpc_msg_get_device_memory_rsp response;
- bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
*free = response.free_mem;
*total = response.total_mem;
}
-void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
+void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total) {
auto sock = get_socket(endpoint);
if (sock == nullptr) {
*free = 0;
*total = 0;
return;
}
- get_device_memory(sock, free, total);
+ get_device_memory(sock, device, free, total);
}
// RPC server-side implementation
class rpc_server {
public:
- rpc_server(ggml_backend_t backend, const char * cache_dir)
- : backend(backend), cache_dir(cache_dir) {
+ rpc_server(std::vector backends, const char * cache_dir)
+ : backends(std::move(backends)), cache_dir(cache_dir) {
}
~rpc_server();
void hello(rpc_msg_hello_rsp & response);
- void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
- void get_alignment(rpc_msg_get_alignment_rsp & response);
- void get_max_size(rpc_msg_get_max_size_rsp & response);
+ bool alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
+ bool get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response);
+ bool get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response);
bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
bool free_buffer(const rpc_msg_free_buffer_req & request);
bool buffer_clear(const rpc_msg_buffer_clear_req & request);
@@ -906,7 +949,7 @@ private:
std::unordered_map & tensor_map);
- ggml_backend_t backend;
+ std::vector backends;
const char * cache_dir;
std::unordered_set buffers;
};
@@ -919,6 +962,10 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) {
}
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
+ uint32_t dev_id = request.device;
+ if (dev_id >= backends.size()) {
+ return false;
+ }
ggml_backend_buffer_type_t buft;
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead(),
@@ -935,10 +982,10 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
return false;
}
- LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
+ LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
if (tensor->buffer == nullptr) {
//No buffer allocated.
- buft = ggml_backend_get_default_buffer_type(backend);
+ buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
} else {
buft = tensor->buffer->buft;
}
@@ -948,33 +995,49 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
return true;
}
-void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
+bool rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
+ uint32_t dev_id = request.device;
+ if (dev_id >= backends.size()) {
+ return false;
+ }
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
response.remote_ptr = 0;
response.remote_size = 0;
if (buffer != nullptr) {
response.remote_ptr = reinterpret_cast(buffer);
response.remote_size = buffer->size;
- LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
+ LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n",
+ __func__, dev_id, request.size, response.remote_ptr, response.remote_size);
buffers.insert(buffer);
} else {
- LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
+ LOG_DBG("[%s] device: %d, size: %" PRIu64 " -> failed\n", __func__, dev_id, request.size);
}
+ return true;
}
-void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
+bool rpc_server::get_alignment(const rpc_msg_get_alignment_req & request, rpc_msg_get_alignment_rsp & response) {
+ uint32_t dev_id = request.device;
+ if (dev_id >= backends.size()) {
+ return false;
+ }
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
size_t alignment = ggml_backend_buft_get_alignment(buft);
- LOG_DBG("[%s] alignment: %lu\n", __func__, alignment);
+ LOG_DBG("[%s] device: %d, alignment: %lu\n", __func__, dev_id, alignment);
response.alignment = alignment;
+ return true;
}
-void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
- ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
+bool rpc_server::get_max_size(const rpc_msg_get_max_size_req & request, rpc_msg_get_max_size_rsp & response) {
+ uint32_t dev_id = request.device;
+ if (dev_id >= backends.size()) {
+ return false;
+ }
+ ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backends[dev_id]);
size_t max_size = ggml_backend_buft_get_max_size(buft);
- LOG_DBG("[%s] max_size: %lu\n", __func__, max_size);
+ LOG_DBG("[%s] device: %d, max_size: %lu\n", __func__, dev_id, max_size);
response.max_size = max_size;
+ return true;
}
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
@@ -1332,23 +1395,33 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) {
// serialization format:
- // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
- if (input.size() < sizeof(uint32_t)) {
+ // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
+ if (input.size() < 2*sizeof(uint32_t)) {
+ return false;
+ }
+ const uint8_t * src = input.data();
+ uint32_t device;
+ memcpy(&device, src, sizeof(device));
+ src += sizeof(device);
+ if (device >= backends.size()) {
return false;
}
uint32_t n_nodes;
- memcpy(&n_nodes, input.data(), sizeof(n_nodes));
- if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
+ memcpy(&n_nodes, src, sizeof(n_nodes));
+ src += sizeof(n_nodes);
+ if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
return false;
}
- const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
+ const uint64_t * nodes = (const uint64_t *)src;
+ src += n_nodes*sizeof(uint64_t);
uint32_t n_tensors;
- memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
- if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
+ memcpy(&n_tensors, src, sizeof(n_tensors));
+ src += sizeof(n_tensors);
+ if (input.size() < 2*sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
return false;
}
- const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
- LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
+ const rpc_tensor * tensors = (const rpc_tensor *)src;
+ LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
@@ -1380,7 +1453,7 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph
return false;
}
}
- ggml_status status = ggml_backend_graph_compute(backend, graph);
+ ggml_status status = ggml_backend_graph_compute(backends[device], graph);
response.result = status;
return true;
}
@@ -1391,9 +1464,9 @@ rpc_server::~rpc_server() {
}
}
-static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
- sockfd_t sockfd, size_t free_mem, size_t total_mem) {
- rpc_server server(backend, cache_dir);
+static void rpc_serve_client(const std::vector & backends, const char * cache_dir,
+ sockfd_t sockfd, const std::vector & free_mem, const std::vector & total_mem) {
+ rpc_server server(backends, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
return;
@@ -1425,13 +1498,26 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
// HELLO command is handled above
return;
}
+ case RPC_CMD_DEVICE_COUNT: {
+ if (!recv_msg(sockfd, nullptr, 0)) {
+ return;
+ }
+ rpc_msg_device_count_rsp response;
+ response.device_count = backends.size();
+ if (!send_msg(sockfd, &response, sizeof(response))) {
+ return;
+ }
+ break;
+ }
case RPC_CMD_ALLOC_BUFFER: {
rpc_msg_alloc_buffer_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_alloc_buffer_rsp response;
- server.alloc_buffer(request, response);
+ if (!server.alloc_buffer(request, response)) {
+ return;
+ }
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
@@ -1452,22 +1538,28 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
break;
}
case RPC_CMD_GET_ALIGNMENT: {
- if (!recv_msg(sockfd, nullptr, 0)) {
+ rpc_msg_get_alignment_req request;
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_get_alignment_rsp response;
- server.get_alignment(response);
+ if (!server.get_alignment(request, response)) {
+ return;
+ }
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_GET_MAX_SIZE: {
- if (!recv_msg(sockfd, nullptr, 0)) {
+ rpc_msg_get_max_size_req request;
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_get_max_size_rsp response;
- server.get_max_size(response);
+ if (!server.get_max_size(request, response)) {
+ return;
+ }
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
@@ -1593,12 +1685,19 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
break;
}
case RPC_CMD_GET_DEVICE_MEMORY: {
- if (!recv_msg(sockfd, nullptr, 0)) {
+ rpc_msg_get_device_memory_req request;
+ if (!recv_msg(sockfd, &request, sizeof(request))) {
+ return;
+ }
+ auto dev_id = request.device;
+ if (dev_id >= backends.size()) {
return;
}
rpc_msg_get_device_memory_rsp response;
- response.free_mem = free_mem;
- response.total_mem = total_mem;
+ response.free_mem = free_mem[dev_id];
+ response.total_mem = total_mem[dev_id];
+ LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
+ response.free_mem, response.total_mem);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
@@ -1612,16 +1711,41 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
}
}
-void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
- const char * cache_dir,
- size_t free_mem, size_t total_mem) {
+void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
+ size_t n_threads, size_t n_devices,
+ ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
+ if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
+ fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
+ return;
+ }
+ std::vector backends;
+ std::vector free_mem_vec(free_mem, free_mem + n_devices);
+ std::vector total_mem_vec(total_mem, total_mem + n_devices);
printf("Starting RPC server v%d.%d.%d\n",
RPC_PROTO_MAJOR_VERSION,
RPC_PROTO_MINOR_VERSION,
RPC_PROTO_PATCH_VERSION);
printf(" endpoint : %s\n", endpoint);
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
- printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
+ printf("Devices:\n");
+ for (size_t i = 0; i < n_devices; i++) {
+ auto dev = devices[i];
+ printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
+ total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
+ auto backend = ggml_backend_dev_init(dev, nullptr);
+ if (!backend) {
+ fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
+ return;
+ }
+ backends.push_back(backend);
+ ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
+ if (reg) {
+ auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
+ if (ggml_backend_set_n_threads_fn) {
+ ggml_backend_set_n_threads_fn(backend, n_threads);
+ }
+ }
+ }
std::string host;
int port;
@@ -1649,22 +1773,27 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
fprintf(stderr, "Failed to accept client connection\n");
return;
}
- printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
+ printf("Accepted client connection\n");
fflush(stdout);
- rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
+ rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
printf("Client connection closed\n");
fflush(stdout);
}
#ifdef _WIN32
WSACleanup();
#endif
+ for (auto backend : backends) {
+ ggml_backend_free(backend);
+ }
}
// device interface
struct ggml_backend_rpc_device_context {
std::string endpoint;
+ uint32_t device;
std::string name;
+ std::string description;
};
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
@@ -1676,15 +1805,13 @@ static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
- return ctx->name.c_str();
+ return ctx->description.c_str();
}
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
- ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
-
- GGML_UNUSED(dev);
+ ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), ctx->device, free, total);
}
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
@@ -1710,7 +1837,7 @@ static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggm
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
- return ggml_backend_rpc_init(ctx->endpoint.c_str());
+ return ggml_backend_rpc_init(ctx->endpoint.c_str(), ctx->device);
GGML_UNUSED(params);
}
@@ -1718,7 +1845,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
- return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str(), ctx->device);
GGML_UNUSED(dev);
}
@@ -1736,7 +1863,7 @@ static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_b
}
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
- return buft_ctx->endpoint == dev_ctx->endpoint;
+ return buft_ctx->endpoint == dev_ctx->endpoint && buft_ctx->device == dev_ctx->device;
}
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
@@ -1759,28 +1886,34 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
// backend reg interface
-static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
- return "RPC";
+struct ggml_backend_rpc_reg_context {
+ std::string name;
+ std::vector devices;
+};
- GGML_UNUSED(reg);
+static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
+ return ctx ? ctx->name.c_str() : "RPC";
}
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
- return 0;
-
- GGML_UNUSED(reg);
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
+ return ctx ? ctx->devices.size() : 0;
}
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
- GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
-
- GGML_UNUSED(reg);
- GGML_UNUSED(index);
+ ggml_backend_rpc_reg_context * ctx = (ggml_backend_rpc_reg_context *)reg->context;
+ if (ctx == nullptr) {
+ GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_rpc_add_server instead");
+ } else {
+ GGML_ASSERT(index < ctx->devices.size());
+ return ctx->devices[index];
+ }
}
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
- if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
- return (void *)ggml_backend_rpc_add_device;
+ if (std::strcmp(name, "ggml_backend_rpc_add_server") == 0) {
+ return (void *)ggml_backend_rpc_add_server;
}
if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
return (void *)ggml_backend_rpc_start_server;
@@ -1807,30 +1940,61 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
return &ggml_backend_rpc_reg;
}
-ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
- static std::unordered_map dev_map;
-
- static std::mutex mutex;
- std::lock_guard lock(mutex);
-
- if (dev_map.find(endpoint) != dev_map.end()) {
- return dev_map[endpoint];
- }
-
- ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
- /* .endpoint = */ endpoint,
- /* .name = */ "RPC[" + std::string(endpoint) + "]",
- };
-
- ggml_backend_dev_t dev = new ggml_backend_device {
- /* .iface = */ ggml_backend_rpc_device_i,
- /* .reg = */ ggml_backend_rpc_reg(),
- /* .context = */ ctx,
- };
-
- dev_map[endpoint] = dev;
-
- return dev;
+static uint32_t ggml_backend_rpc_get_device_count(const char * endpoint) {
+ auto sock = get_socket(endpoint);
+ rpc_msg_device_count_rsp response;
+ bool status = send_rpc_cmd(sock, RPC_CMD_DEVICE_COUNT, nullptr, 0, &response, sizeof(response));
+ RPC_STATUS_ASSERT(status);
+ return response.device_count;
}
+static const ggml_backend_reg_i ggml_backend_rpc_reg_interface = {
+ /* .get_name = */ ggml_backend_rpc_reg_get_name,
+ /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
+ /* .get_device = */ ggml_backend_rpc_reg_get_device,
+ /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) {
+ static std::unordered_map reg_map;
+ static std::mutex mutex;
+ static uint32_t dev_id = 0;
+ std::lock_guard lock(mutex);
+ if (reg_map.find(endpoint) != reg_map.end()) {
+ return reg_map[endpoint];
+ }
+ uint32_t dev_count = ggml_backend_rpc_get_device_count(endpoint);
+ if (dev_count == 0) {
+ return nullptr;
+ }
+ ggml_backend_rpc_reg_context * ctx = new ggml_backend_rpc_reg_context;
+ ctx->name = "RPC[" + std::string(endpoint) + "]";
+ for (uint32_t ind = 0; ind < dev_count; ind++) {
+ std::string dev_name = "RPC" + std::to_string(dev_id);
+ std::string dev_desc = std::string(endpoint);
+ ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context {
+ /* .endpoint = */ endpoint,
+ /* .device = */ ind,
+ /* .name = */ dev_name,
+ /* .description = */ dev_desc
+ };
+
+ ggml_backend_dev_t dev = new ggml_backend_device {
+ /* .iface = */ ggml_backend_rpc_device_i,
+ /* .reg = */ ggml_backend_rpc_reg(),
+ /* .context = */ dev_ctx,
+ };
+ ctx->devices.push_back(dev);
+ dev_id++;
+ }
+ ggml_backend_reg_t reg = new ggml_backend_reg {
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
+ /* .iface = */ ggml_backend_rpc_reg_interface,
+ /* .context = */ ctx
+ };
+ reg_map[endpoint] = reg;
+ return reg;
+}
+
+
GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)
diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt
index b97e7bf995..83a83887b5 100644
--- a/ggml/src/ggml-vulkan/CMakeLists.txt
+++ b/ggml/src/ggml-vulkan/CMakeLists.txt
@@ -1,5 +1,6 @@
cmake_minimum_required(VERSION 3.19)
cmake_policy(SET CMP0114 NEW)
+cmake_policy(SET CMP0116 NEW)
find_package(Vulkan COMPONENTS glslc REQUIRED)
@@ -54,25 +55,25 @@ if (Vulkan_FOUND)
# Test all shader extensions
test_shader_extension_support(
"GL_KHR_cooperative_matrix"
- "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
+ "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat.comp"
"GGML_VULKAN_COOPMAT_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_NV_cooperative_matrix2"
- "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
+ "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2.comp"
"GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_EXT_integer_dot_product"
- "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
+ "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp"
"GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT"
)
test_shader_extension_support(
"GL_EXT_bfloat16"
- "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
+ "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/bfloat16.comp"
"GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT"
)
@@ -160,7 +161,6 @@ if (Vulkan_FOUND)
set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$")
set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}")
set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp")
- set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp")
set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders")
set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv")
@@ -176,24 +176,35 @@ if (Vulkan_FOUND)
add_custom_command(
OUTPUT ${_ggml_vk_header}
- ${_ggml_vk_source}
-
COMMAND ${_ggml_vk_genshaders_cmd}
- --glslc ${Vulkan_GLSLC_EXECUTABLE}
- --input-dir ${_ggml_vk_input_dir}
--output-dir ${_ggml_vk_output_dir}
--target-hpp ${_ggml_vk_header}
- --target-cpp ${_ggml_vk_source}
- --no-clean
-
- DEPENDS ${_ggml_vk_shader_files}
- ${_ggml_vk_shaders_gen_sources}
+ DEPENDS ${_ggml_vk_shaders_gen_sources}
vulkan-shaders-gen
-
- COMMENT "Generate vulkan shaders"
+ COMMENT "Generate vulkan shaders header"
)
+ target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header})
- target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header})
+ foreach (file_full ${_ggml_vk_shader_files})
+ get_filename_component(file ${file_full} NAME)
+ set (_ggml_vk_target_cpp "${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp")
+
+ add_custom_command(
+ OUTPUT ${_ggml_vk_target_cpp}
+ DEPFILE ${_ggml_vk_target_cpp}.d
+ COMMAND ${_ggml_vk_genshaders_cmd}
+ --glslc ${Vulkan_GLSLC_EXECUTABLE}
+ --source ${file_full}
+ --output-dir ${_ggml_vk_output_dir}
+ --target-hpp ${_ggml_vk_header}
+ --target-cpp ${_ggml_vk_target_cpp}
+ DEPENDS ${file_full}
+ ${_ggml_vk_shaders_gen_sources}
+ vulkan-shaders-gen
+ COMMENT "Generate vulkan shaders for ${file}"
+ )
+ target_sources(ggml-vulkan PRIVATE ${_ggml_vk_target_cpp})
+ endforeach()
else()
message(WARNING "Vulkan not found")
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 003a901067..3cd89c7116 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -393,6 +393,7 @@ struct vk_device_struct {
vk::PhysicalDeviceProperties properties;
std::string name;
uint64_t max_memory_allocation_size;
+ uint64_t max_buffer_size;
uint64_t suballocation_block_size;
bool fp16;
bool bf16;
@@ -1563,6 +1564,12 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
static void ggml_backend_vk_free(ggml_backend_t backend);
+static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) {
+ const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset},
+ VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
+ return range;
+}
+
// Wait for ctx->fence to be signaled.
static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
// Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
@@ -2012,8 +2019,8 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) {
VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
- if (size > device->max_memory_allocation_size) {
- throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit");
+ if (size > device->max_buffer_size) {
+ throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
}
vk_buffer buf = std::make_shared();
@@ -2159,8 +2166,8 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) {
buf.reset();
}
-static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) {
- return { buf, 0, VK_WHOLE_SIZE };
+static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) {
+ return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) };
}
static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
@@ -2614,8 +2621,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t D_lsb = D ^ (D & (D-1));
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
- // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
- GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
};
@@ -3855,17 +3860,27 @@ static vk_device ggml_vk_get_device(size_t idx) {
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
- device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
+ device->max_memory_allocation_size = std::stoull(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
} else if (maintenance4_support) {
device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
} else {
device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
}
+ const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv("GGML_VK_FORCE_MAX_BUFFER_SIZE");
+
+ if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) {
+ device->max_buffer_size = std::stoull(GGML_VK_FORCE_MAX_BUFFER_SIZE);
+ } else if (maintenance4_support) {
+ device->max_buffer_size = props4.maxBufferSize;
+ } else {
+ device->max_buffer_size = device->max_memory_allocation_size;
+ }
+
const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE");
if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
- device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
+ device->suballocation_block_size = std::stoull(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
} else {
// Limit batching of allocations to 1GB by default to avoid fragmentation issues
device->suballocation_block_size = 1024*1024*1024;
@@ -6150,9 +6165,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
if (
- (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
- (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
- (split_k > 1 && split_k_size > ctx->device->max_memory_allocation_size)) {
+ (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
+ (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
+ (split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) {
GGML_ABORT("Requested preallocation size is too large");
}
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
@@ -6227,7 +6242,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
}
if (x_non_contig) {
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
} else if (qx_needs_dequant) {
const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1});
@@ -6239,7 +6254,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -6250,7 +6265,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
- ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
+ ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -6272,14 +6287,11 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
}
- // No bounds checking is needed for dst. This is basically VK_WHOLE_SIZE but clamped to maxStorageBufferRange.
- VkDeviceSize d_range = std::min(VkDeviceSize{d_D->size - d_buf_offset}, VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
-
// compute
ggml_vk_matmul(
ctx, subctx, pipeline,
{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
- { d_D, d_buf_offset, d_range }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
+ ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
ne01, ne11, ne10,
ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,
split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
@@ -6446,8 +6458,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
}
if (
- (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
- (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
+ (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
+ (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
GGML_ABORT("Requested preallocation size is too large");
}
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
@@ -6512,7 +6524,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
}
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
}
if (y_non_contig) {
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
@@ -6521,7 +6533,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -6532,7 +6544,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
- ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13, true);
+ ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -6931,8 +6943,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
if (
- (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
- (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
+ (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
+ (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
GGML_ABORT("Requested preallocation size is too large");
}
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
@@ -6999,7 +7011,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
}
if (x_non_contig) {
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
} else if (qx_needs_dequant) {
const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
@@ -7012,7 +7024,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -7145,8 +7157,8 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
if (
- (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
- (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size)) {
+ (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
+ (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
GGML_ABORT("Requested preallocation size is too large");
}
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
@@ -7212,7 +7224,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (x_non_contig) {
GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, { d_Qx, qx_buf_offset, VK_WHOLE_SIZE }, { d_X, 0, VK_WHOLE_SIZE });
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
}
if (y_non_contig) {
GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
@@ -7221,7 +7233,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
@@ -7457,8 +7469,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
aligned = false;
}
- // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
- GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
@@ -7498,7 +7508,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
- if (split_k_size > ctx->device->max_memory_allocation_size) {
+ if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
GGML_ABORT("Requested preallocation size is too large");
}
if (ctx->prealloc_size_split_k < split_k_size) {
@@ -7620,12 +7630,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
- vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
- vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
- vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
- vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
- vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
- vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
+ ggml_vk_subbuffer(ctx, d_Q, q_buf_offset),
+ ggml_vk_subbuffer(ctx, d_K, k_buf_offset),
+ ggml_vk_subbuffer(ctx, d_V, v_buf_offset),
+ ggml_vk_subbuffer(ctx, d_M, m_buf_offset),
+ ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
+ ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0),
},
// We only use split_k when group query attention is enabled, which means
// there's no more than one tile of rows (i.e. workgroups_x would have been
@@ -7637,21 +7647,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{
- vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
- vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
- vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
+ ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0),
+ ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
+ ggml_vk_subbuffer(ctx, d_D, d_buf_offset),
},
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
ctx->prealloc_split_k_need_sync = true;
} else {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
- vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
- vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
- vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
- vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
- vk_subbuffer{d_S, s_buf_offset, VK_WHOLE_SIZE},
- vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
+ ggml_vk_subbuffer(ctx, d_Q, q_buf_offset),
+ ggml_vk_subbuffer(ctx, d_K, k_buf_offset),
+ ggml_vk_subbuffer(ctx, d_V, v_buf_offset),
+ ggml_vk_subbuffer(ctx, d_M, m_buf_offset),
+ ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
+ ggml_vk_subbuffer(ctx, d_D, d_buf_offset),
},
pc, { workgroups_x, workgroups_y, workgroups_z });
}
@@ -8360,18 +8370,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}
- uint64_t x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0;
- uint64_t y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 : 0;
- uint64_t z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 : 0;
- uint64_t d_sz = ggml_type_size(dst->type) * ned;
-
vk_buffer d_D = dst_buf_ctx->dev_buffer;
- // Workaround for tiny tensor inputs on ROPE
- if (op == GGML_OP_ROPE && use_src1 && y_sz > d_D->size) {
- y_sz = VK_WHOLE_SIZE;
- }
-
GGML_ASSERT(d_D != nullptr);
uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
if(!src0_uma) {
@@ -8396,26 +8396,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
- if (op_supports_incontiguous) {
- x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
- y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
- z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
- d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
-
- if (x_buf_offset + x_sz >= d_X->size) {
- x_sz = VK_WHOLE_SIZE;
- }
- if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
- y_sz = VK_WHOLE_SIZE;
- }
- if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
- z_sz = VK_WHOLE_SIZE;
- }
- if (d_buf_offset + d_sz >= d_D->size) {
- d_sz = VK_WHOLE_SIZE;
- }
- }
-
std::array elements;
// Single call if dimension 2 is contiguous
@@ -8606,19 +8586,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
break;
}
- if (!op_supports_incontiguous) {
- if (x_sz != VK_WHOLE_SIZE) {
- x_sz *= ne02 * ne03;
+ uint64_t x_sz, y_sz, z_sz, d_sz;
+
+ if (op_supports_incontiguous) {
+ x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
+ y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
+ z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
+ d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
+
+ if (x_buf_offset + x_sz >= d_X->size) {
+ x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset);
}
- if (use_src1 && y_sz != VK_WHOLE_SIZE) {
- y_sz *= ne12 * ne13;
+ if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
+ y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset);
}
- if (use_src2 && z_sz != VK_WHOLE_SIZE) {
- z_sz *= ne22 * ne23;
+ if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
+ z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset);
}
- if (d_sz != VK_WHOLE_SIZE) {
- d_sz *= ned2 * ned3;
+ if (d_buf_offset + d_sz >= d_D->size) {
+ d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset);
}
+ } else {
+ x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03;
+ y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0;
+ z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0;
+ d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3;
}
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
@@ -8628,7 +8620,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
vk_subbuffer{ d_Y, y_buf_offset, y_sz },
vk_subbuffer{ d_D, d_buf_offset, d_sz },
- vk_subbuffer{ d_A, a_buf_offset, VK_WHOLE_SIZE },
+ ggml_vk_subbuffer(ctx, d_A, a_buf_offset),
}, pc, elements);
} else if (op == GGML_OP_GLU) {
// Empty src1 is possible in glu, but the shader needs a buffer
@@ -8821,18 +8813,18 @@ static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx,
static_assert(MAX_PARAMETER_COUNT == 12);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
- vk_subbuffer{ buf[0], offset[0], VK_WHOLE_SIZE },
- vk_subbuffer{ buf[1], offset[1], VK_WHOLE_SIZE },
- vk_subbuffer{ buf[2], offset[2], VK_WHOLE_SIZE },
- vk_subbuffer{ buf[3], offset[3], VK_WHOLE_SIZE },
- vk_subbuffer{ buf[4], offset[4], VK_WHOLE_SIZE },
- vk_subbuffer{ buf[5], offset[5], VK_WHOLE_SIZE },
- vk_subbuffer{ buf[6], offset[6], VK_WHOLE_SIZE },
- vk_subbuffer{ buf[7], offset[7], VK_WHOLE_SIZE },
- vk_subbuffer{ buf[8], offset[8], VK_WHOLE_SIZE },
- vk_subbuffer{ buf[9], offset[9], VK_WHOLE_SIZE },
- vk_subbuffer{ buf[10], offset[10], VK_WHOLE_SIZE },
- vk_subbuffer{ buf[11], offset[11], VK_WHOLE_SIZE },
+ ggml_vk_subbuffer(ctx, buf[0], offset[0]),
+ ggml_vk_subbuffer(ctx, buf[1], offset[1]),
+ ggml_vk_subbuffer(ctx, buf[2], offset[2]),
+ ggml_vk_subbuffer(ctx, buf[3], offset[3]),
+ ggml_vk_subbuffer(ctx, buf[4], offset[4]),
+ ggml_vk_subbuffer(ctx, buf[5], offset[5]),
+ ggml_vk_subbuffer(ctx, buf[6], offset[6]),
+ ggml_vk_subbuffer(ctx, buf[7], offset[7]),
+ ggml_vk_subbuffer(ctx, buf[8], offset[8]),
+ ggml_vk_subbuffer(ctx, buf[9], offset[9]),
+ ggml_vk_subbuffer(ctx, buf[10], offset[10]),
+ ggml_vk_subbuffer(ctx, buf[11], offset[11]),
}, pc, elements);
}
@@ -10006,7 +9998,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
ggml_vk_ctx_begin(ctx->device, subctx);
for (size_t i = 0; i < num_it; i++) {
ggml_vk_matmul(
- ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
+ ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k),
m, n, k,
k, k, m, k*m, k*n, m*n,
split_k, batch, batch, batch, 1, 1, n
@@ -10317,7 +10309,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
//
// vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
// ggml_vk_ctx_begin(ctx->device, subctx);
-// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
+// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne);
// ggml_vk_ctx_end(subctx);
//
// auto begin = std::chrono::high_resolution_clock::now();
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp
index d896f1ef0b..5084a70ed4 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_binary_head.comp"
+#include "types.glsl"
+#include "generic_binary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp
index 00cf2dd62f..3bcfe6908e 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/add.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/add.comp
@@ -6,8 +6,8 @@
#extension GL_KHR_shader_subgroup_basic : enable
#endif
-#include "types.comp"
-#include "generic_binary_head.comp"
+#include "types.glsl"
+#include "generic_binary_head.glsl"
const uint num_threads = 256;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp b/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp
index 3ae8f0116c..495249d5f6 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp
@@ -2,7 +2,7 @@
#extension GL_EXT_control_flow_attributes : require
-#include "types.comp"
+#include "types.glsl"
layout (push_constant) uniform parameter
{
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp
index a1d4c240dd..7c12877671 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
index dc53a401e0..c81b84452e 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp
@@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
-#include "types.comp"
+#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp
index 1e5cb8dae4..653431895e 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp
index 9ee2f1fae2..e404698382 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_binary_head.comp"
+#include "types.glsl"
+#include "generic_binary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp
index 6567a8c54c..ca1a3ac25b 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
#extension GL_EXT_control_flow_attributes : require
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp
index 938c74da50..70a301488e 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp
@@ -1,6 +1,6 @@
#version 450
-#include "types.comp"
+#include "types.glsl"
layout (push_constant) uniform parameter
{
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp
index 44a64ddc80..0367e80bbf 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp
@@ -11,7 +11,7 @@
# extension GL_KHR_shader_subgroup_shuffle : enable
#endif
-#include "types.comp"
+#include "types.glsl"
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
layout(binding = 0) readonly buffer A {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp
index b17b4e83ee..5217e18bdd 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp
@@ -1,6 +1,6 @@
#version 450
-#include "types.comp"
+#include "types.glsl"
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin]
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin]
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp
index f476a2e3dd..9f8bfd3c18 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp
index 978d430030..06df509525 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp
@@ -1,8 +1,8 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
-#include "dequant_funcs.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
+#include "dequant_funcs.glsl"
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
// 16 invocations needed for init_iq_shmem
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp
index bc2e1f2df3..b8c40eec10 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp
@@ -1,7 +1,7 @@
#version 450
-#include "rte.comp"
-#include "types.comp"
+#include "rte.glsl"
+#include "types.glsl"
#if defined(SET_ROWS) && QUANT_K == 1
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
@@ -14,7 +14,7 @@ const uint BLOCK_SIZE = 32;
layout (binding = 0) readonly buffer S {float data_s[];};
#if defined(SET_ROWS)
-#include "generic_binary_head.comp"
+#include "generic_binary_head.glsl"
layout (binding = 1) readonly buffer C {B_TYPE data_i[];};
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
@@ -25,7 +25,7 @@ layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
#endif
#else
-#include "generic_unary_head.comp"
+#include "generic_unary_head.glsl"
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
#endif
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp
index 0b8d02f58f..db6865db98 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp
index d9345497c7..e75df66756 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp
@@ -2,8 +2,8 @@
#extension GL_EXT_control_flow_attributes : enable
-#include "types.comp"
-#include "generic_head.comp"
+#include "types.glsl"
+#include "generic_head.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp
index a4d3fca556..765afffa80 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
similarity index 99%
rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
index 73fef4fa65..0d98f5a9d6 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl
@@ -2,7 +2,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#endif
-#include "types.comp"
+#include "types.glsl"
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
similarity index 99%
rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
index 706540fd85..6a5bb4574d 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
@@ -1,5 +1,5 @@
-#include "types.comp"
+#include "types.glsl"
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl
similarity index 91%
rename from ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl
index 8d806435b7..addceafade 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl
@@ -10,4 +10,4 @@ layout (push_constant) uniform parameter
uint nel;
} p;
-#include "types.comp"
+#include "types.glsl"
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp
index b604c1881a..637c95fa35 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp
@@ -2,7 +2,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp
index fd1e4e30d2..d1cbc5e9d0 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp
index 127c7b6424..78490162cd 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp
index a08331c40d..9b8ce0a7f8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp
index 0ae9acd02a..aacf07d0f8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp
index e4f42be94c..f2c20b1d2c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp
index 19c7fdeefc..671c1f4a0d 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp
index 46d9ad15eb..8f7833eab2 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp
index f930852a48..a313699775 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
index ee496e9d56..ffba5a77dd 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
index d4e4e6bae6..58dc2e5dfd 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp
index 3661f771c7..0c90be8b4e 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp
index 4081853272..b92b292135 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp
index 2f27eee686..6b63cbe583 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
index 1370db3654..8b7be557e9 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp
index b20b805292..f1b0bac872 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp
index dc59fe3b77..c495b31f17 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
index 3f3b839e11..6bc04670fc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp
index 9cf34256e8..c8d6fcb49f 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp
index bd1344a88d..10844ddf78 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp
@@ -1,6 +1,6 @@
#version 450
-#include "dequant_head.comp"
+#include "dequant_head.glsl"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp
index 26d8bc22ad..9cef8a8ec3 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp
@@ -10,7 +10,7 @@ layout (push_constant) uniform parameter
uint n_past;
} p;
-#include "types.comp"
+#include "types.glsl"
layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp
index 9fb69c6c15..572472f8a9 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/div.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/div.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_binary_head.comp"
+#include "types.glsl"
+#include "generic_binary_head.glsl"
const uint num_threads = 256;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp
index a3941372a7..b69d4ddb09 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp
@@ -1,8 +1,8 @@
#version 450
-#include "rte.comp"
-#include "generic_head.comp"
-#include "types.comp"
+#include "rte.glsl"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/bfloat16.comp
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/bfloat16.comp
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat.comp
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat.comp
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2.comp
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat2_support.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2.comp
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/integer_dot.comp
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/integer_dot.comp
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
index 43b906e5ed..62acbf107a 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
@@ -8,8 +8,8 @@
#extension GL_KHR_shader_subgroup_shuffle : enable
-#include "types.comp"
-#include "flash_attn_base.comp"
+#include "types.glsl"
+#include "flash_attn_base.glsl"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
@@ -153,12 +153,13 @@ void main() {
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
- if (!KV_bounds_check || j * Bc + c < KV) {
+ if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} else {
masksh[c][r] = float(0);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
index ddb1246e0b..2066a05b34 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
@@ -10,8 +10,8 @@
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
-#include "types.comp"
-#include "flash_attn_base.comp"
+#include "types.glsl"
+#include "flash_attn_base.glsl"
const uint32_t HSK_per_thread = HSK / D_split;
const uint32_t HSV_per_thread = HSV / D_split;
@@ -201,11 +201,13 @@ void main() {
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
- if (!KV_bounds_check || j * Bc + c < KV) {
+ if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
}
}
@@ -356,8 +358,8 @@ void main() {
}
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
- [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
- float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
+ [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+ float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
float ms = 1.0f;
float vs = 1.0f;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
index ab647e9bc8..910da1ab0c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
@@ -16,9 +16,9 @@
#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_EXT_null_initializer : enable
-#include "types.comp"
-#include "dequant_funcs_cm2.comp"
-#include "flash_attn_base.comp"
+#include "types.glsl"
+#include "dequant_funcs_cm2.glsl"
+#include "flash_attn_base.glsl"
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
@@ -154,15 +154,31 @@ void main() {
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
- tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
- tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
- tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
- coopmat mv;
+ if (nem1_bounds_check) {
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
- coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+ coopmat mv;
- S += slopeMat*coopmat(mv);
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+
+ S += slopeMat*coopmat(mv);
+ } else {
+ tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
+ // Don't clamp against nem1 when GQA is enabled
+ uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+
+ coopmat mv;
+
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
+
+ S += slopeMat*coopmat(mv);
+ }
}
// Clear padding elements to -inf, so they don't contribute to rowmax
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp
index f4268ed24f..e017b50368 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp
@@ -1,6 +1,6 @@
#version 450
-#include "glu_head.comp"
+#include "glu_head.glsl"
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -10,4 +10,4 @@ float op(float a, float b) {
return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
}
-#include "glu_main.comp"
+#include "glu_main.glsl"
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp
index cbd4cb36bf..759a1848fa 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp
@@ -1,6 +1,6 @@
#version 450
-#include "glu_head.comp"
+#include "glu_head.glsl"
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
// ref: https://www.johndcook.com/blog/python_erf/
@@ -24,4 +24,4 @@ float op(float a, float b) {
return 0.5f * a * (1.0f + erf_approx) * b;
}
-#include "glu_main.comp"
+#include "glu_main.glsl"
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp
index 3a2a6897bf..c4032ab21d 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp
@@ -1,6 +1,6 @@
#version 450
-#include "glu_head.comp"
+#include "glu_head.glsl"
const float GELU_QUICK_COEF = -1.702f;
@@ -8,4 +8,4 @@ float op(float a, float b) {
return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b;
}
-#include "glu_main.comp"
+#include "glu_main.glsl"
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp
index 4cc7a68ca1..a95c2525c8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp
index 5fd5a5e703..58375aba09 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp
index e6e6fcfd20..bfdfe2182d 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl
similarity index 97%
rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl
index 750e785753..99595fc688 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl
@@ -1,8 +1,8 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
-#include "rte.comp"
-#include "utils.comp"
+#include "rte.glsl"
+#include "utils.glsl"
layout (push_constant) uniform parameter
{
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_head.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
index 7ef75cd7a4..76d83041ce 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_binary_head.comp"
+#include "types.glsl"
+#include "generic_binary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
index 339f905fc7..9dba437edb 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
@@ -2,9 +2,9 @@
#extension GL_EXT_control_flow_attributes : enable
-#include "types.comp"
-#include "generic_binary_head.comp"
-#include "dequant_funcs.comp"
+#include "types.glsl"
+#include "generic_binary_head.glsl"
+#include "dequant_funcs.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl
similarity index 95%
rename from ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl
index 51d70869d9..2168989340 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl
@@ -1,6 +1,6 @@
#extension GL_EXT_shader_16bit_storage : require
-#include "rte.comp"
+#include "rte.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/glu_main.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp
index b6a0d56454..bdf97dbb5d 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp
index 1da252cc66..b4dbdf3141 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp
index 3afc588274..1ec315915e 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
index f0f19a019c..1827d647a2 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
@@ -3,9 +3,8 @@
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
-#include "rte.comp"
-
-#include "types.comp"
+#include "rte.glsl"
+#include "types.glsl"
layout (push_constant) uniform parameter
{
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp
index 9faa636ac2..4bf8b4ca04 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp
@@ -4,9 +4,8 @@
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "rte.comp"
-
-#include "types.comp"
+#include "rte.glsl"
+#include "types.glsl"
layout (push_constant) uniform parameter
{
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
index deba8c3985..83ef2f8795 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp
index d90a99aea5..b281e855cb 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp
index 43de19df8e..02ef1eace1 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_binary_head.comp"
+#include "types.glsl"
+#include "generic_binary_head.glsl"
const uint num_threads = 256;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
index bb429dd594..9a03925cfd 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
@@ -2,7 +2,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
similarity index 99%
rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
index f761391eae..450dee0408 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
@@ -11,7 +11,7 @@
#define EXPERT_COUNT 8
#endif
-#include "types.comp"
+#include "types.glsl"
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@@ -32,7 +32,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 3) readonly buffer IDS {int data_ids[];};
#endif
-#include "dequant_funcs.comp"
+#include "dequant_funcs.glsl"
layout (push_constant) uniform parameter
{
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp
index e4acbd4f96..4cb292380c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp
@@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp
index 309da0991a..0b74b33212 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp
@@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp
index 8d01536fa6..e424af12c5 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp
@@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp
index c496043241..0cd906dbbf 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp
@@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp
index 94d4b92e1e..71bd72d17e 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp
@@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp
index f021e40476..a4b9ab1f94 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp
@@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp
index 3fe9dc3a41..40849c691f 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp
@@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
index 423ceb8a3d..03ed25d3bf 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
@@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
index e91724a28d..528f224d86 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
@@ -1,7 +1,7 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
index f9cde06488..21d07d2e50 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
@@ -2,7 +2,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
index 6c84ef3cde..9e46c89a11 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
@@ -2,7 +2,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
index d53d9ee0a2..d7a7f6426e 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
@@ -2,7 +2,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
index 8fb314fa0a..64293f6eca 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
@@ -6,13 +6,13 @@
#define MMQ
#define B_TYPE block_q8_1_x4
-#include "mul_mat_vec_base.comp"
+#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#define K_PER_ITER 8
-#include "mul_mmq_funcs.comp"
+#include "mul_mmq_funcs.glsl"
uint a_offset, b_offset, d_offset;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
index 3cb24412d5..85400ac5fc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
@@ -28,7 +28,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#endif
-#include "types.comp"
+#include "types.glsl"
#ifndef LOAD_VEC_A
#define LOAD_VEC_A 1
@@ -195,7 +195,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
-#include "mul_mm_funcs.comp"
+#include "mul_mm_funcs.glsl"
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
index 0e3065e014..2e04baa44e 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
@@ -18,8 +18,8 @@
#extension GL_EXT_bfloat16 : enable
#endif
-#include "types.comp"
-#include "utils.comp"
+#include "types.glsl"
+#include "utils.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
@@ -71,7 +71,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#if QUANT_K > 1
#define DECODEFUNCA , dequantFuncA
-#include "dequant_funcs_cm2.comp"
+#include "dequant_funcs_cm2.glsl"
#else
#define DECODEFUNCA
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
index f36add62a9..b5d761c0ba 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp
@@ -20,7 +20,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#endif
-#include "types.comp"
+#include "types.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
@@ -110,7 +110,7 @@ shared u16vec2 row_ids[4096];
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
-#include "mul_mmq_funcs.comp"
+#include "mul_mmq_funcs.glsl"
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
similarity index 99%
rename from ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
index cdfb230f4e..fe71eb131c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl
@@ -2,7 +2,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
-#include "types.comp"
+#include "types.glsl"
// Each iqs value maps to a 32-bit integer
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp
index 854a2ad818..1e8f694a72 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp
@@ -8,9 +8,9 @@
#extension GL_KHR_shader_subgroup_basic : enable
#endif
-#include "rte.comp"
-#include "types.comp"
-#include "utils.comp"
+#include "rte.glsl"
+#include "types.glsl"
+#include "utils.glsl"
layout (push_constant) uniform parameter2
{
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp
index 6627a50bd9..cc3ea0b760 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp
index e0214fe764..1f05f922cc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
index 6426dedee5..1251f9cc64 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
@@ -1,6 +1,6 @@
#version 450
-#include "generic_head.comp"
+#include "generic_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp
index 0d81220c71..f3c8176872 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp
@@ -1,6 +1,6 @@
#version 450
-#include "types.comp"
+#include "types.glsl"
layout (push_constant) uniform parameter
{
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp
index b6124411a0..d9d7166e36 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp
@@ -1,6 +1,6 @@
#version 450
-#include "types.comp"
+#include "types.glsl"
#extension GL_EXT_shader_16bit_storage : require
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp
index 145c9fbdc9..0f3c6ca871 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp
@@ -17,7 +17,7 @@ layout (push_constant) uniform parameter
uint ne;
} p;
-#include "types.comp"
+#include "types.glsl"
layout(constant_id = 0) const uint GROUP_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp
index 0073d8f766..86be2669a1 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp
@@ -1,9 +1,9 @@
#version 450
-#include "glu_head.comp"
+#include "glu_head.glsl"
float op(float a, float b) {
return max(a, 0.0f) * b;
}
-#include "glu_main.comp"
+#include "glu_main.glsl"
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp
index 4f806270c7..5725cef236 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp
index 1568b141de..8f4b9a8684 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp
index d86279934f..87df782944 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
index 41197e9301..d5b211ffaa 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_binary_head.comp"
-#include "types.comp"
+#include "generic_binary_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp
index 76009f3df6..87707fc149 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp
index ba4677c293..4618b2c7e8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_binary_head.comp"
-#include "types.comp"
+#include "generic_binary_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp b/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp
index b9abe8dedc..68fbd0c7be 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
similarity index 97%
rename from ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
index 00e203e73b..50fc1f1e2d 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl
@@ -1,8 +1,8 @@
-#include "types.comp"
+#include "types.glsl"
#extension GL_EXT_shader_16bit_storage : require
-#include "rte.comp"
+#include "rte.glsl"
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
index 5808710ccf..111286b498 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp
@@ -1,6 +1,6 @@
#version 450
-#include "rope_head.comp"
+#include "rope_head.glsl"
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
index 366a7b1c47..06e095bef9 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp
@@ -1,6 +1,6 @@
#version 450
-#include "rope_head.comp"
+#include "rope_head.glsl"
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
index 9643bca96a..6ba9575409 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp
@@ -1,6 +1,6 @@
#version 450
-#include "rope_head.comp"
+#include "rope_head.glsl"
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
index cedacc4d14..d37d1c1043 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp
@@ -1,6 +1,6 @@
#version 450
-#include "rope_head.comp"
+#include "rope_head.glsl"
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rte.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/rte.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp
index f10b0a02b5..35ec726a01 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
const uint num_threads = 128;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp
index 5c9e5c3503..32298d43c6 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp
index 4d36f88e08..7d1cc6f45a 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp
index f9afa9b13c..e5d949ff18 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp
index d7c15a1695..61f17b2f00 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp
index 5f20a1ee7d..dca0d896bc 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp
@@ -23,7 +23,7 @@ layout (push_constant) uniform parameter
uint has_sinks;
} p;
-#include "types.comp"
+#include "types.glsl"
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp
index 144ea58e6f..d873332eeb 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp
@@ -2,8 +2,8 @@
#extension GL_EXT_control_flow_attributes : enable
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp
index 4bc697b9b9..70daad6c5d 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp
index ef43598baf..4eb56afcb1 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp
@@ -1,7 +1,7 @@
#version 450
-#include "types.comp"
-#include "generic_unary_head.comp"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp
index 72353cc329..bc924b520a 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp
@@ -2,8 +2,8 @@
#extension GL_EXT_shader_16bit_storage : require
-#include "types.comp"
-#include "generic_binary_head.comp"
+#include "types.glsl"
+#include "generic_binary_head.glsl"
const uint num_threads = 256;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp
index 759204afaf..bc22aa7bd7 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp
@@ -1,6 +1,6 @@
#version 450
-#include "types.comp"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp
index a28e7c6cc8..4fee433a12 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp
@@ -1,9 +1,9 @@
#version 450
-#include "glu_head.comp"
+#include "glu_head.glsl"
float op(float a, float b) {
return a / (1.0f + exp(-a)) * b;
}
-#include "glu_main.comp"
+#include "glu_main.glsl"
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp
index 970750eec0..bda9dea21c 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp
@@ -1,6 +1,6 @@
#version 450
-#include "glu_head.comp"
+#include "glu_head.glsl"
float op(float a, float b) {
float xi = min(a, p.limit);
@@ -11,4 +11,4 @@ float op(float a, float b) {
return out_glu;
}
-#include "glu_main.comp"
+#include "glu_main.glsl"
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp
index 8a6f868f58..7b5eb413bf 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp
@@ -1,7 +1,7 @@
#version 450
-#include "generic_head.comp"
-#include "types.comp"
+#include "generic_head.glsl"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp
index ce8e09442d..1605565457 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp
@@ -9,7 +9,7 @@ layout (push_constant) uniform parameter
uint max_period;
} p;
-#include "types.comp"
+#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 256
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/types.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/types.glsl
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp
index 74771def0f..154a2172d8 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp
@@ -9,7 +9,7 @@ layout (push_constant) uniform parameter
float sf0; float sf1; float sf2; float sf3;
} p;
-#include "types.comp"
+#include "types.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/utils.comp b/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl
similarity index 100%
rename from ggml/src/ggml-vulkan/vulkan-shaders/utils.comp
rename to ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
index 84bb9df9a0..f0cc24ff31 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -1,5 +1,3 @@
-
-
#include
#include
#include
@@ -22,6 +20,7 @@
#include
#ifdef _WIN32
+ #define NOMINMAX
#include
#include // For _mkdir on Windows
#else
@@ -34,13 +33,13 @@
std::mutex lock;
std::vector> shader_fnames;
+std::locale c_locale("C");
std::string GLSLC = "glslc";
-std::string input_dir = "vulkan-shaders";
+std::string input_filepath = "";
std::string output_dir = "/tmp";
-std::string target_hpp = "ggml-vulkan-shaders.hpp";
-std::string target_cpp = "ggml-vulkan-shaders.cpp";
-bool no_clean = false;
+std::string target_hpp = "";
+std::string target_cpp = "";
const std::vector type_names = {
"f32",
@@ -75,6 +74,7 @@ enum MatMulIdType {
};
namespace {
+
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
#ifdef _WIN32
HANDLE stdout_read, stdout_write;
@@ -232,16 +232,87 @@ std::string basename(const std::string &path) {
return path.substr(path.find_last_of("/\\") + 1);
}
+std::stringstream make_generic_stringstream() {
+ std::stringstream ss;
+ ss.imbue(c_locale);
+ return ss;
+}
+
+std::string read_binary_file(const std::string& path, bool may_not_exist = false) {
+ FILE* f = fopen(path.c_str(), "rb");
+ if (!f) {
+ if (!may_not_exist) {
+ std::cerr << "Error opening file: " << path << " (" << strerror(errno) << ")\n";
+ }
+ return {};
+ }
+
+ fseek(f, 0, SEEK_END);
+ size_t size = ftell(f);
+ fseek(f, 0, SEEK_SET);
+
+ std::string data(size, '\0');
+ size_t read_size = fread(data.data(), 1, size, f);
+ fclose(f);
+ if (read_size != size) {
+ std::cerr << "Error reading file: " << path << " (" << strerror(errno) << ")\n";
+ return {};
+ }
+
+ return data;
+}
+
+void write_binary_file(const std::string& path, const std::string& content) {
+ FILE* f = fopen(path.c_str(), "wb");
+ if (!f) {
+ std::cerr << "Error opening file for writing: " << path << " (" << strerror(errno) << ")\n";
+ return;
+ }
+
+ size_t write_size = fwrite(content.data(), 1, content.size(), f);
+ fclose(f);
+ if (write_size != content.size()) {
+ std::cerr << "Error writing file: " << path << " (" << strerror(errno) << ")\n";
+ return;
+ }
+}
+
+void write_file_if_changed(const std::string& path, const std::string& content) {
+ std::string existing = read_binary_file(path, true);
+ if (existing != content) {
+ write_binary_file(path, content);
+ }
+}
+
+
// variables to track number of compiles in progress
static uint32_t compile_count = 0;
static std::mutex compile_count_mutex;
static std::condition_variable compile_count_cond;
+static bool generate_dep_file = true;
-void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
- std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
- std::string out_fname = join_paths(output_dir, name + ".spv");
- std::string in_path = join_paths(input_dir, in_fname);
+void decrement_compile_count(uint32_t * count) {
+ if (count) {
+ std::lock_guard guard(compile_count_mutex);
+ assert(compile_count > 0);
+ compile_count--;
+ compile_count_cond.notify_all();
+ }
+}
+using compile_count_guard = std::unique_ptr;
+
+compile_count_guard acquire_compile_slot() {
+ // wait until fewer than N compiles are in progress.
+ // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
+ uint32_t N = std::max(1u, std::min(16u, std::thread::hardware_concurrency()));
+ std::unique_lock guard(compile_count_mutex);
+ compile_count_cond.wait(guard, [N] { return compile_count < N; });
+ compile_count++;
+ return compile_count_guard(&compile_count, &decrement_compile_count);
+}
+
+void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map defines, bool coopmat, bool dep_file, compile_count_guard slot) {
std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
// disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
@@ -249,11 +320,17 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c
std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O";
#ifdef _WIN32
- std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
+ std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
#else
- std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_fname};
+ std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_path};
#endif
+ if (dep_file) {
+ cmd.push_back("-MD");
+ cmd.push_back("-MF");
+ cmd.push_back("\"" + target_cpp + ".d\"");
+ }
+
#ifdef GGML_VULKAN_SHADER_DEBUG_INFO
cmd.push_back("-g");
#endif
@@ -281,17 +358,23 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c
return;
}
+ if (dep_file) {
+ // replace .spv output path with the embed .cpp path which is used as output in CMakeLists.txt
+ std::string dep = read_binary_file(target_cpp + ".d", true);
+ if (!dep.empty()) {
+ size_t pos = dep.find(out_path);
+ if (pos != std::string::npos) {
+ dep.replace(pos, out_path.length(), target_cpp);
+ }
+ write_binary_file(target_cpp + ".d", dep);
+ }
+ }
+
std::lock_guard guard(lock);
- shader_fnames.push_back(std::make_pair(name, out_fname));
+ shader_fnames.push_back(std::make_pair(name, out_path));
} catch (const std::exception& e) {
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
}
- {
- std::lock_guard guard(compile_count_mutex);
- assert(compile_count > 0);
- compile_count--;
- }
- compile_count_cond.notify_all();
}
std::map merge_maps(const std::map& a, const std::map& b) {
@@ -301,18 +384,24 @@ std::map merge_maps(const std::map> compiles;
-void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
- {
- // wait until fewer than N compiles are in progress.
- // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
- uint32_t N = 16;
- std::unique_lock guard(compile_count_mutex);
- while (compile_count >= N) {
- compile_count_cond.wait(guard);
- }
- compile_count++;
+void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
+ name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
+ std::string out_path = join_paths(output_dir, name + ".spv");
+
+ if (input_filepath == "") {
+ // No input source to compile, only generate header for all shaders
+ shader_fnames.push_back(std::pair(name, out_path));
+ return;
+ } else if (basename(input_filepath) != source) {
+ // Only compile shader variants matching the input filename
+ return;
}
- compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
+
+ compile_count_guard slot = acquire_compile_slot();
+ compiles.push_back(std::async(
+ string_to_spv_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot)));
+ // Don't write the same dep file from multiple processes
+ generate_dep_file = false;
}
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
@@ -485,7 +574,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}
void process_shaders() {
- std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
std::map base_dict = {{"FLOAT_TYPE", "float"}};
// matmul
@@ -837,11 +925,11 @@ void process_shaders() {
}
void write_output_files() {
- FILE* hdr = fopen(target_hpp.c_str(), "w");
- FILE* src = fopen(target_cpp.c_str(), "w");
+ std::stringstream hdr = make_generic_stringstream();
+ std::stringstream src = make_generic_stringstream();
- fprintf(hdr, "#include \n\n");
- fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
+ hdr << "#include \n\n";
+ src << "#include \"" << basename(target_hpp) << "\"\n\n";
std::sort(shader_fnames.begin(), shader_fnames.end());
for (const auto& pair : shader_fnames) {
@@ -853,91 +941,85 @@ void write_output_files() {
const std::string& path = pair.second;
#endif
- FILE* spv = fopen(path.c_str(), "rb");
- if (!spv) {
- std::cerr << "Error opening SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
- continue;
- }
+ hdr << "extern const uint64_t " << name << "_len;\n";
+ hdr << "extern const unsigned char " << name << "_data[];\n\n";
- fseek(spv, 0, SEEK_END);
- size_t size = ftell(spv);
- fseek(spv, 0, SEEK_SET);
+ if (input_filepath != "") {
+ std::string data = read_binary_file(path);
+ if (data.empty()) {
+ continue;
+ }
- std::vector data(size);
- size_t read_size = fread(data.data(), 1, size, spv);
- fclose(spv);
- if (read_size != size) {
- std::cerr << "Error reading SPIR-V file: " << path << " (" << strerror(errno) << ")\n";
- continue;
- }
-
- fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
- fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
-
- fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
- for (size_t i = 0; i < size; ++i) {
- fprintf(src, "0x%02x,", data[i]);
- if ((i + 1) % 12 == 0) fprintf(src, "\n");
- }
- fprintf(src, "\n};\n\n");
-
- if (!no_clean) {
- std::remove(path.c_str());
+ src << "const uint64_t " << name << "_len = " << data.size() << ";\n";
+ src << "const unsigned char " << name << "_data[" << data.size() << "] = {\n" << std::hex;
+ auto bytes = reinterpret_cast(data.data());
+ for (size_t i = 0; i < data.size(); ++i) {
+ src << "0x" << static_cast(bytes[i]) << ",";
+ if ((i + 1) % 12 == 0) src << "\n";
+ }
+ src << std::dec << "\n};\n\n";
}
}
std::string suffixes[2] = {"_f32", "_f16"};
- for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
- fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
- fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
- std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
- std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = ";
+ for (auto op : {"add", "sub", "mul", "div", "add_rms"}) {
+ hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
+ hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
+
+ std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp";
+ if (basename(input_filepath) != op_file) {
+ continue;
+ }
+ std::stringstream data = make_generic_stringstream();
+ std::stringstream len = make_generic_stringstream();
+ data << "const void * " << op << "_data[2][2][2][2] = ";
+ len << "const uint64_t " << op << "_len[2][2][2][2] = ";
for (uint32_t t0 = 0; t0 < 2; ++t0) {
if (t0 == 0) {
- data += "{";
- len += "{";
+ data << "{";
+ len << "{";
}
for (uint32_t t1 = 0; t1 < 2; ++t1) {
if (t1 == 0) {
- data += "{";
- len += "{";
+ data << "{";
+ len << "{";
}
for (uint32_t t2 = 0; t2 < 2; ++t2) {
if (t2 == 0) {
- data += "{";
- len += "{";
+ data << "{";
+ len << "{";
}
for (uint32_t rte = 0; rte < 2; ++rte) {
if (rte == 0) {
- data += "{";
- len += "{";
+ data << "{";
+ len << "{";
}
- data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
- len += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
- data += "_data,";
- len += "_len,";
+ data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
+ len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
+ data << "_data,";
+ len << "_len,";
if (rte == 1) {
- data += "}, ";
- len += "}, ";
+ data << "}, ";
+ len << "}, ";
}
}
if (t2 == 1) {
- data += "}, ";
- len += "}, ";
+ data << "}, ";
+ len << "}, ";
}
}
if (t1 == 1) {
- data += "}, ";
- len += "}, ";
+ data << "}, ";
+ len << "}, ";
}
}
if (t0 == 1) {
- data += "};\n";
- len += "};\n";
+ data << "};\n";
+ len << "};\n";
}
}
- fputs(data.c_str(), src);
- fputs(len.c_str(), src);
+ src << data.str();
+ src << len.str();
}
std::vector btypes = {"f16", "f32"};
@@ -951,20 +1033,25 @@ void write_output_files() {
if (btype == "q8_1" && !is_legacy_quant(tname)) {
continue;
}
- fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str());
- fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n", tname.c_str(), btype.c_str());
- std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_data};\n";
- std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_len};\n";
- fputs(data.c_str(), src);
- fputs(len.c_str(), src);
+ hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
+ hdr << "extern const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3];\n";
+ if (basename(input_filepath) == "mul_mat_vec.comp") {
+ src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
+ src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
+ }
}
}
- fclose(hdr);
- fclose(src);
-}
+ if (input_filepath == "") {
+ write_file_if_changed(target_hpp, hdr.str());
+ }
+ if (target_cpp != "") {
+ write_binary_file(target_cpp, src.str());
+ }
}
+} // namespace
+
int main(int argc, char** argv) {
std::map args;
for (int i = 1; i < argc; ++i) {
@@ -982,8 +1069,8 @@ int main(int argc, char** argv) {
if (args.find("--glslc") != args.end()) {
GLSLC = args["--glslc"]; // Path to glslc
}
- if (args.find("--input-dir") != args.end()) {
- input_dir = args["--input-dir"]; // Directory containing shader sources
+ if (args.find("--source") != args.end()) {
+ input_filepath = args["--source"]; // The shader source file to compile
}
if (args.find("--output-dir") != args.end()) {
output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
@@ -994,14 +1081,6 @@ int main(int argc, char** argv) {
if (args.find("--target-cpp") != args.end()) {
target_cpp = args["--target-cpp"]; // Path to generated cpp file
}
- if (args.find("--no-clean") != args.end()) {
- no_clean = true; // Keep temporary SPIR-V files in output-dir after build
- }
-
- if (!directory_exists(input_dir)) {
- std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
- return EXIT_FAILURE;
- }
if (!directory_exists(output_dir)) {
if (!create_directory(output_dir)) {
diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
index 93200a4d29..e795ca3fd9 100644
--- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp
+++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp
@@ -28,6 +28,7 @@
/* Constants */
#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16
+#define WEBGPU_WAIT_ANY_BATCH_SIZE 64
#define WEBGPU_MUL_MAT_WG_SIZE 64
#define WEBGPU_NUM_PARAM_BUFS 100
#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
@@ -35,6 +36,9 @@
#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
+// For operations which process a row in parallel, this seems like a reasonable default
+#define WEBGPU_ROW_SPLIT_WG_SIZE 64
+
/* End Constants */
// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
@@ -130,15 +134,16 @@ struct webgpu_context_struct {
wgpu::ComputePipeline set_rows_pipeline;
wgpu::ComputePipeline get_rows_pipeline[30];
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
- wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
- wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
- wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
- wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
- wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace
- wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
- wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
- wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
- wgpu::ComputePipeline scale_pipeline[2]; // inplace
+ wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
+ wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
+ wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
+ wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
+ wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace
+ wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace
+ wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace
+ wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split
+ wgpu::ComputePipeline scale_pipeline[2]; // inplace
+ wgpu::ComputePipeline soft_max_pipeline[3][2][2]; // (no_mask, f32_mask, f16_mask), has_sink, inplace
size_t memset_bytes_per_thread;
@@ -256,8 +261,12 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
}),
UINT64_MAX);
} else {
- // existing callbacks, wait on them
- ctx->instance.WaitAny(ctx->callback_futures.size(), ctx->callback_futures.data(), UINT64_MAX);
+ // WebGPU implementations may limit the number of futures that can be waited on at once,
+ // so wait in batches (64 is what Dawn supports).
+ for (size_t i = 0; i < ctx->callback_futures.size(); i += WEBGPU_WAIT_ANY_BATCH_SIZE) {
+ size_t end = std::min(i + WEBGPU_WAIT_ANY_BATCH_SIZE, ctx->callback_futures.size());
+ ctx->instance.WaitAny(end - i, ctx->callback_futures.data() + i, UINT64_MAX);
+ }
ctx->callback_futures.clear();
}
}
@@ -415,6 +424,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
ctx->staged_param_bufs.push_back(params_bufs);
if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
ggml_backend_webgpu_submit_queue(ctx);
+ ggml_backend_webgpu_wait_on_submission(ctx);
}
}
}
@@ -726,9 +736,7 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
- size_t max_wg_size = ctx->max_wg_size_x;
- uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size;
- ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x,
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, ggml_nrows(src),
ggml_op_name(dst->op));
}
@@ -912,6 +920,79 @@ static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tens
ggml_op_name(dst->op));
}
+static void ggml_webgpu_soft_max(webgpu_context & ctx,
+ ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * src2,
+ ggml_tensor * dst) {
+ const int inplace = ggml_webgpu_tensor_equal(src0, dst);
+ const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
+ const int has_sink = (src2 != nullptr);
+ float max_bias;
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+ float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
+ float m0 = powf(2.0f, -(max_bias) / n_head_log2);
+ float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ std::vector params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+ mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
+ has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+ mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
+ mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
+ mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+ (uint32_t) ggml_nelements(dst),
+ (uint32_t) src0->ne[0],
+ (uint32_t) src0->ne[1],
+ (uint32_t) src0->ne[2],
+ mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
+ mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
+ *(uint32_t *) dst->op_params, // scale
+ *(uint32_t *) &max_bias,
+ *(uint32_t *) &n_head_log2,
+ *(uint32_t *) &m0,
+ *(uint32_t *) &m1
+ };
+
+ std::vector entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) }
+ };
+ uint32_t binding_num = 1;
+ if (mask_type < 2) {
+ entries.push_back({ .binding = binding_num,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
+ binding_num++;
+ }
+ if (has_sink) {
+ entries.push_back({ .binding = binding_num,
+ .buffer = ggml_webgpu_tensor_buf(src2),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
+ binding_num++;
+ }
+ if (!inplace) {
+ entries.push_back({ .binding = binding_num,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+ }
+
+ ggml_backend_webgpu_build_and_enqueue(ctx, ctx->soft_max_pipeline[mask_type][has_sink][inplace], params, entries,
+ ggml_nrows(dst), ggml_op_name(dst->op));
+}
+
// Returns true if node has enqueued work into the queue, false otherwise
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
if (ggml_is_empty(node)) {
@@ -980,6 +1061,9 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
case GGML_OP_SCALE:
ggml_webgpu_scale(ctx, src0, node);
break;
+ case GGML_OP_SOFT_MAX:
+ ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
+ break;
default:
return false;
}
@@ -1237,11 +1321,11 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) {
return reinterpret_cast((void *) guid_str);
}
-// The max workgroup size is a common constant
-static std::vector ggml_webgpu_max_wg_size_entry(webgpu_context & webgpu_ctx) {
+// Workgroup size is a common constant
+static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) {
std::vector constants(1);
constants[0].key = "wg_size";
- constants[0].value = webgpu_ctx->max_wg_size_x;
+ constants[0].value = wg_size;
return constants;
}
@@ -1309,11 +1393,11 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline, wgsl_set_rows, "set_rows",
- ggml_webgpu_max_wg_size_entry(webgpu_ctx));
+ ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
}
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
- std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_pipeline[GGML_TYPE_F32], wgsl_get_rows_f32_vec,
"get_rows_f32_vec", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->get_rows_f32_no_vec_pipeline, wgsl_get_rows_f32,
@@ -1363,7 +1447,7 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
- std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
@@ -1375,7 +1459,7 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
- std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16",
@@ -1387,7 +1471,7 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
- std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16",
@@ -1399,7 +1483,7 @@ static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
- std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16",
@@ -1411,7 +1495,7 @@ static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
- std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16",
@@ -1423,7 +1507,7 @@ static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
- std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace,
@@ -1431,7 +1515,7 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
- std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32,
"rope_f32", constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1],
@@ -1451,7 +1535,7 @@ static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
- std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
// reglu
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0],
wgsl_reglu_f32, "reglu_f32", constants);
@@ -1505,13 +1589,43 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
}
static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
- std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
+ std::vector constants = ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32",
constants);
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace,
"scale_f32_inplace", constants);
}
+static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
+ std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][0], wgsl_soft_max_f32,
+ "soft_max_f32", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][0][1], wgsl_soft_max_f32_inplace,
+ "soft_max_f32_inplace", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][0], wgsl_soft_max_f32_sink,
+ "soft_max_f32_sink", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[2][1][1],
+ wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][0], wgsl_soft_max_f32_mask_f32,
+ "soft_max_f32_mask_f32", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][0][1],
+ wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][0], wgsl_soft_max_f32_mask_f16,
+ "soft_max_f32_mask_f16", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][0][1],
+ wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][0],
+ wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[0][1][1],
+ wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace",
+ constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][0],
+ wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
+ ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->soft_max_pipeline[1][1][1],
+ wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace",
+ constants);
+}
+
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
@@ -1593,6 +1707,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
ggml_tensor * src0 = op->src[0];
ggml_tensor * src1 = op->src[1];
+ ggml_tensor * src2 = op->src[2];
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
@@ -1623,7 +1738,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
break;
case GGML_OP_SET_ROWS:
- supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
+ supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
break;
case GGML_OP_GET_ROWS:
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
@@ -1695,16 +1810,31 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_SCALE:
supports_op = op->type == GGML_TYPE_F32;
break;
+ case GGML_OP_SOFT_MAX:
+ supports_op = op->type == GGML_TYPE_F32;
+ break;
default:
break;
}
-#ifdef GGML_WEBGPU_DEBUG
- if (!supports_op) {
- WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
- << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
- << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
+ if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
+ (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
+ (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
+ (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
+ supports_op = false;
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
+ }
+
+ if (!supports_op) {
+ WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
+ } else {
+ WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
+ << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
+ << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
+ << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
}
-#endif
return supports_op;
}
@@ -1826,6 +1956,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
ggml_webgpu_init_rope_pipeline(ctx);
ggml_webgpu_init_glu_pipeline(ctx);
ggml_webgpu_init_scale_pipeline(ctx);
+ ggml_webgpu_init_soft_max_pipeline(ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl
index a275eeb978..712b921f1a 100644
--- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl
@@ -71,28 +71,53 @@ var src: array;
DECLS
override wg_size: u32;
+var scratch: array;
+
@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3) {
- if (gid.x >= params.ne1 * params.ne2 * params.ne3) {
- return;
- }
+fn main(@builtin(workgroup_id) wid: vec3,
+ @builtin(local_invocation_id) lid: vec3) {
// one thread per row
- var i = gid.x;
+ var i = wid.x;
let i3 = i / (params.ne2 * params.ne1);
i = i % (params.ne2 * params.ne1);
let i2 = i / params.ne1;
let i1 = i % params.ne1;
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
- let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+ let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+
+ let elems = (params.ne0 + wg_size - 1) / wg_size;
var sum = 0.0f;
- for (var j: u32 = 0; j < params.ne0; j++) {
- sum += src[i_src_row + j] * src[i_src_row + j];
+ var col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ sum += pow(src[i_src_row + col], 2.0);
+ col += wg_size;
}
+
+ scratch[lid.x] = sum;
+ workgroupBarrier();
+ var offset = wg_size / 2;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ scratch[lid.x] += scratch[lid.x + offset];
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+ sum = scratch[0];
+
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
- for (var j: u32 = 0; j < params.ne0; j++) {
- update(i_src_row + j, i_dst_row + j, scale);
+ col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ update(i_src_row + col, i_dst_row + col, scale);
+ col += wg_size;
}
}
#end(SHADER)
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl
new file mode 100644
index 0000000000..c74dc4cc92
--- /dev/null
+++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl
@@ -0,0 +1,345 @@
+#define(VARIANTS)
+[
+ {
+ "SHADER_NAME": "soft_max_f32",
+ "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_inplace",
+ "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_sink",
+ "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_sink_inplace",
+ "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32_sink",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f32",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16_sink",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
+ },
+ {
+ "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
+ "REPLS": {
+ "MASK_TYPE" : "f16",
+ },
+ "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
+ }
+]
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(BASE_BINDINGS)
+@group(0) @binding(1)
+var dst: array;
+
+@group(0) @binding(2)
+var params: Params;
+#enddecl(BASE_BINDINGS)
+
+#decl(BASE_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var params: Params;
+#enddecl(BASE_BINDINGS_INPLACE)
+
+#decl(SINK_BINDINGS)
+@group(0) @binding(1)
+var sinks: array;
+
+@group(0) @binding(2)
+var dst: array;
+
+@group(0) @binding(3)
+var params: Params;
+#enddecl(SINK_BINDINGS)
+
+#decl(SINK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var sinks: array;
+
+@group(0) @binding(2)
+var params: Params;
+#enddecl(SINK_BINDINGS_INPLACE)
+
+#decl(MASK_BINDINGS)
+@group(0) @binding(1)
+var mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var dst: array;
+
+@group(0) @binding(3)
+var params: Params;
+#enddecl(MASK_BINDINGS)
+
+#decl(MASK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var params: Params;
+#enddecl(MASK_BINDINGS_INPLACE)
+
+#decl(MASK_SINK_BINDINGS)
+@group(0) @binding(1)
+var mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var sinks: array;
+
+@group(0) @binding(3)
+var dst: array;
+
+@group(0) @binding(4)
+var params: Params;
+#enddecl(MASK_SINK_BINDINGS)
+
+#decl(MASK_SINK_BINDINGS_INPLACE)
+@group(0) @binding(1)
+var mask: array<{{MASK_TYPE}}>;
+
+@group(0) @binding(2)
+var sinks: array;
+
+@group(0) @binding(3)
+var params: Params;
+#enddecl(MASK_SINK_BINDINGS_INPLACE)
+
+#decl(NOT_INPLACE)
+fn inter_value(i: u32) -> f32 {
+ return dst[i];
+}
+
+fn update(i: u32, val: f32) {
+ dst[i] = val;
+}
+#enddecl(NOT_INPLACE)
+
+#decl(INPLACE)
+fn inter_value(i: u32) -> f32 {
+ return src[i];
+}
+
+fn update(i: u32, val: f32) {
+ src[i] = val;
+}
+#enddecl(INPLACE)
+
+#decl(NO_MASK)
+fn mask_val(i: u32) -> f32 {
+ return 0.0;
+}
+#enddecl(NO_MASK)
+
+#decl(MASK)
+fn mask_val(i: u32) -> f32 {
+ return f32(mask[i]);
+}
+#enddecl(MASK)
+
+#decl(NO_SINK)
+fn lower_max_bound(i2: u32) -> f32 {
+ return -1e30;
+}
+
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+ return val;
+}
+#enddecl(NO_SINK)
+
+#decl(SINK)
+fn lower_max_bound(i2: u32) -> f32 {
+ return sinks[params.offset_sinks + i2];
+}
+
+fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
+ return val + exp(sinks[params.offset_sinks + i2] - max_val);
+}
+#enddecl(SINK)
+
+#end(DECLS)
+
+#define(SHADER)
+enable f16;
+
+struct Params {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_sinks: u32,
+ offset_dst: u32,
+
+ // Strides (in elements)
+ stride_src01: u32,
+ stride_src02: u32,
+ stride_src03: u32,
+
+ stride_src11: u32,
+ stride_src12: u32,
+ stride_src13: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ // shape of src0/dst
+ ne: u32,
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ // shape of src1
+ ne12: u32,
+ ne13: u32,
+
+ scale: f32,
+ max_bias: f32,
+ n_head_log2: f32,
+ m0: f32,
+ m1: f32,
+};
+
+@group(0) @binding(0)
+var src: array;
+
+DECLS
+
+const CACHE_SIZE: u32 = 16;
+
+override wg_size: u32;
+var scratch: array;
+
+@compute @workgroup_size(wg_size)
+fn main(@builtin(workgroup_id) wid: vec3,
+ @builtin(local_invocation_id) lid: vec3) {
+
+ var i = wid.x;
+ let i3 = i / (params.ne2 * params.ne1);
+ i = i % (params.ne2 * params.ne1);
+ let i2 = i / params.ne1;
+ let i1 = i % params.ne1;
+ let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
+ let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
+ let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+ let elems = (params.ne0 + wg_size - 1) / wg_size;
+
+ let head = f32(i2);
+ let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
+
+ var cache: array;
+
+ var max_val = lower_max_bound(i2);
+ var col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
+ max_val = max(max_val, val);
+ if (col < CACHE_SIZE) {
+ cache[col] = val;
+ }
+ col += wg_size;
+ }
+
+ scratch[lid.x] = max_val;
+ workgroupBarrier();
+ var offset = wg_size / 2;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+ let row_max = scratch[0];
+ workgroupBarrier();
+
+ var sum = 0.0f;
+ col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
+ cache[col], col < CACHE_SIZE);
+ let ex = exp(val - row_max);
+ sum += ex;
+ if (col < CACHE_SIZE) {
+ cache[col] = ex;
+ } else {
+ update(i_dst_row + col, ex);
+ }
+ col += wg_size;
+ }
+
+ scratch[lid.x] = sum;
+ workgroupBarrier();
+ offset = wg_size / 2;
+ while (offset > 0) {
+ if (lid.x < offset) {
+ scratch[lid.x] += scratch[lid.x + offset];
+ }
+ offset = offset / 2;
+ workgroupBarrier();
+ }
+ let row_sum = add_sinks(scratch[0], i2, row_max);
+
+ let sum_recip = 1.0 / row_sum;
+ col = lid.x;
+ for (var j: u32 = 0; j < elems; j++) {
+ if (col >= params.ne0) {
+ break;
+ }
+ update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
+ col += wg_size;
+ }
+}
+#end(SHADER)
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index aecbdad5a3..2bce1375ba 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -1143,10 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSIGMOID",
"EXP",
"GELU_ERF",
+ "XIELU",
};
-static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
-
+static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16");
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU",
@@ -2652,6 +2652,29 @@ struct ggml_tensor * ggml_silu_inplace(
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
}
+// ggml_xielu
+
+struct ggml_tensor * ggml_xielu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float alpha_n,
+ float alpha_p,
+ float beta,
+ float eps) {
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
+ ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n));
+ ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p));
+ ggml_set_op_params_f32(result, 3, beta);
+ ggml_set_op_params_f32(result, 4, eps);
+
+ result->op = GGML_OP_UNARY;
+ result->src[0] = a;
+
+ return result;
+}
+
// ggml_silu_back
struct ggml_tensor * ggml_silu_back(
@@ -3829,6 +3852,15 @@ struct ggml_tensor * ggml_soft_max_ext(
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
}
+struct ggml_tensor * ggml_soft_max_ext_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * mask,
+ float scale,
+ float max_bias) {
+ return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true);
+}
+
void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks) {
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 19bef55bae..492142e57e 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -262,6 +262,7 @@ class Keys:
class ClipVision:
IMAGE_SIZE = "clip.vision.image_size"
+ PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size"
PATCH_SIZE = "clip.vision.patch_size"
EMBEDDING_LENGTH = "clip.vision.embedding_length"
FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length"
@@ -298,6 +299,13 @@ class Keys:
class Diffusion:
SHIFT_LOGITS = "diffusion.shift_logits"
+ class xIELU:
+ ALPHA_P = "xielu.alpha_p"
+ ALPHA_N = "xielu.alpha_n"
+ BETA = "xielu.beta"
+ EPS = "xielu.eps"
+
+
#
# recommended mapping of model tensor names for storage in gguf
#
@@ -407,6 +415,7 @@ class MODEL_ARCH(IntEnum):
LLADA_MOE = auto()
SEED_OSS = auto()
GROVEMOE = auto()
+ APERTUS = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -749,6 +758,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLADA_MOE: "llada-moe",
MODEL_ARCH.SEED_OSS: "seed_oss",
MODEL_ARCH.GROVEMOE: "grovemoe",
+ MODEL_ARCH.APERTUS: "apertus",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2723,6 +2733,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
+ MODEL_ARCH.APERTUS: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
MODEL_ARCH.LLADA_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index 02188f8370..7a30e47fb9 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -1040,6 +1040,9 @@ class GGUFWriter:
def add_vision_image_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value)
+ def add_vision_preproc_image_size(self, value: int) -> None:
+ self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value)
+
def add_vision_image_mean(self, values: Sequence[float]) -> None:
self.add_array(Keys.ClipVision.IMAGE_MEAN, values)
@@ -1087,6 +1090,18 @@ class GGUFWriter:
def add_audio_stack_factor(self, value: int) -> None:
self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
+ def add_xielu_alpha_p(self, values: Sequence[float]):
+ self.add_array(Keys.xIELU.ALPHA_P, values)
+
+ def add_xielu_alpha_n(self, values: Sequence[float]):
+ self.add_array(Keys.xIELU.ALPHA_N, values)
+
+ def add_xielu_beta(self, values: Sequence[float]):
+ self.add_array(Keys.xIELU.BETA, values)
+
+ def add_xielu_eps(self, values: Sequence[float]):
+ self.add_array(Keys.xIELU.EPS, values)
+
# diffusion models
def add_diffusion_shift_logits(self, value: bool) -> None:
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 888eda67e6..29fa937852 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -152,6 +152,7 @@ class TensorNameMap:
"model.layers.{bid}.operator_norm", # lfm2
"model.transformer.blocks.{bid}.attn_norm", # llada
"layers.{bid}.input_layernorm", # qwen3-embedding,
+ "model.layers.{bid}.attention_layernorm" # apertus
),
# Attention norm 2
@@ -331,6 +332,7 @@ class TensorNameMap:
"model.layers.layers.{bid}.pre_mlp_norm", # plamo2
"model.transformer.blocks.{bid}.ff_norm", # llada
"layers.{bid}.post_attention_layernorm", # qwen3-embedding
+ "model.layers.{bid}.feedforward_layernorm", # apertus
"layers.{bid}.mlp_norm" # modern bert
),
@@ -556,6 +558,7 @@ class TensorNameMap:
"transformer.layers.{bid}.attn.q_norm", # openelm
"model.layers.layers.{bid}.mixer.q", # plamo2
"layers.{bid}.self_attn.q_norm", # qwen3-embedding
+ "model.layers.{bid}.attention.query_layernorm", # apertus
),
MODEL_TENSOR.ATTN_K_NORM: (
@@ -569,6 +572,7 @@ class TensorNameMap:
"transformer.layers.{bid}.attn.k_norm", # openelm
"model.layers.layers.{bid}.mixer.k", # plamo2
"layers.{bid}.self_attn.k_norm", # qwen3-embedding
+ "model.layers.{bid}.attention.key_layernorm", # apertus
),
MODEL_TENSOR.ROPE_FREQS: (
diff --git a/include/llama.h b/include/llama.h
index 452d9ec5bf..8fc3d7db5a 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -543,6 +543,9 @@ extern "C" {
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
+ // Returns true if the model is hybrid (like Jamba, Granite, etc.)
+ LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
+
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
@@ -791,8 +794,12 @@ extern "C" {
size_t n_token_capacity,
size_t * n_token_count_out);
+// for backwards-compat
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
+// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba)
+#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
+
typedef uint32_t llama_state_seq_flags;
LLAMA_API size_t llama_state_seq_get_size_ext(
diff --git a/models/templates/Apertus-8B-Instruct.jinja b/models/templates/Apertus-8B-Instruct.jinja
new file mode 100644
index 0000000000..10826ff690
--- /dev/null
+++ b/models/templates/Apertus-8B-Instruct.jinja
@@ -0,0 +1,327 @@
+{%- macro render_typescript_type(param_spec, required_params, is_nullable=false) -%}
+ {%- if param_spec.type == "array" -%}
+ {%- if param_spec['items'] -%}
+ {%- if param_spec['items']['type'] == "string" -%}
+ {{- "string[]" }}
+ {%- elif param_spec['items']['type'] == "number" -%}
+ {{- "number[]" }}
+ {%- elif param_spec['items']['type'] == "integer" -%}
+ {{- "number[]" }}
+ {%- elif param_spec['items']['type'] == "boolean" -%}
+ {{- "boolean[]" }}
+ {%- else -%}
+ {%- set inner_type = render_typescript_type(param_spec['items'], required_params) -%}
+ {%- if inner_type == "object | object" or inner_type|length > 50 -%}
+ {{- "any[]" }}
+ {%- else -%}
+ {{- inner_type + "[]" }}
+ {%- endif -%}
+ {%- endif -%}
+ {%- if param_spec.nullable -%}
+ {{- " | null" }}
+ {%- endif -%}
+ {%- else -%}
+ {{- "any[]" }}
+ {%- if param_spec.nullable -%}
+ {{- " | null" }}
+ {%- endif -%}
+ {%- endif -%}
+ {%- elif param_spec.type is defined and param_spec.type is iterable and param_spec.type is not string and param_spec.type is not mapping and param_spec.type[0] is defined -%}
+ {#- Handle array of types like ["object", "object"] from Union[dict, list] #}
+ {%- if param_spec.type | length > 1 -%}
+ {{- param_spec.type | join(" | ") }}
+ {%- else -%}
+ {{- param_spec.type[0] }}
+ {%- endif -%}
+ {%- elif param_spec.oneOf -%}
+ {#- Handle oneOf schemas - check for complex unions and fallback to any #}
+ {%- set has_object_variants = false -%}
+ {%- for variant in param_spec.oneOf -%}
+ {%- if variant.type == "object" -%}
+ {%- set has_object_variants = true -%}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- if has_object_variants and param_spec.oneOf|length > 1 -%}
+ {{- "any" }}
+ {%- else -%}
+ {%- for variant in param_spec.oneOf -%}
+ {{- render_typescript_type(variant, required_params) -}}
+ {%- if variant.description %}
+ {{- "// " + variant.description }}
+ {%- endif -%}
+ {%- if variant.default is defined %}
+ {{ "// default: " + variant.default|tojson }}
+ {%- endif -%}
+ {%- if not loop.last %}
+ {{- " | " }}
+ {% endif -%}
+ {%- endfor -%}
+ {%- endif -%}
+ {%- elif param_spec.type == "string" -%}
+ {%- if param_spec.enum -%}
+ {{- '"' + param_spec.enum|join('" | "') + '"' -}}
+ {%- else -%}
+ {{- "string" }}
+ {%- if param_spec.nullable %}
+ {{- " | null" }}
+ {%- endif -%}
+ {%- endif -%}
+ {%- elif param_spec.type == "number" -%}
+ {{- "number" }}
+ {%- elif param_spec.type == "integer" -%}
+ {{- "number" }}
+ {%- elif param_spec.type == "boolean" -%}
+ {{- "boolean" }}
+ {%- elif param_spec.type == "object" -%}
+ {%- if param_spec.properties -%}
+ {{- "{\n" }}
+ {%- for prop_name, prop_spec in param_spec.properties.items() -%}
+ {{- prop_name -}}
+ {%- if prop_name not in (param_spec.required or []) -%}
+ {{- "?" }}
+ {%- endif -%}
+ {{- ": " }}
+ {{ render_typescript_type(prop_spec, param_spec.required or []) }}
+ {%- if not loop.last -%}
+ {{-", " }}
+ {%- endif -%}
+ {%- endfor -%}
+ {{- "}" }}
+ {%- else -%}
+ {{- "object" }}
+ {%- endif -%}
+ {%- else -%}
+ {{- "any" }}
+ {%- endif -%}
+{%- endmacro -%}
+
+{%- macro render_tools(tools) -%}
+ {%- for tool in tools %}
+ {{- "// " + tool.description + "\n" }}
+ {{- "type "+ tool.name + " = " }}
+ {%- if tool.parameters and tool.parameters.properties %}
+ {{- "(_: {\n" }}
+ {%- for param_name, param_spec in tool.parameters.properties.items() %}
+ {%- if param_spec.description %}
+ {{- "// " + param_spec.description + "\n" }}
+ {%- endif %}
+ {{- param_name }}
+ {%- if param_name not in (tool.parameters.required or []) -%}
+ {{- "?" }}
+ {%- endif -%}
+ {{- ": " }}
+ {{- render_typescript_type(param_spec, tool.parameters.required or []) }}
+ {%- if param_spec.default is defined -%}
+ {%- if param_spec.enum %}
+ {{- ", // default: " + param_spec.default }}
+ {%- elif param_spec.oneOf %}
+ {{- "// default: " + param_spec.default }}
+ {%- else %}
+ {{- ", // default: " + param_spec.default|tojson }}
+ {%- endif -%}
+ {%- endif -%}
+ {%- if not loop.last %}
+ {{- ",\n" }}
+ {%- else %}
+ {{- "\n" }}
+ {%- endif -%}
+ {%- endfor %}
+ {{- "}) => any;" }}
+ {%- else -%}
+ {{- "() => any;" }}
+ {%- endif -%}
+ {%- if not loop.last -%}
+ {{- "\n" }}
+ {%- endif -%}
+ {%- endfor %}
+{%- endmacro -%}
+
+{{ bos_token }}
+
+{%- set system_token = '<|system_start|>' -%}
+{%- set end_system_token = '<|system_end|>' -%}
+{%- set developer_token = '<|developer_start|>' -%}
+{%- set end_developer_token = '<|developer_end|>' -%}
+{%- set user_token = '<|user_start|>' -%}
+{%- set end_user_token = '<|user_end|>' -%}
+{%- set assistant_token = '<|assistant_start|>' -%}
+{%- set end_assistant_token = '<|assistant_end|>' -%}
+{%- set inner_token = '<|inner_prefix|>' -%}
+{%- set outer_token = '<|inner_suffix|>' -%}
+{%- set tool_calls_token = '<|tools_prefix|>' -%}
+{%- set end_tool_calls_token = '<|tools_suffix|>' -%}
+
+{%- set ns = namespace(in_assistant=false, in_tool=false, in_inner=false, assistant_format=none) -%}
+
+{%- if messages and messages[0].role == 'system' -%}
+ {%- if "content" in messages[0] -%}
+ {%- if messages[0].content is string -%}
+ {{ system_token + messages[0].content + end_system_token }}
+ {%- elif messages[0].content is mapping and "text" in messages[0].content -%}
+ {{ system_token + messages[0].content.text + end_system_token }}
+ {%- else -%}
+ {{- raise_exception("Invalid system message") -}}
+ {%- endif -%}
+ {%- else -%}
+ {{- raise_exception("Invalid system message") -}}
+ {%- endif -%}
+ {%- set loop_messages = messages[1:] -%}
+{%- else -%}
+ {{ system_token + 'You are Apertus, a helpful assistant created by the SwissAI initiative.\nKnowledge cutoff: 2024-04\nCurrent date: ' + strftime_now('%Y-%m-%d') + end_system_token }}
+ {%- set loop_messages = messages -%}
+{%- endif -%}
+
+{{ developer_token + 'Deliberation: ' }}
+{%- if enable_thinking is defined and enable_thinking -%}
+ {{ 'enabled\n' }}
+{%- else -%}
+ {{ 'disabled\n' }}
+{%- endif -%}
+{%- if tools is defined and tools -%}
+ {{ 'Tool Capabilities:\n' + render_tools(tools) }}
+{%- else -%}
+ {{ 'Tool Capabilities: disabled' }}
+{%- endif -%}
+{{ end_developer_token }}
+
+{%- for message in loop_messages -%}
+ {%- if message.role == 'user' -%}
+ {%- set ns.in_inner = false -%}
+ {%- if ns.in_tool -%}
+ {{ ']' }}
+ {%- set ns.in_tool = false -%}
+ {%- endif -%}
+ {%- if ns.in_assistant -%}
+ {{ end_assistant_token }}
+ {%- set ns.in_assistant = false -%}
+ {%- endif -%}
+ {%- if "content" in message -%}
+ {{ user_token }}
+ {%- if message.content is string -%}
+ {{ message.content }}
+ {%- elif message.content is mapping and "parts" in message.content -%}
+ {%- set parts = message.content.parts -%}
+ {%- for part in parts -%}
+ {%- if part.type == "text" -%}
+ {{ part.text }}
+ {%- else -%}
+ {{- raise_exception("Invalid user part: " + part.type) -}}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- else -%}
+ {{- raise_exception("Invalid user message: " + message.role) -}}
+ {%- endif -%}
+ {{ end_user_token }}
+ {%- endif -%}
+ {%- elif message.role == 'assistant' -%}
+ {%- if not ns.in_assistant -%}
+ {{ assistant_token }}
+ {%- set ns.in_assistant = true -%}
+ {%- endif -%}
+ {%- if "content" in message and message.content is not none -%}
+ {%- if message.content is string and (ns.assistant_format is none or ns.assistant_format == "string") -%}
+ {%- if ns.in_tool -%}
+ {{ ']' }}
+ {%- set ns.in_tool = false -%}
+ {%- endif -%}
+ {%- set ns.assistant_format = "string" -%}
+ {{ message.content }}
+ {%- elif message.content is mapping and "blocks" in message.content and (ns.assistant_format is none or ns.assistant_format == "mapping") -%}
+ {%- set ns.assistant_format = "mapping" -%}
+ {%- set blocks = message.content.blocks -%}
+ {%- for block in blocks -%}
+ {%- if block.type == 'thoughts' -%}
+ {%- if ns.in_tool -%}
+ {{ ']' }}
+ {%- set ns.in_tool = false -%}
+ {%- endif -%}
+ {%- if not ns.in_inner -%}
+ {%- set ns.in_inner = true -%}
+ {{ inner_token }}
+ {%- endif -%}
+ {{ block.text }}
+ {%- elif block.type == 'tool_calls' -%}
+ {%- if ns.in_tool -%}
+ {{ ']' }}
+ {%- set ns.in_tool = false -%}
+ {%- endif -%}
+ {%- if ns.in_inner and not loop.first and block.calls|length == 1 and block.calls[0].name == 'display_answers' -%}
+ {%- set ns.in_inner = false -%}
+ {{ outer_token }}
+ {%- endif -%}
+ {{ tool_calls_token + '[' }}
+ {%- for tool_call in block.calls -%}
+ {{- '{"' + tool_call.name + '": ' + tool_call.arguments + '}' }}
+ {%- if not loop.last -%}
+ {{- ", " }}
+ {%- endif -%}
+ {%- endfor -%}
+ {{ ']' + end_tool_calls_token }}
+ {%- elif block.type == 'tool_outputs' -%}
+ {%- if ns.in_tool -%}
+ {{- raise_exception("Cannot have both tool outputs as separate messages and tool outputs as blocks") -}}
+ {%- endif -%}
+ {{ '[' }}
+ {%- for tool_output in block.outputs -%}
+ {{- tool_output.output }}
+ {%- if not loop.last -%}
+ {{- ", " }}
+ {%- endif -%}
+ {%- endfor -%}
+ {{- ']' }}
+ {%- elif block.type == 'response' -%}
+ {%- if ns.in_tool -%}
+ {{ ']' }}
+ {%- set ns.in_tool = false -%}
+ {%- endif -%}
+ {%- if (not loop.first and ns.in_inner) or (ns.in_assistant and ns.in_inner) -%}
+ {%- set ns.in_inner = false -%}
+ {{ outer_token }}
+ {%- endif -%}
+ {{ block.text }}
+ {%- else -%}
+ {{- raise_exception("Invalid assistant block type: " + block.type) -}}
+ {%- endif -%}
+ {%- endfor -%}
+ {%- else -%}
+ {{- raise_exception("Invalid assistant content '" + message.content + "', expected " + ns.assistant_format) -}}
+ {%- endif -%}
+ {%- elif "tool_calls" not in message -%}
+ {{- raise_exception("Invalid assistant message " + message) -}}
+ {%- endif -%}
+ {%- if "tool_calls" in message and message.tool_calls -%}
+ {{ tool_calls_token + '[' }}
+ {%- for tool_call in message.tool_calls -%}
+ {%- if tool_call.type == 'function' -%}
+ {%- set function = tool_call.function -%}
+ {{- '{"' + function.name + '": ' + function.arguments + '}' }}
+ {%- if not loop.last -%}
+ {{- ", " }}
+ {%- endif -%}
+ {%- else -%}
+ {{- raise_exception("Invalid tool call type: " + tool_call.type) -}}
+ {%- endif -%}
+ {%- endfor -%}
+ {{ ']' + end_tool_calls_token }}
+ {%- endif -%}
+ {%- elif message.role == 'tool' -%}
+ {%- if not ns.in_assistant -%}
+ {{- raise_exception("Tool message outside of assistant") -}}
+ {%- endif -%}
+ {%- if not ns.in_tool -%}
+ {{ '[' }}
+ {%- set ns.in_tool = true -%}
+ {%- else -%}
+ {{ ", "}}
+ {%- endif -%}
+ {{ message.content }}
+ {%- else -%}
+ {{- raise_exception("Invalid message role") -}}
+ {%- endif -%}
+{%- endfor -%}
+{%- if ns.in_tool -%}
+ {{ ']' }}
+{%- endif -%}
+{%- if add_generation_prompt -%}
+ {{ assistant_token }}
+{%- endif -%}
\ No newline at end of file
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index 48f92c863e..f8f2d09778 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -100,6 +100,7 @@ static const std::map LLM_ARCH_NAMES = {
{ LLM_ARCH_LLADA_MOE, "llada-moe" },
{ LLM_ARCH_SEED_OSS, "seed_oss" },
{ LLM_ARCH_GROVEMOE, "grovemoe" },
+ { LLM_ARCH_APERTUS, "apertus" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -259,6 +260,11 @@ static const std::map LLM_KV_NAMES = {
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },
{ LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" },
+ { LLM_KV_XIELU_ALPHA_N, "xielu.alpha_n" },
+ { LLM_KV_XIELU_ALPHA_P, "xielu.alpha_p" },
+ { LLM_KV_XIELU_BETA, "xielu.beta" },
+ { LLM_KV_XIELU_EPS, "xielu.eps" },
+
// deprecated
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
{ LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" },
@@ -2139,6 +2145,25 @@ static const std::map> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }
},
},
+ {
+ LLM_ARCH_APERTUS,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_DREAM,
{
diff --git a/src/llama-arch.h b/src/llama-arch.h
index 8ec51a1daa..ae6559afd3 100644
--- a/src/llama-arch.h
+++ b/src/llama-arch.h
@@ -104,6 +104,7 @@ enum llm_arch {
LLM_ARCH_LLADA_MOE,
LLM_ARCH_SEED_OSS,
LLM_ARCH_GROVEMOE,
+ LLM_ARCH_APERTUS,
LLM_ARCH_UNKNOWN,
};
@@ -262,6 +263,11 @@ enum llm_kv {
LLM_KV_SHORTCONV_L_CACHE,
+ LLM_KV_XIELU_ALPHA_N,
+ LLM_KV_XIELU_ALPHA_P,
+ LLM_KV_XIELU_BETA,
+ LLM_KV_XIELU_EPS,
+
// deprecated:
LLM_KV_TOKENIZER_PREFIX_ID,
LLM_KV_TOKENIZER_SUFFIX_ID,
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
index 0fe4b56942..f29b23eeff 100644
--- a/src/llama-hparams.h
+++ b/src/llama-hparams.h
@@ -42,7 +42,7 @@ struct llama_hparams {
uint32_t n_embd;
uint32_t n_embd_features = 0;
uint32_t n_layer;
- int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
+ int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
uint32_t n_rot;
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
@@ -169,6 +169,12 @@ struct llama_hparams {
uint32_t laurel_rank = 64;
uint32_t n_embd_altup = 256;
+ // xIELU
+ std::array xielu_alpha_n;
+ std::array xielu_alpha_p;
+ std::array xielu_beta;
+ std::array xielu_eps;
+
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp
index 827302e6d2..facba1d004 100644
--- a/src/llama-kv-cache-iswa.cpp
+++ b/src/llama-kv-cache-iswa.cpp
@@ -220,7 +220,7 @@ bool llama_kv_cache_iswa::get_can_shift() const {
}
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
- if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+ if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
kv_base->state_write(io, seq_id, flags);
}
@@ -228,7 +228,7 @@ void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id
}
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
- if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
+ if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
kv_base->state_read(io, seq_id, flags);
}
diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp
index abf652483c..cb8832a353 100644
--- a/src/llama-memory-hybrid.cpp
+++ b/src/llama-memory-hybrid.cpp
@@ -175,17 +175,17 @@ std::map llama_memory_hybrid::memory_breakdo
}
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
- GGML_UNUSED(flags);
-
- mem_attn->state_write(io, seq_id);
- mem_recr->state_write(io, seq_id);
+ if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
+ mem_attn->state_write(io, seq_id, flags);
+ }
+ mem_recr->state_write(io, seq_id, flags);
}
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
- GGML_UNUSED(flags);
-
- mem_attn->state_read(io, seq_id);
- mem_recr->state_read(io, seq_id);
+ if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
+ mem_attn->state_read(io, seq_id, flags);
+ }
+ mem_recr->state_read(io, seq_id, flags);
}
llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp
index 44645fcdd2..e23e74982b 100644
--- a/src/llama-memory-recurrent.cpp
+++ b/src/llama-memory-recurrent.cpp
@@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) {
}
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+ //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
uint32_t new_head = size;
if (p0 < 0) {
@@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
if (tail_id >= 0) {
const auto & cell = cells[tail_id];
// partial intersection is invalid
- if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+ if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
+ //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
return false;
}
// invalidate tails which will be cleared
@@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
} else {
// seq_id is negative, then the range should include everything or nothing
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) {
+ //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n");
return false;
}
}
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
index 8182a9adf5..aa3a65f87a 100644
--- a/src/llama-model-loader.cpp
+++ b/src/llama-model-loader.cpp
@@ -465,6 +465,8 @@ namespace GGUFMeta {
// TODO: this is not very clever - figure out something better
template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required);
template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required);
+ template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required);
+
llama_model_loader::llama_model_loader(
const std::string & fname,
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index dcb0d5bf6c..a657a7cc71 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -512,9 +512,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
llm_arch_is_recurrent(ml.get_arch()));
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
-
std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0);
+ std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f);
+ std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f);
+ std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f);
+ std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f);
+
ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false);
ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
@@ -1103,7 +1107,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
}
break;
default: type = LLM_TYPE_UNKNOWN;
- }
+ }
+
+ // Load attention parameters
+ ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false);
+ ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
} break;
case LLM_ARCH_GPT2:
{
@@ -2048,6 +2056,19 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_APERTUS:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer);
+ ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer);
+ ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer);
+ ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer);
+
+ switch (hparams.n_layer) {
+ case 32: type = LLM_TYPE_8B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
default: throw std::runtime_error("unsupported model architecture");
}
@@ -3442,17 +3463,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_PLAMO2:
{
+ // mamba parameters
const uint32_t d_conv = hparams.ssm_d_conv;
const uint32_t d_state = hparams.ssm_d_state;
const uint32_t num_heads = hparams.ssm_dt_rank;
const uint32_t intermediate_size = hparams.ssm_d_inner;
- const uint32_t head_dim = intermediate_size / num_heads;
- const uint32_t qk_dim = head_dim;
- const uint32_t v_dim = head_dim;
- const int64_t num_attention_heads = hparams.n_head();
- const int64_t q_num_heads = num_attention_heads;
const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16));
+ // attention parameters
+ const uint32_t qk_dim = hparams.n_embd_head_k;
+ const uint32_t v_dim = hparams.n_embd_head_v;
+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
@@ -3486,6 +3507,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0);
layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0);
} else {
+ const int64_t num_attention_heads = hparams.n_head(i);
+ const int64_t q_num_heads = num_attention_heads;
const int64_t num_key_value_heads = hparams.n_head_kv(i);
const int64_t k_num_heads = num_key_value_heads;
const int64_t v_num_heads = num_key_value_heads;
@@ -3494,8 +3517,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t v_proj_dim = v_num_heads * v_dim;
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0);
- layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0);
- layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0);
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0);
}
@@ -5959,6 +5982,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0);
}
} break;
+ case LLM_ARCH_APERTUS:
+ {
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+
+ if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
+ layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+ layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+ } else {
+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+ }
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
+
+ // optional bias tensors
+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED);
+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED);
+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED);
+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
+
+ // Q and K layernorms for Apertus
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
+ layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
+ layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED);
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
@@ -7828,6 +7893,8 @@ struct llm_build_bert : public llm_graph_context {
}
if (model.layers[il].attn_q_norm) {
+ Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens);
+
Qcur = build_norm(Qcur,
model.layers[il].attn_q_norm,
model.layers[il].attn_q_norm_b,
@@ -7837,6 +7904,8 @@ struct llm_build_bert : public llm_graph_context {
}
if (model.layers[il].attn_k_norm) {
+ Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens);
+
Kcur = build_norm(Kcur,
model.layers[il].attn_k_norm,
model.layers[il].attn_k_norm_b,
@@ -8339,6 +8408,9 @@ struct llm_build_mpt : public llm_graph_context {
// Q/K Layernorm
if (model.layers[il].attn_q_norm) {
+ Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens);
+ Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens);
+
Qcur = build_norm(Qcur,
model.layers[il].attn_q_norm,
model.layers[il].attn_q_norm_b,
@@ -17781,6 +17853,7 @@ private:
const int64_t n_embd_head_q = hparams.n_embd_head_k;
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_head_v = hparams.n_embd_head_v;
+ int32_t n_head = hparams.n_head(il);
int32_t n_head_kv = hparams.n_head_kv(il);
const int64_t q_offset = 0;
@@ -19262,6 +19335,141 @@ struct llm_build_grovemoe : public llm_graph_context {
}
};
+struct llm_build_apertus : public llm_graph_context {
+ llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ ggml_tensor * inp_pos = build_inp_pos();
+ auto * inp_attn = build_attn_inp_kv();
+
+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_tensor * inpSA = inpL;
+
+ cur = build_norm(inpL,
+ model.layers[il].attn_norm, nullptr,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+ // compute Q and K and RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+ cb(Qcur, "Qcur_normed", il);
+
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+ cb(Kcur, "Kcur_normed", il);
+
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ cb(Qcur, "Qcur_pos", il);
+ cb(Kcur, "Kcur_pos", il);
+ cb(Vcur, "Vcur_pos", il);
+
+ cur = build_attn(inp_attn,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+ cb(cur, "attn_out", il);
+ }
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network with xIELU activation
+ {
+ cur = build_norm(ffn_inp,
+ model.layers[il].ffn_norm, nullptr,
+ LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ // Up projection
+ ggml_tensor * up = build_lora_mm(model.layers[il].ffn_up, cur);
+ cb(up, "ffn_up", il);
+
+ float alpha_n_val = hparams.xielu_alpha_n[il];
+ float alpha_p_val = hparams.xielu_alpha_p[il];
+ float beta_val = hparams.xielu_beta[il];
+ float eps_val = hparams.xielu_eps[il];
+
+ // Apply xIELU activation
+ ggml_tensor * activated = ggml_xielu(ctx0, up, alpha_n_val, alpha_p_val, beta_val, eps_val);
+ cb(activated, "ffn_xielu", il);
+
+ // Down projection
+ cur = build_lora_mm(model.layers[il].ffn_down, activated);
+ cb(cur, "ffn_down", il);
+ }
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "ffn_out", il);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, nullptr,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+};
+
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * res;
@@ -19797,6 +20005,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique(*this, params);
} break;
+ case LLM_ARCH_APERTUS:
+ {
+ llm = std::make_unique(*this, params);
+ } break;
default:
GGML_ABORT("fatal error");
}
@@ -20004,6 +20216,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GLM4_MOE:
case LLM_ARCH_SEED_OSS:
case LLM_ARCH_GROVEMOE:
+ case LLM_ARCH_APERTUS:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:
@@ -20114,6 +20327,10 @@ bool llama_model_is_recurrent(const llama_model * model) {
return llm_arch_is_recurrent(model->arch);
}
+bool llama_model_is_hybrid(const llama_model * model) {
+ return llm_arch_is_hybrid(model->arch);
+}
+
bool llama_model_is_diffusion(const llama_model * model) {
return llm_arch_is_diffusion(model->arch);
}
diff --git a/src/llama-model.h b/src/llama-model.h
index 53e3506369..f7819a6b64 100644
--- a/src/llama-model.h
+++ b/src/llama-model.h
@@ -381,6 +381,12 @@ struct llama_layer {
// openai-moe
struct ggml_tensor * attn_sinks = nullptr;
+ // xIELU activation parameters for Apertus
+ struct ggml_tensor * ffn_act_alpha_n = nullptr;
+ struct ggml_tensor * ffn_act_alpha_p = nullptr;
+ struct ggml_tensor * ffn_act_beta = nullptr;
+ struct ggml_tensor * ffn_act_eps = nullptr;
+
struct llama_layer_posnet posnet;
struct llama_layer_convnext convnext;
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 5a6dc436e2..0b84b9eec1 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -347,6 +347,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
case LLAMA_VOCAB_PRE_TYPE_OLMO:
case LLAMA_VOCAB_PRE_TYPE_JAIS:
case LLAMA_VOCAB_PRE_TYPE_TRILLION:
+ case LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING:
regex_exprs = {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
};
@@ -1962,6 +1963,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "trillion") {
pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION;
clean_spaces = false;
+ } else if (
+ tokenizer_pre == "granite-docling") {
+ pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING;
+ clean_spaces = false;
} else if (
tokenizer_pre == "bailingmoe" ||
tokenizer_pre == "llada-moe") {
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
index 0d2f28c36c..5e468675e4 100644
--- a/src/llama-vocab.h
+++ b/src/llama-vocab.h
@@ -8,46 +8,47 @@
// pre-tokenization types
enum llama_vocab_pre_type {
- LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
- LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
- LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
- LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
- LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
- LLAMA_VOCAB_PRE_TYPE_MPT = 5,
- LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
- LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
- LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
- LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
- LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
- LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
- LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
- LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
- LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
- LLAMA_VOCAB_PRE_TYPE_PORO = 15,
- LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
- LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
- LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
- LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
- LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
- LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
- LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
- LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
- LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
- LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
- LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
- LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
- LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
- LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
- LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
- LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
- LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
- LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
- LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
- LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
- LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
- LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
- LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
- LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
+ LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
+ LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
+ LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
+ LLAMA_VOCAB_PRE_TYPE_MPT = 5,
+ LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
+ LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
+ LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
+ LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
+ LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
+ LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
+ LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
+ LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
+ LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
+ LLAMA_VOCAB_PRE_TYPE_PORO = 15,
+ LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
+ LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
+ LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
+ LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
+ LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
+ LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
+ LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
+ LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
+ LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
+ LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
+ LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
+ LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
+ LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
+ LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
+ LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
+ LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
+ LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
+ LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
+ LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
+ LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
+ LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
+ LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
+ LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
+ LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40,
};
struct LLM_KV;
diff --git a/tests/test-alloc.cpp b/tests/test-alloc.cpp
index 2eb7724731..95e09c97b0 100644
--- a/tests/test-alloc.cpp
+++ b/tests/test-alloc.cpp
@@ -548,6 +548,41 @@ static void test_buffer_size_zero() {
GGML_ASSERT(backend_b.context->allocated_total() == 0);
}
+// Test re-using gallocr for a different graph. The new graph has the same
+// total size, but one of the chunks is larger, so reallocation is required.
+static void test_reallocation() {
+ dummy_backend backend = dummy_backend_init(32, /*align*/ 4);
+ ggml_gallocr_ptr galloc;
+ {
+ auto [ctx, graph, ctx_ptr] = make_context();
+ ggml_tensor * x[4];
+ x[0] = make_input_with_size(ctx, 24);
+ x[1] = make_input_with_size(ctx, 16);
+ x[2] = ggml_view_1d(ctx, x[0], 4, 0);
+ x[3] = ggml_add(ctx, x[2], x[1]);
+ assign_names(ctx);
+
+ galloc = allocate_graph(graph, x[3], &backend.buffer_type);
+ check_all_allocated(graph);
+ GGML_ASSERT(backend.context->allocated_total() == 40);
+ }
+ {
+ auto [ctx, graph, ctx_ptr] = make_context();
+ ggml_tensor * x[3];
+ x[0] = make_input_with_size(ctx, 20);
+ x[1] = make_input_with_size(ctx, 20);
+ x[2] = ggml_add(ctx, x[0], x[1]);
+ assign_names(ctx);
+ ggml_set_output(x[2]);
+ ggml_build_forward_expand(graph, x[2]);
+
+ bool result = ggml_gallocr_alloc_graph(galloc.get(), graph);
+ GGML_ASSERT(result);
+ check_all_allocated(graph);
+ GGML_ASSERT(backend.context->allocated_total() == 40);
+ }
+}
+
static void run(const char * name, void (*f)()) {
printf("%s ", name);
fflush(stdout);
@@ -568,5 +603,6 @@ int main() {
run("test_prefer_already_allocated_memory", test_prefer_already_allocated_memory);
run("test_multiple_buffer_types", test_multiple_buffer_types);
run("test_buffer_size_zero", test_buffer_size_zero);
+ run("test_reallocation", test_reallocation);
return 0;
}
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 62d815cc26..c1e45972e5 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -3752,9 +3752,10 @@ struct test_soft_max : public test_case {
const std::array nr23; // broadcast only dims 2 and 3
const float scale;
const float max_bias;
+ const bool inplace;
std::string vars() override {
- return VARS_TO_STR8(type, ne, mask, sinks, m_prec, nr23, scale, max_bias);
+ return VARS_TO_STR9(type, ne, mask, sinks, m_prec, nr23, scale, max_bias, inplace);
}
// the 1024 test with bias occasionally fails:
@@ -3770,8 +3771,9 @@ struct test_soft_max : public test_case {
ggml_type m_prec = GGML_TYPE_F32,
std::array nr23 = {1, 1},
float scale = 1.0f,
- float max_bias = 0.0f)
- : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
+ float max_bias = 0.0f,
+ bool inplace = false)
+ : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias), inplace(inplace) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
@@ -3790,7 +3792,12 @@ struct test_soft_max : public test_case {
ggml_set_name(sinks, "sinks");
}
- ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
+ ggml_tensor * out;
+ if (inplace) {
+ out = ggml_soft_max_ext_inplace(ctx, a, mask, scale, max_bias);
+ } else {
+ out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
+ }
ggml_soft_max_add_sinks(out, sinks);
ggml_set_name(out, "out");
@@ -6562,6 +6569,9 @@ static std::vector> make_test_cases_eval() {
}
}
}
+ // inplace tests
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f, true));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, mask, sinks, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f, true));
}
}
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
diff --git a/tests/test-barrier.cpp b/tests/test-barrier.cpp
index d85bf912b2..04c27761dc 100644
--- a/tests/test-barrier.cpp
+++ b/tests/test-barrier.cpp
@@ -1,6 +1,5 @@
#include "ggml.h"
#include "ggml-cpu.h"
-#include "ggml-backend.h"
#include
#include
@@ -8,12 +7,13 @@
#include
#include
#include
+#include
#define MAX_NARGS 2
int main(int argc, char *argv[]) {
- int n_threads = 4;
+ int n_threads = std::max(1, std::min(4, (int) std::thread::hardware_concurrency()));
int n_rounds = 100;
if (argc > 1) {
diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp
index ce0f4b0a2a..52e23b5ac6 100644
--- a/tests/test-chat.cpp
+++ b/tests/test-chat.cpp
@@ -411,6 +411,7 @@ const common_chat_msg message_assist_thoughts_unparsed_md = simple_assis
const common_chat_msg message_assist_thoughts_unparsed_md_partial = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}");
const common_chat_msg message_assist_thoughts_unparsed_r7b = simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?");
+const common_chat_msg message_assist_thoughts_unparsed_magistral = simple_assist_msg("[THINK]raisonnement[/THINK]Réponse");
const common_chat_msg message_assist_thoughts = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking");
const common_chat_msg message_assist_thoughts_unopened_unparsed = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?");
const common_chat_msg message_assist_thoughts_no_content = simple_assist_msg("", "I'm\nthinking");
@@ -745,6 +746,17 @@ static void test_template_output_parsers() {
tmpls.get(), end_tokens, message_assist_call_id, tools,
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
}
+ {
+ assert_msg_equals(
+ simple_assist_msg("Réponse", "raisonnement"),
+ common_chat_parse(
+ message_assist_thoughts_unparsed_magistral.content,
+ /* is_partial= */ false,
+ {
+ /* .format = */ COMMON_CHAT_FORMAT_MAGISTRAL,
+ /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
+ }));
+ }
{
auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja");
std::vector end_tokens{ "<|im_end|>" };
@@ -2054,6 +2066,79 @@ static void test_template_output_parsers() {
/* .parse_tool_calls = */ true,
}));
}
+ {
+ auto tmpls = read_templates("models/templates/Apertus-8B-Instruct.jinja");
+ std::vector end_tokens{ "<|assistant_end|>" };
+
+ assert_equals(COMMON_CHAT_FORMAT_APERTUS, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+ assert_equals(COMMON_CHAT_FORMAT_APERTUS, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+
+ // Test parsing regular content
+ assert_msg_equals(message_assist,
+ common_chat_parse(
+ "Hello, world!\nWhat's up?",
+ /* is_partial= */ false,
+ {COMMON_CHAT_FORMAT_APERTUS}));
+
+ // Test parsing content with thinking
+ assert_msg_equals(message_assist_thoughts,
+ common_chat_parse(
+ "<|inner_prefix|>I'm\nthinking<|inner_suffix|>Hello, world!\nWhat's up?",
+ /* is_partial= */ false,
+ {
+ /* .format = */ COMMON_CHAT_FORMAT_APERTUS,
+ /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+ }));
+
+ // Test parsing tool calls
+ assert_msg_equals(message_assist_call,
+ common_chat_parse(
+ "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>",
+ /* is_partial= */ false,
+ {COMMON_CHAT_FORMAT_APERTUS}));
+
+ // Test parsing tool calls with thinking
+ assert_msg_equals(message_assist_call_thoughts,
+ common_chat_parse(
+ "<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>",
+ /* is_partial= */ false,
+ {
+ /* .format = */ COMMON_CHAT_FORMAT_APERTUS,
+ /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK
+ }));
+
+ // Test tool calls with extra content
+ assert_msg_equals(message_assist_call_content,
+ common_chat_parse(
+ "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?",
+ /* is_partial= */ false,
+ {COMMON_CHAT_FORMAT_APERTUS}
+ ));
+
+ // Test tool calls with extra content AND thinking
+ assert_msg_equals(message_assist_call_thoughts_content,
+ common_chat_parse(
+ "<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?",
+ /* is_partial= */ false,
+ {
+ /* .format = */ COMMON_CHAT_FORMAT_APERTUS,
+ /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK
+ }));
+
+ // Test template generation for regular content
+ test_templates(tmpls.get(), end_tokens, message_assist, tools,
+ "Hello, world!\nWhat's up?",
+ /* expect_grammar_triggered= */ false);
+
+ // Test template generation for tool calls
+ test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
+ "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>",
+ /* expect_grammar_triggered= */ true
+ );
+
+ assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get()));
+ }
+
}
static void test_msg_diffs_compute() {
diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp
index 275ba367c0..89bc01b485 100644
--- a/tools/llama-bench/llama-bench.cpp
+++ b/tools/llama-bench/llama-bench.cpp
@@ -168,7 +168,7 @@ static std::vector parse_devices_arg(const std::string & val
return devices;
}
-static std::vector register_rpc_device_list(const std::string & servers) {
+static void register_rpc_server_list(const std::string & servers) {
auto rpc_servers = string_split(servers, ',');
if (rpc_servers.empty()) {
throw std::invalid_argument("no RPC servers specified");
@@ -179,36 +179,15 @@ static std::vector register_rpc_device_list(const std::strin
throw std::invalid_argument("failed to find RPC backend");
}
- using add_rpc_device_fn = ggml_backend_dev_t (*)(const char * endpoint);
- auto * ggml_backend_rpc_add_device_fn = (add_rpc_device_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
- if (!ggml_backend_rpc_add_device_fn) {
- throw std::invalid_argument("failed to find RPC device add function");
+ using add_rpc_server_fn = ggml_backend_reg_t (*)(const char * endpoint);
+ auto * ggml_backend_rpc_add_server_fn = (add_rpc_server_fn) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server");
+ if (!ggml_backend_rpc_add_server_fn) {
+ throw std::invalid_argument("failed to find RPC add server function");
}
-
- static std::unordered_set registered;
- std::vector devices;
for (const auto & server : rpc_servers) {
- ggml_backend_dev_t dev = nullptr;
-
- std::string name = string_format("RPC[%s]", server.c_str());
-
- if (registered.find(server) != registered.end()) {
- dev = ggml_backend_dev_by_name(name.c_str());
- }
-
- if (!dev) {
- dev = ggml_backend_rpc_add_device_fn(server.c_str());
- if (!dev) {
- throw std::invalid_argument(string_format("failed to add RPC device for server '%s'", server.c_str()));
- }
- ggml_backend_device_register(dev);
- registered.insert(server);
- }
-
- devices.push_back(dev);
+ auto reg = ggml_backend_rpc_add_server_fn(server.c_str());
+ ggml_backend_register(reg);
}
-
- return devices;
}
static std::string devices_to_string(const std::vector & devices) {
@@ -714,7 +693,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
break;
}
try {
- register_rpc_device_list(argv[i]);
+ register_rpc_server_list(argv[i]);
} catch (const std::exception & e) {
fprintf(stderr, "error: %s\n", e.what());
invalid_param = true;
@@ -1368,13 +1347,23 @@ struct test {
static std::string get_backend() {
std::vector backends;
+ bool rpc_used = false;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
auto * reg = ggml_backend_reg_get(i);
std::string name = ggml_backend_reg_name(reg);
- if (name != "CPU") {
- backends.push_back(ggml_backend_reg_name(reg));
+ if (string_starts_with(name, "RPC")) {
+ if (ggml_backend_reg_dev_count(reg) > 0) {
+ rpc_used = true;
+ }
+ } else {
+ if (name != "CPU") {
+ backends.push_back(ggml_backend_reg_name(reg));
+ }
}
}
+ if (rpc_used) {
+ backends.push_back("RPC");
+ }
return backends.empty() ? "CPU" : join(backends, ",");
}
diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h
index 664b0c9ac6..7a7523851c 100644
--- a/tools/mtmd/clip-impl.h
+++ b/tools/mtmd/clip-impl.h
@@ -31,6 +31,7 @@
// vision-specific
#define KEY_IMAGE_SIZE "clip.vision.image_size"
+#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size"
#define KEY_PATCH_SIZE "clip.vision.patch_size"
#define KEY_IMAGE_MEAN "clip.vision.image_mean"
#define KEY_IMAGE_STD "clip.vision.image_std"
diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
index 210ecc883f..98e68af27a 100644
--- a/tools/mtmd/clip.cpp
+++ b/tools/mtmd/clip.cpp
@@ -170,7 +170,9 @@ struct clip_hparams {
int32_t projection_dim;
int32_t n_head;
int32_t n_layer;
- int32_t proj_scale_factor = 0; // idefics3
+ // idefics3
+ int32_t preproc_image_size = 0;
+ int32_t proj_scale_factor = 0;
float image_mean[3];
float image_std[3];
@@ -2250,6 +2252,7 @@ struct clip_model_loader {
if (is_vision) {
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
+ get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.preproc_image_size, false);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy
@@ -3551,10 +3554,51 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
// res_imgs->data[0] = *res;
res_imgs->entries.push_back(std::move(img_f32));
return true;
- }
- else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE
+ } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
+ // The refined size has two steps:
+ // 1. Resize w/ aspect-ratio preserving such that the longer side is
+ // the preprocessor longest size
+ // 2. Resize w/out preserving aspect ratio such that both sides are
+ // multiples of image_size (always rounding up)
+ //
+ // CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737
+ const clip_image_size refined_size = image_manipulation::calc_size_preserved_ratio(
+ original_size, params.image_size, params.preproc_image_size);
+
+ llava_uhd::slice_instructions instructions;
+ instructions.overview_size = clip_image_size{params.image_size, params.image_size};
+ instructions.refined_size = refined_size;
+ instructions.grid_size = clip_image_size{
+ static_cast(std::ceil(static_cast(refined_size.width) / params.image_size)),
+ static_cast(std::ceil(static_cast(refined_size.height) / params.image_size)),
+ };
+ for (int y = 0; y < refined_size.height; y += params.image_size) {
+ for (int x = 0; x < refined_size.width; x += params.image_size) {
+ instructions.slices.push_back(llava_uhd::slice_coordinates{
+ /* x */x,
+ /* y */y,
+ /* size */clip_image_size{
+ std::min(params.image_size, refined_size.width - x),
+ std::min(params.image_size, refined_size.height - y)
+ }
+ });
+ }
+ }
+ auto imgs = llava_uhd::slice_image(img, instructions);
+
+ // cast and normalize to f32
+ for (size_t i = 0; i < imgs.size(); ++i) {
+ // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
+ clip_image_f32_ptr res(clip_image_f32_init());
+ normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+ res_imgs->entries.push_back(std::move(res));
+ }
+
+ res_imgs->grid_x = instructions.grid_size.width;
+ res_imgs->grid_y = instructions.grid_size.height;
+ return true;
+ } else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE
|| ctx->proj_type() == PROJECTOR_TYPE_GEMMA3
- || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3
|| ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
) {
clip_image_u8 resized_image;
diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp
index cd022c5e24..ff13874cdf 100644
--- a/tools/mtmd/mtmd.cpp
+++ b/tools/mtmd/mtmd.cpp
@@ -76,7 +76,7 @@ enum mtmd_slice_tmpl {
MTMD_SLICE_TMPL_MINICPMV_2_5,
MTMD_SLICE_TMPL_MINICPMV_2_6,
MTMD_SLICE_TMPL_LLAMA4,
- // TODO @ngxson : add support for idefics (SmolVLM)
+ MTMD_SLICE_TMPL_IDEFICS3,
};
const char * mtmd_default_marker() {
@@ -114,19 +114,22 @@ struct mtmd_context {
// for llava-uhd style models, we need special tokens in-between slices
// minicpmv calls them "slices", llama 4 calls them "tiles"
mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE;
- llama_token tok_ov_img_start = LLAMA_TOKEN_NULL; // overview image
- llama_token tok_ov_img_end = LLAMA_TOKEN_NULL; // overview image
- llama_token tok_slices_start = LLAMA_TOKEN_NULL; // start of all slices
- llama_token tok_slices_end = LLAMA_TOKEN_NULL; // end of all slices
- llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice start
- llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice end
- llama_token tok_sli_img_mid = LLAMA_TOKEN_NULL; // between 2 slices
- llama_token tok_row_end = LLAMA_TOKEN_NULL; // end of row
+ std::vector tok_ov_img_start; // overview image
+ std::vector tok_ov_img_end; // overview image
+ std::vector tok_slices_start; // start of all slices
+ std::vector tok_slices_end; // end of all slices
+ std::vector tok_sli_img_start; // single slice start
+ std::vector tok_sli_img_end; // single slice end
+ std::vector tok_sli_img_mid; // between 2 slices
+ std::vector tok_row_end; // end of row
bool tok_row_end_trail = false;
bool ov_img_first = false;
bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
+ // string template for slice image delimiters with row/col (idefics3)
+ std::string sli_img_start_tmpl;
+
// for whisper, we pre-calculate the mel filter bank
whisper_preprocessor::whisper_filters w_filters;
@@ -197,13 +200,13 @@ struct mtmd_context {
// minicpmv 2.5 format:
// (overview) (slice) (slice) \n ...
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_5;
- tok_ov_img_start = lookup_token("");
- tok_ov_img_end = lookup_token("");
- tok_slices_start = lookup_token("");
- tok_slices_end = lookup_token("");
+ tok_ov_img_start = {lookup_token("")};
+ tok_ov_img_end = {lookup_token("")};
+ tok_slices_start = {lookup_token("")};
+ tok_slices_end = {lookup_token("")};
tok_sli_img_start = tok_ov_img_start;
tok_sli_img_end = tok_ov_img_end;
- tok_row_end = lookup_token("\n");
+ tok_row_end = {lookup_token("\n")};
tok_row_end_trail = false; // no trailing end-of-row token
ov_img_first = true;
@@ -211,11 +214,11 @@ struct mtmd_context {
// minicpmv 2.6 format:
// (overview) (slice) (slice) \n ...
slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6;
- tok_ov_img_start = lookup_token("");
- tok_ov_img_end = lookup_token("");
- tok_sli_img_start = lookup_token("");
- tok_sli_img_end = lookup_token("");
- tok_row_end = lookup_token("\n");
+ tok_ov_img_start = {lookup_token("")};
+ tok_ov_img_end = {lookup_token("")};
+ tok_sli_img_start = {lookup_token("")};
+ tok_sli_img_end = {lookup_token("")};
+ tok_row_end = {lookup_token("\n")};
tok_row_end_trail = false; // no trailing end-of-row token
ov_img_first = true;
@@ -230,9 +233,9 @@ struct mtmd_context {
// <|image|> (overview) <-- overview image is last
// <|image_end|>
slice_tmpl = MTMD_SLICE_TMPL_LLAMA4;
- tok_ov_img_start = lookup_token("<|image|>");
- tok_sli_img_mid = lookup_token("<|tile_x_separator|>");
- tok_row_end = lookup_token("<|tile_y_separator|>");
+ tok_ov_img_start = {lookup_token("<|image|>")};
+ tok_sli_img_mid = {lookup_token("<|tile_x_separator|>")};
+ tok_row_end = {lookup_token("<|tile_y_separator|>")};
tok_row_end_trail = true; // add trailing end-of-row token
ov_img_first = false; // overview image is last
}
@@ -245,8 +248,12 @@ struct mtmd_context {
} else if (proj == PROJECTOR_TYPE_IDEFICS3) {
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
- img_beg = "";
- img_end = "";
+ slice_tmpl = MTMD_SLICE_TMPL_IDEFICS3;
+ tok_ov_img_start = {lookup_token("\n"), lookup_token(""), lookup_token("")};
+ tok_ov_img_end = {lookup_token("")};
+ tok_row_end = {lookup_token("\n")};
+ img_beg = "";
+ sli_img_start_tmpl = "";
} else if (proj == PROJECTOR_TYPE_PIXTRAL) {
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
@@ -504,6 +511,7 @@ struct mtmd_tokenizer {
ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6
|| ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
+ || ctx->slice_tmpl == MTMD_SLICE_TMPL_IDEFICS3
) {
const int n_col = batch_f32.grid_x;
const int n_row = batch_f32.grid_y;
@@ -517,53 +525,45 @@ struct mtmd_tokenizer {
// add overview image (first)
if (ctx->ov_img_first) {
- if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
- add_text({ctx->tok_ov_img_start});
- }
+ add_text(ctx->tok_ov_img_start);
cur.entries.emplace_back(std::move(ov_chunk));
- if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
- add_text({ctx->tok_ov_img_end});
- }
+ add_text(ctx->tok_ov_img_end);
}
// add slices (or tiles)
if (!chunks.empty()) {
GGML_ASSERT((int)chunks.size() == n_row * n_col);
- if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) {
- add_text({ctx->tok_slices_start});
- }
+ add_text(ctx->tok_slices_start);
for (int y = 0; y < n_row; y++) {
for (int x = 0; x < n_col; x++) {
const bool is_last_in_row = (x == n_col - 1);
- if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) {
- add_text({ctx->tok_sli_img_start});
+ if (!ctx->tok_sli_img_start.empty()) {
+ add_text(ctx->tok_sli_img_start);
+ } else if (!ctx->sli_img_start_tmpl.empty()) {
+ // If using a template to preceed a slice image
+ const size_t sz = std::snprintf(nullptr, 0, ctx->sli_img_start_tmpl.c_str(), y+1, x+1) + 1;
+ std::unique_ptr buf(new char[sz]);
+ std::snprintf(buf.get(), sz, ctx->sli_img_start_tmpl.c_str(), y+1, x+1);
+ add_text(std::string(buf.get(), buf.get() + sz - 1), true);
}
cur.entries.emplace_back(std::move(chunks[y * n_col + x]));
- if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) {
- add_text({ctx->tok_sli_img_end});
- }
- if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) {
- add_text({ctx->tok_sli_img_mid});
+ add_text(ctx->tok_sli_img_end);
+ if (!is_last_in_row) {
+ add_text(ctx->tok_sli_img_mid);
}
}
- if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) {
- add_text({ctx->tok_row_end});
+ if ((y != n_row - 1 || ctx->tok_row_end_trail)) {
+ add_text(ctx->tok_row_end);
}
}
- if (ctx->tok_slices_end != LLAMA_TOKEN_NULL) {
- add_text({ctx->tok_slices_end});
- }
+ add_text(ctx->tok_slices_end);
}
// add overview image (last)
if (!ctx->ov_img_first) {
- if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
- add_text({ctx->tok_ov_img_start});
- }
+ add_text(ctx->tok_ov_img_start);
cur.entries.emplace_back(std::move(ov_chunk));
- if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
- add_text({ctx->tok_ov_img_end});
- }
+ add_text(ctx->tok_ov_img_end);
}
} else {
@@ -780,7 +780,9 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
bool ok = false;
- if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) || clip_is_glm(ctx_clip)) {
+ if (clip_is_llava(ctx_clip)
+ || clip_is_minicpmv(ctx_clip)
+ || clip_is_glm(ctx_clip)) {
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
const auto & entries = image_tokens->batch_f32.entries;
for (size_t i = 0; i < entries.size(); i++) {
diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh
index c64be03630..dbdf7656a6 100755
--- a/tools/mtmd/tests.sh
+++ b/tools/mtmd/tests.sh
@@ -69,6 +69,7 @@ add_test_vision "ggml-org/InternVL2_5-1B-GGUF:Q8_0"
add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0"
add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0"
+add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0"
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp
index dc8e077f34..0885156127 100644
--- a/tools/rpc/rpc-server.cpp
+++ b/tools/rpc/rpc-server.cpp
@@ -22,6 +22,7 @@
#include
#include
#include
+#include
namespace fs = std::filesystem;
@@ -131,24 +132,24 @@ static std::string fs_get_cache_directory() {
}
struct rpc_server_params {
- std::string host = "127.0.0.1";
- int port = 50052;
- size_t backend_mem = 0;
- bool use_cache = false;
- int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
- std::string device;
+ std::string host = "127.0.0.1";
+ int port = 50052;
+ bool use_cache = false;
+ int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
+ std::vector devices;
+ std::vector dev_mem;
};
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
fprintf(stderr, "options:\n");
- fprintf(stderr, " -h, --help show this help message and exit\n");
- fprintf(stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n", params.n_threads);
- fprintf(stderr, " -d DEV, --device device to use\n");
- fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
- fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
- fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
- fprintf(stderr, " -c, --cache enable local file cache\n");
+ fprintf(stderr, " -h, --help show this help message and exit\n");
+ fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads);
+ fprintf(stderr, " -d, --device comma-separated list of devices\n");
+ fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
+ fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
+ fprintf(stderr, " -m, --mem memory size for each device (in MB)\n");
+ fprintf(stderr, " -c, --cache enable local file cache\n");
fprintf(stderr, "\n");
}
@@ -174,17 +175,17 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
if (++i >= argc) {
return false;
}
- params.device = argv[i];
- if (ggml_backend_dev_by_name(params.device.c_str()) == nullptr) {
- fprintf(stderr, "error: unknown device: %s\n", params.device.c_str());
- fprintf(stderr, "available devices:\n");
- for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
- auto * dev = ggml_backend_dev_get(i);
- size_t free, total;
- ggml_backend_dev_memory(dev, &free, &total);
- printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
+ const std::regex regex{ R"([,/]+)" };
+ std::string dev_str = argv[i];
+ std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1);
+ std::sregex_token_iterator end;
+ for ( ; iter != end; ++iter) {
+ try {
+ params.devices.push_back(*iter);
+ } catch (const std::exception & ) {
+ fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str());
+ return false;
}
- return false;
}
} else if (arg == "-p" || arg == "--port") {
if (++i >= argc) {
@@ -200,7 +201,19 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
if (++i >= argc) {
return false;
}
- params.backend_mem = std::stoul(argv[i]) * 1024 * 1024;
+ const std::regex regex{ R"([,/]+)" };
+ std::string mem_str = argv[i];
+ std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
+ std::sregex_token_iterator end;
+ for ( ; iter != end; ++iter) {
+ try {
+ size_t mem = std::stoul(*iter) * 1024 * 1024;
+ params.dev_mem.push_back(mem);
+ } catch (const std::exception & ) {
+ fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
+ return false;
+ }
+ }
} else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv, params);
exit(0);
@@ -213,45 +226,46 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
return true;
}
-static ggml_backend_t create_backend(const rpc_server_params & params) {
- ggml_backend_t backend = nullptr;
+static std::vector get_devices(const rpc_server_params & params) {
+ std::vector devices;
+ if (!params.devices.empty()) {
+ for (auto device : params.devices) {
+ ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str());
+ if (dev) {
+ devices.push_back(dev);
+ } else {
+ fprintf(stderr, "error: unknown device: %s\n", device.c_str());
+ fprintf(stderr, "available devices:\n");
+ for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+ auto * dev = ggml_backend_dev_get(i);
+ size_t free, total;
+ ggml_backend_dev_memory(dev, &free, &total);
+ printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
+ }
+ return {};
+ }
+ }
+ }
- if (!params.device.empty()) {
- ggml_backend_dev_t dev = ggml_backend_dev_by_name(params.device.c_str());
+ // Try non-CPU devices first
+ if (devices.empty()) {
+ for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+ if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
+ devices.push_back(dev);
+ }
+ }
+ }
+
+ // If there are no accelerators, fallback to CPU device
+ if (devices.empty()) {
+ ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (dev) {
- backend = ggml_backend_dev_init(dev, nullptr);
- if (!backend) {
- fprintf(stderr, "Failed to create backend for device %s\n", params.device.c_str());
- return nullptr;
- }
+ devices.push_back(dev);
}
}
- if (!backend) {
- backend = ggml_backend_init_best();
- }
-
- if (backend) {
- fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend));
-
- // set the number of threads
- ggml_backend_dev_t dev = ggml_backend_get_device(backend);
- ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
- if (reg) {
- auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
- if (ggml_backend_set_n_threads_fn) {
- ggml_backend_set_n_threads_fn(backend, params.n_threads);
- }
- }
- }
-
- return backend;
-}
-
-static void get_backend_memory(ggml_backend_t backend, size_t * free_mem, size_t * total_mem) {
- ggml_backend_dev_t dev = ggml_backend_get_device(backend);
- GGML_ASSERT(dev != nullptr);
- ggml_backend_dev_memory(dev, free_mem, total_mem);
+ return devices;
}
int main(int argc, char * argv[]) {
@@ -273,18 +287,23 @@ int main(int argc, char * argv[]) {
fprintf(stderr, "\n");
}
- ggml_backend_t backend = create_backend(params);
- if (!backend) {
- fprintf(stderr, "Failed to create backend\n");
+ auto devices = get_devices(params);
+ if (devices.empty()) {
+ fprintf(stderr, "No devices found\n");
return 1;
}
std::string endpoint = params.host + ":" + std::to_string(params.port);
- size_t free_mem, total_mem;
- if (params.backend_mem > 0) {
- free_mem = params.backend_mem;
- total_mem = params.backend_mem;
- } else {
- get_backend_memory(backend, &free_mem, &total_mem);
+ std::vector free_mem, total_mem;
+ for (size_t i = 0; i < devices.size(); i++) {
+ if (i < params.dev_mem.size()) {
+ free_mem.push_back(params.dev_mem[i]);
+ total_mem.push_back(params.dev_mem[i]);
+ } else {
+ size_t free, total;
+ ggml_backend_dev_memory(devices[i], &free, &total);
+ free_mem.push_back(free);
+ total_mem.push_back(total);
+ }
}
const char * cache_dir = nullptr;
std::string cache_dir_str;
@@ -309,8 +328,7 @@ int main(int argc, char * argv[]) {
return 1;
}
- start_server_fn(backend, endpoint.c_str(), cache_dir, free_mem, total_mem);
-
- ggml_backend_free(backend);
+ start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
+ devices.data(), free_mem.data(), total_mem.data());
return 0;
}
diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz
index 4f18a634ce..2801319c98 100644
Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ
diff --git a/tools/server/server.cpp b/tools/server/server.cpp
index 6062904a8c..a21147613d 100644
--- a/tools/server/server.cpp
+++ b/tools/server/server.cpp
@@ -764,7 +764,7 @@ struct completion_token_output {
}
};
-struct swa_checkpoint {
+struct ctx_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
@@ -1460,7 +1460,7 @@ struct server_slot {
std::vector generated_token_probs;
- std::vector swa_checkpoints;
+ std::vector ctx_checkpoints;
bool has_next_token = true;
bool has_new_line = false;
@@ -3541,7 +3541,11 @@ struct server_context {
slot.n_past = 0;
}
- const auto n_swa = llama_model_n_swa(model);
+ // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
+ const auto n_swa = std::max(1, llama_model_n_swa(model));
+
+ // the largest pos_min required for a checkpoint to be useful
+ const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
@@ -3550,66 +3554,62 @@ struct server_context {
GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
}
- const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
-
if (pos_min > pos_min_thold) {
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
- // search for a SWA checkpoint
+ // search for a context checkpoint
const auto it = std::find_if(
- slot.swa_checkpoints.rbegin(),
- slot.swa_checkpoints.rend(),
+ slot.ctx_checkpoints.rbegin(),
+ slot.ctx_checkpoints.rend(),
[&](const auto & cur) {
- return cur.pos_min <= pos_min_thold;
+ // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS]
+ return cur.pos_min < pos_min_thold;
}
);
- bool do_reset = it == slot.swa_checkpoints.rend();
+ bool do_reset = it == slot.ctx_checkpoints.rend();
+ //printf("[DEBUG] `do_reset` was set to `%s`\n", do_reset ? "true" : "false");
if (!do_reset) {
- // restore the checkpoint
- const size_t swa_size = it->data.size();
- const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), swa_size, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
+ // restore the context checkpoint
+ const size_t ctx_checkpoint_size = it->data.size();
+ const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), ctx_checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
- if (n != swa_size) {
- SLT_ERR(slot, "failed to restore SWA checkpoint, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
+ if (n != ctx_checkpoint_size) {
+ SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
do_reset = true;
+ //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
} else {
- slot.n_past = std::min(slot.n_past, it->pos_max);
-
- SLT_WRN(slot, "SWA checkpoint restore, pos_min = %d, pos_max = %d, size = %.3f MiB\n", it->pos_min, it->pos_max, (float) swa_size / 1024 / 1024);
+ slot.n_past = std::min(slot.n_past, std::max(it->pos_min + 1, it->pos_max));
+ SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) ctx_checkpoint_size / 1024 / 1024);
}
}
if (do_reset) {
- SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
+ SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
-
slot.n_past = 0;
- slot.swa_checkpoints.clear();
}
}
}
- if (n_swa > 0) {
- const auto pos_min_thold = std::max(0, slot.n_past - n_swa);
-
+ {
// erase any checkpoints with pos_min > pos_min_thold
- for (int i = (int) slot.swa_checkpoints.size() - 1; i >= 0; i--) {
- const auto & cur = slot.swa_checkpoints[i];
+ for (int i = (int) slot.ctx_checkpoints.size() - 1; i >= 0; i--) {
+ const auto & cur = slot.ctx_checkpoints[i];
if (cur.pos_min > pos_min_thold) {
- slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + i);
-
- SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n", cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
+ SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
+ slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin() + i);
}
}
}
}
+ // [TAG_PROMPT_LOGITS]
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
- SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
-
+ SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, n_prompt_tokens = %d)\n", slot.n_past, slot.n_prompt_tokens);
slot.n_past--;
+ SLT_WRN(slot, "n_past was set to %d\n", slot.n_past);
}
slot.n_prompt_tokens_cache = slot.n_past;
@@ -3623,9 +3623,9 @@ struct server_context {
}
}
- // keep only the common part
+ // truncate any tokens that are beyond n_past for this slot
if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.n_past, -1)) {
- // could not partially delete (likely using a non-Transformer model)
+ SLT_WRN(slot, "failed to truncate tokens beyond n_past = %d\n", slot.n_past);
llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1);
// there is no common part left
@@ -3633,7 +3633,7 @@ struct server_context {
slot.n_prompt_tokens_cache = 0;
}
- SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
+ SLT_INF(slot, "n_past = %d, memory_seq_rm [%d, end)\n", slot.n_past, slot.n_past);
// remove the non-common part from the cache
slot.cache_tokens.keep_first(slot.n_past);
@@ -3854,37 +3854,38 @@ struct server_context {
// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;
- // make a checkpoint with the SWA memory
- // checkpoints are needed only if we are not using "--swa-full"
- if (llama_model_n_swa(model) > 0 && !params_base.swa_full && params_base.n_swa_checkpoints > 0) {
- if (slot.swa_checkpoints.size() >= (size_t) params_base.n_swa_checkpoints) {
- {
- const auto & cur = slot.swa_checkpoints.back();
+ // make a checkpoint of the parts of the memory that cannot be rolled back.
+ // checkpoints are created only if:
+ // - the model uses SWA and we are not using `swa_full`
+ // - the model architecture is marked as recurrent or hybrid
+ //
+ // TODO: try to make this conditional on the context or the memory module, instead of the model type
+ const bool do_checkpoint =
+ (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) ||
+ (llama_model_n_swa(model) > 0 && !params_base.swa_full);
- SLT_WRN(slot, "SWA checkpoint erase, pos_min = %d, pos_max = %d, size = %.3f MiB\n",
- cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
- }
+ if (do_checkpoint && params_base.n_ctx_checkpoints > 0) {
+ while (slot.ctx_checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
+ // make room for the new checkpoint, if needed
+ const auto & cur = slot.ctx_checkpoints.front();
+ SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
+ cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
- slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
+ slot.ctx_checkpoints.erase(slot.ctx_checkpoints.begin());
}
- const size_t swa_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_SWA_ONLY);
+ const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
- auto & cur = slot.swa_checkpoints.emplace_back(swa_checkpoint{
+ auto & cur = slot.ctx_checkpoints.emplace_back(ctx_checkpoint{
/*.pos_min = */ llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id),
/*.pos_max = */ llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id),
- /*.data = */ std::vector(swa_size),
+ /*.data = */ std::vector