Merge branch 'ggml-org:master' into quantize
This commit is contained in:
commit
9b857e3984
|
|
@ -4,7 +4,7 @@ ARG UBUNTU_VERSION=24.04
|
||||||
ARG ROCM_VERSION=6.4
|
ARG ROCM_VERSION=6.4
|
||||||
ARG AMDGPU_VERSION=6.4
|
ARG AMDGPU_VERSION=6.4
|
||||||
|
|
||||||
# Target the CUDA build image
|
# Target the ROCm build image
|
||||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||||
|
|
||||||
### Build image
|
### Build image
|
||||||
|
|
@ -15,12 +15,12 @@ FROM ${BASE_ROCM_DEV_CONTAINER} AS build
|
||||||
# This is mostly tied to rocBLAS supported archs.
|
# This is mostly tied to rocBLAS supported archs.
|
||||||
# gfx803, gfx900, gfx1032, gfx1101, gfx1102,not officialy supported
|
# gfx803, gfx900, gfx1032, gfx1101, gfx1102,not officialy supported
|
||||||
# gfx906 is deprecated
|
# gfx906 is deprecated
|
||||||
#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.2.4/reference/system-requirements.html
|
#check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html
|
||||||
|
|
||||||
ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102'
|
ARG ROCM_DOCKER_ARCH='gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201'
|
||||||
#ARG ROCM_DOCKER_ARCH=gfx1100
|
#ARG ROCM_DOCKER_ARCH=gfx1100
|
||||||
|
|
||||||
# Set nvcc architectured
|
# Set ROCm architectured
|
||||||
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
|
ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH}
|
||||||
# Enable ROCm
|
# Enable ROCm
|
||||||
# ENV CC=/opt/rocm/llvm/bin/clang
|
# ENV CC=/opt/rocm/llvm/bin/clang
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,7 @@ jobs:
|
||||||
-DGGML_METAL_SHADER_DEBUG=ON \
|
-DGGML_METAL_SHADER_DEBUG=ON \
|
||||||
-DGGML_RPC=ON
|
-DGGML_RPC=ON
|
||||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||||
|
leaks -atExit -- ./build/bin/test-thread-safety -hf ggml-org/gemma-3-270m-qat-GGUF -ngl 99 -p "$(printf 'hello %.0s' {1..128})" -n 16 -c 512 -ub 32 -np 2 -t 2 -lv 1
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
id: cmake_test
|
id: cmake_test
|
||||||
|
|
@ -126,7 +127,8 @@ jobs:
|
||||||
-DCMAKE_BUILD_RPATH="@loader_path" \
|
-DCMAKE_BUILD_RPATH="@loader_path" \
|
||||||
-DLLAMA_FATAL_WARNINGS=ON \
|
-DLLAMA_FATAL_WARNINGS=ON \
|
||||||
-DGGML_METAL=OFF \
|
-DGGML_METAL=OFF \
|
||||||
-DGGML_RPC=ON
|
-DGGML_RPC=ON \
|
||||||
|
-DCMAKE_OSX_DEPLOYMENT_TARGET=13.3
|
||||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
|
|
@ -1050,9 +1052,13 @@ jobs:
|
||||||
run: examples/sycl/win-build-sycl.bat
|
run: examples/sycl/win-build-sycl.bat
|
||||||
|
|
||||||
windows-latest-cmake-hip:
|
windows-latest-cmake-hip:
|
||||||
if: ${{ github.event.inputs.create_release != 'true' }}
|
|
||||||
runs-on: windows-2022
|
runs-on: windows-2022
|
||||||
|
|
||||||
|
env:
|
||||||
|
# The ROCm version must correspond to the version used in the HIP SDK.
|
||||||
|
ROCM_VERSION: "6.4.2"
|
||||||
|
HIPSDK_INSTALLER_VERSION: "25.Q3"
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
id: checkout
|
id: checkout
|
||||||
|
|
@ -1061,16 +1067,14 @@ jobs:
|
||||||
- name: Clone rocWMMA repository
|
- name: Clone rocWMMA repository
|
||||||
id: clone_rocwmma
|
id: clone_rocwmma
|
||||||
run: |
|
run: |
|
||||||
git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1
|
git clone https://github.com/rocm/rocwmma --branch rocm-${{ env.ROCM_VERSION }} --depth 1
|
||||||
|
|
||||||
- name: Cache ROCm Installation
|
- name: Cache ROCm Installation
|
||||||
id: cache-rocm
|
id: cache-rocm
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: C:\Program Files\AMD\ROCm
|
path: C:\Program Files\AMD\ROCm
|
||||||
key: rocm-6.1-${{ runner.os }}-v1
|
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
|
||||||
restore-keys: |
|
|
||||||
rocm-6.1-${{ runner.os }}-
|
|
||||||
|
|
||||||
- name: Install ROCm
|
- name: Install ROCm
|
||||||
if: steps.cache-rocm.outputs.cache-hit != 'true'
|
if: steps.cache-rocm.outputs.cache-hit != 'true'
|
||||||
|
|
@ -1078,7 +1082,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
write-host "Downloading AMD HIP SDK Installer"
|
write-host "Downloading AMD HIP SDK Installer"
|
||||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||||
write-host "Installing AMD HIP SDK"
|
write-host "Installing AMD HIP SDK"
|
||||||
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
|
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
|
||||||
$completed = $proc.WaitForExit(600000)
|
$completed = $proc.WaitForExit(600000)
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,8 @@ jobs:
|
||||||
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
||||||
-DLLAMA_FATAL_WARNINGS=ON \
|
-DLLAMA_FATAL_WARNINGS=ON \
|
||||||
-DGGML_METAL=OFF \
|
-DGGML_METAL=OFF \
|
||||||
-DGGML_RPC=ON
|
-DGGML_RPC=ON \
|
||||||
|
-DCMAKE_OSX_DEPLOYMENT_TARGET=13.3
|
||||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||||
|
|
||||||
- name: Determine tag name
|
- name: Determine tag name
|
||||||
|
|
@ -528,11 +529,16 @@ jobs:
|
||||||
windows-hip:
|
windows-hip:
|
||||||
runs-on: windows-2022
|
runs-on: windows-2022
|
||||||
|
|
||||||
|
env:
|
||||||
|
# The ROCm version must correspond to the version used in the HIP SDK.
|
||||||
|
ROCM_VERSION: "6.4.2"
|
||||||
|
HIPSDK_INSTALLER_VERSION: "25.Q3"
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- name: "radeon"
|
- name: "radeon"
|
||||||
gpu_targets: "gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
|
gpu_targets: "gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032"
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
|
|
@ -542,21 +548,19 @@ jobs:
|
||||||
- name: Clone rocWMMA repository
|
- name: Clone rocWMMA repository
|
||||||
id: clone_rocwmma
|
id: clone_rocwmma
|
||||||
run: |
|
run: |
|
||||||
git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1
|
git clone https://github.com/rocm/rocwmma --branch rocm-${{ env.ROCM_VERSION }} --depth 1
|
||||||
|
|
||||||
- name: Cache ROCm Installation
|
- name: Cache ROCm Installation
|
||||||
id: cache-rocm
|
id: cache-rocm
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: C:\Program Files\AMD\ROCm
|
path: C:\Program Files\AMD\ROCm
|
||||||
key: rocm-6.1-${{ runner.os }}-v1
|
key: rocm-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ runner.os }}
|
||||||
restore-keys: |
|
|
||||||
rocm-6.1-${{ runner.os }}-
|
|
||||||
|
|
||||||
- name: ccache
|
- name: ccache
|
||||||
uses: ggml-org/ccache-action@v1.2.16
|
uses: ggml-org/ccache-action@v1.2.16
|
||||||
with:
|
with:
|
||||||
key: windows-latest-cmake-hip-${{ matrix.name }}-x64
|
key: windows-latest-cmake-hip-${{ env.HIPSDK_INSTALLER_VERSION }}-${{ matrix.name }}-x64
|
||||||
evict-old-files: 1d
|
evict-old-files: 1d
|
||||||
|
|
||||||
- name: Install ROCm
|
- name: Install ROCm
|
||||||
|
|
@ -565,7 +569,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
write-host "Downloading AMD HIP SDK Installer"
|
write-host "Downloading AMD HIP SDK Installer"
|
||||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||||
write-host "Installing AMD HIP SDK"
|
write-host "Installing AMD HIP SDK"
|
||||||
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
|
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
|
||||||
$completed = $proc.WaitForExit(600000)
|
$completed = $proc.WaitForExit(600000)
|
||||||
|
|
@ -610,9 +614,12 @@ jobs:
|
||||||
-DLLAMA_CURL=OFF
|
-DLLAMA_CURL=OFF
|
||||||
cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS}
|
cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS}
|
||||||
md "build\bin\rocblas\library\"
|
md "build\bin\rocblas\library\"
|
||||||
|
md "build\bin\hipblaslt\library"
|
||||||
cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
|
cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\"
|
||||||
|
cp "${env:HIP_PATH}\bin\hipblaslt.dll" "build\bin\"
|
||||||
cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\"
|
cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\"
|
||||||
cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\"
|
cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\"
|
||||||
|
cp "${env:HIP_PATH}\bin\hipblaslt\library\*" "build\bin\hipblaslt\library\"
|
||||||
|
|
||||||
- name: Pack artifacts
|
- name: Pack artifacts
|
||||||
id: pack_artifacts
|
id: pack_artifacts
|
||||||
|
|
|
||||||
19
ci/run.sh
19
ci/run.sh
|
|
@ -270,7 +270,9 @@ function gg_run_ctest_with_model_debug {
|
||||||
local model; model=$(gg_get_model)
|
local model; model=$(gg_get_model)
|
||||||
cd build-ci-debug
|
cd build-ci-debug
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
(LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
(LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||||
|
|
||||||
set +e
|
set +e
|
||||||
cd ..
|
cd ..
|
||||||
}
|
}
|
||||||
|
|
@ -281,7 +283,15 @@ function gg_run_ctest_with_model_release {
|
||||||
local model; model=$(gg_get_model)
|
local model; model=$(gg_get_model)
|
||||||
cd build-ci-release
|
cd build-ci-release
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
(LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
(LLAMACPP_TEST_MODELFILE="$model" time ctest --output-on-failure -L model) 2>&1 | tee -a $OUT/${ci}-ctest.log
|
||||||
|
|
||||||
|
# test memory leaks
|
||||||
|
#if [[ ! -z ${GG_BUILD_METAL} ]]; then
|
||||||
|
# # TODO: this hangs for some reason ...
|
||||||
|
# (time leaks -quiet -atExit -- ./bin/test-thread-safety -m $model --parallel 2 -t 2 -p "hello") 2>&1 | tee -a $OUT/${ci}-leaks.log
|
||||||
|
#fi
|
||||||
|
|
||||||
set +e
|
set +e
|
||||||
cd ..
|
cd ..
|
||||||
}
|
}
|
||||||
|
|
@ -860,10 +870,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
ret=0
|
ret=0
|
||||||
if [ -z ${GG_BUILD_SYCL} ]; then
|
test $ret -eq 0 && gg_run ctest_debug
|
||||||
# SYCL build breaks with debug build flags
|
|
||||||
test $ret -eq 0 && gg_run ctest_debug
|
|
||||||
fi
|
|
||||||
test $ret -eq 0 && gg_run ctest_release
|
test $ret -eq 0 && gg_run ctest_release
|
||||||
|
|
||||||
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||||
|
|
@ -871,9 +878,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||||
test $ret -eq 0 && gg_run rerank_tiny
|
test $ret -eq 0 && gg_run rerank_tiny
|
||||||
|
|
||||||
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
|
if [ -z ${GG_BUILD_CLOUD} ] || [ ${GG_BUILD_EXTRA_TESTS_0} ]; then
|
||||||
if [ -z ${GG_BUILD_SYCL} ]; then
|
|
||||||
test $ret -eq 0 && gg_run test_scripts_debug
|
test $ret -eq 0 && gg_run test_scripts_debug
|
||||||
fi
|
|
||||||
test $ret -eq 0 && gg_run test_scripts_release
|
test $ret -eq 0 && gg_run test_scripts_release
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
@ -884,9 +889,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
|
||||||
test $ret -eq 0 && gg_run pythia_2_8b
|
test $ret -eq 0 && gg_run pythia_2_8b
|
||||||
#test $ret -eq 0 && gg_run open_llama_7b_v2
|
#test $ret -eq 0 && gg_run open_llama_7b_v2
|
||||||
fi
|
fi
|
||||||
if [ -z ${GG_BUILD_SYCL} ]; then
|
|
||||||
test $ret -eq 0 && gg_run ctest_with_model_debug
|
test $ret -eq 0 && gg_run ctest_with_model_debug
|
||||||
fi
|
|
||||||
test $ret -eq 0 && gg_run ctest_with_model_release
|
test $ret -eq 0 && gg_run ctest_with_model_release
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
|
||||||
152
common/arg.cpp
152
common/arg.cpp
|
|
@ -745,6 +745,124 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
|
||||||
|
|
||||||
#endif // LLAMA_USE_CURL
|
#endif // LLAMA_USE_CURL
|
||||||
|
|
||||||
|
//
|
||||||
|
// Docker registry functions
|
||||||
|
//
|
||||||
|
|
||||||
|
static std::string common_docker_get_token(const std::string & repo) {
|
||||||
|
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
|
||||||
|
|
||||||
|
common_remote_params params;
|
||||||
|
auto res = common_remote_get_content(url, params);
|
||||||
|
|
||||||
|
if (res.first != 200) {
|
||||||
|
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string response_str(res.second.begin(), res.second.end());
|
||||||
|
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
|
||||||
|
|
||||||
|
if (!response.contains("token")) {
|
||||||
|
throw std::runtime_error("Docker registry token response missing 'token' field");
|
||||||
|
}
|
||||||
|
|
||||||
|
return response["token"].get<std::string>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string common_docker_resolve_model(const std::string & docker) {
|
||||||
|
// Parse ai/smollm2:135M-Q4_K_M
|
||||||
|
size_t colon_pos = docker.find(':');
|
||||||
|
std::string repo, tag;
|
||||||
|
if (colon_pos != std::string::npos) {
|
||||||
|
repo = docker.substr(0, colon_pos);
|
||||||
|
tag = docker.substr(colon_pos + 1);
|
||||||
|
} else {
|
||||||
|
repo = docker;
|
||||||
|
tag = "latest";
|
||||||
|
}
|
||||||
|
|
||||||
|
// ai/ is the default
|
||||||
|
size_t slash_pos = docker.find('/');
|
||||||
|
if (slash_pos == std::string::npos) {
|
||||||
|
repo.insert(0, "ai/");
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
|
||||||
|
try {
|
||||||
|
// --- helper: digest validation ---
|
||||||
|
auto validate_oci_digest = [](const std::string & digest) -> std::string {
|
||||||
|
// Expected: algo:hex ; start with sha256 (64 hex chars)
|
||||||
|
// You can extend this map if supporting other algorithms in future.
|
||||||
|
static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
|
||||||
|
std::smatch m;
|
||||||
|
if (!std::regex_match(digest, m, re)) {
|
||||||
|
throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
|
||||||
|
}
|
||||||
|
// normalize hex to lowercase
|
||||||
|
std::string normalized = digest;
|
||||||
|
std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
|
||||||
|
return std::tolower(c);
|
||||||
|
});
|
||||||
|
return normalized;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string token = common_docker_get_token(repo); // Get authentication token
|
||||||
|
|
||||||
|
// Get manifest
|
||||||
|
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
|
||||||
|
std::string manifest_url = url_prefix + "/manifests/" + tag;
|
||||||
|
common_remote_params manifest_params;
|
||||||
|
manifest_params.headers.push_back("Authorization: Bearer " + token);
|
||||||
|
manifest_params.headers.push_back(
|
||||||
|
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
|
||||||
|
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
|
||||||
|
if (manifest_res.first != 200) {
|
||||||
|
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
|
||||||
|
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
|
||||||
|
std::string gguf_digest; // Find the GGUF layer
|
||||||
|
if (manifest.contains("layers")) {
|
||||||
|
for (const auto & layer : manifest["layers"]) {
|
||||||
|
if (layer.contains("mediaType")) {
|
||||||
|
std::string media_type = layer["mediaType"].get<std::string>();
|
||||||
|
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
|
||||||
|
media_type.find("gguf") != std::string::npos) {
|
||||||
|
gguf_digest = layer["digest"].get<std::string>();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (gguf_digest.empty()) {
|
||||||
|
throw std::runtime_error("No GGUF layer found in Docker manifest");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate & normalize digest
|
||||||
|
gguf_digest = validate_oci_digest(gguf_digest);
|
||||||
|
LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
|
||||||
|
|
||||||
|
// Prepare local filename
|
||||||
|
std::string model_filename = repo;
|
||||||
|
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
|
||||||
|
model_filename += "_" + tag + ".gguf";
|
||||||
|
std::string local_path = fs_get_cache_file(model_filename);
|
||||||
|
|
||||||
|
const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
|
||||||
|
if (!common_download_file_single(blob_url, local_path, token, false)) {
|
||||||
|
throw std::runtime_error("Failed to download Docker Model");
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
|
||||||
|
return local_path;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// utils
|
// utils
|
||||||
//
|
//
|
||||||
|
|
@ -795,7 +913,9 @@ static handle_model_result common_params_handle_model(
|
||||||
handle_model_result result;
|
handle_model_result result;
|
||||||
// handle pre-fill default model path and url based on hf_repo and hf_file
|
// handle pre-fill default model path and url based on hf_repo and hf_file
|
||||||
{
|
{
|
||||||
if (!model.hf_repo.empty()) {
|
if (!model.docker_repo.empty()) { // Handle Docker URLs by resolving them to local paths
|
||||||
|
model.path = common_docker_resolve_model(model.docker_repo);
|
||||||
|
} else if (!model.hf_repo.empty()) {
|
||||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||||
if (model.hf_file.empty()) {
|
if (model.hf_file.empty()) {
|
||||||
if (model.path.empty()) {
|
if (model.path.empty()) {
|
||||||
|
|
@ -1184,7 +1304,7 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
|
||||||
} else {
|
} else {
|
||||||
for (const auto & device : dev_names) {
|
for (const auto & device : dev_names) {
|
||||||
auto * dev = ggml_backend_dev_by_name(device.c_str());
|
auto * dev = ggml_backend_dev_by_name(device.c_str());
|
||||||
if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
|
if (!dev || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||||
throw std::invalid_argument(string_format("invalid device: %s", device.c_str()));
|
throw std::invalid_argument(string_format("invalid device: %s", device.c_str()));
|
||||||
}
|
}
|
||||||
devices.push_back(dev);
|
devices.push_back(dev);
|
||||||
|
|
@ -1194,7 +1314,7 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
|
||||||
return devices;
|
return devices;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void add_rpc_devices(std::string servers) {
|
static void add_rpc_devices(const std::string & servers) {
|
||||||
auto rpc_servers = string_split<std::string>(servers, ',');
|
auto rpc_servers = string_split<std::string>(servers, ',');
|
||||||
if (rpc_servers.empty()) {
|
if (rpc_servers.empty()) {
|
||||||
throw std::invalid_argument("no RPC servers specified");
|
throw std::invalid_argument("no RPC servers specified");
|
||||||
|
|
@ -2396,24 +2516,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
{"--list-devices"},
|
{"--list-devices"},
|
||||||
"print list of available devices and exit",
|
"print list of available devices and exit",
|
||||||
[](common_params &) {
|
[](common_params &) {
|
||||||
std::vector<ggml_backend_dev_t> rpc_devices;
|
std::vector<ggml_backend_dev_t> devices;
|
||||||
std::vector<ggml_backend_dev_t> all_devices;
|
|
||||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||||
auto * dev = ggml_backend_dev_get(i);
|
auto * dev = ggml_backend_dev_get(i);
|
||||||
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
|
||||||
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
devices.push_back(dev);
|
||||||
if (ggml_backend_reg_name(reg) == std::string("RPC")) {
|
|
||||||
rpc_devices.push_back(dev);
|
|
||||||
} else {
|
|
||||||
all_devices.push_back(dev);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// insert RPC devices in front
|
|
||||||
all_devices.insert(all_devices.begin(), rpc_devices.begin(), rpc_devices.end());
|
|
||||||
printf("Available devices:\n");
|
printf("Available devices:\n");
|
||||||
for (size_t i = 0; i < all_devices.size(); ++i) {
|
for (auto * dev : devices) {
|
||||||
auto * dev = all_devices[i];
|
|
||||||
size_t free, total;
|
size_t free, total;
|
||||||
ggml_backend_dev_memory(dev, &free, &total);
|
ggml_backend_dev_memory(dev, &free, &total);
|
||||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||||
|
|
@ -2636,6 +2747,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
params.model.url = value;
|
params.model.url = value;
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_MODEL_URL"));
|
).set_env("LLAMA_ARG_MODEL_URL"));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{ "-dr", "--docker-repo" }, "[<repo>/]<model>[:quant]",
|
||||||
|
"Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.\n"
|
||||||
|
"example: gemma3\n"
|
||||||
|
"(default: unused)",
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
params.model.docker_repo = value;
|
||||||
|
}
|
||||||
|
).set_env("LLAMA_ARG_DOCKER_REPO"));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
|
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
|
||||||
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
|
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
|
||||||
|
|
|
||||||
|
|
@ -197,6 +197,7 @@ struct common_params_model {
|
||||||
std::string url = ""; // model url to download // NOLINT
|
std::string url = ""; // model url to download // NOLINT
|
||||||
std::string hf_repo = ""; // HF repo // NOLINT
|
std::string hf_repo = ""; // HF repo // NOLINT
|
||||||
std::string hf_file = ""; // HF file // NOLINT
|
std::string hf_file = ""; // HF file // NOLINT
|
||||||
|
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_params_speculative {
|
struct common_params_speculative {
|
||||||
|
|
@ -287,9 +288,9 @@ struct common_params {
|
||||||
float rope_freq_base = 0.0f; // RoPE base frequency
|
float rope_freq_base = 0.0f; // RoPE base frequency
|
||||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||||
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
||||||
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
|
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
|
||||||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
float yarn_beta_fast = -1.0f; // YaRN low correction dim
|
||||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
float yarn_beta_slow = -1.0f; // YaRN high correction dim
|
||||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||||
|
|
||||||
// offload params
|
// offload params
|
||||||
|
|
@ -452,7 +453,7 @@ struct common_params {
|
||||||
|
|
||||||
std::string slot_save_path;
|
std::string slot_save_path;
|
||||||
|
|
||||||
float slot_prompt_similarity = 0.5f;
|
float slot_prompt_similarity = 0.1f;
|
||||||
|
|
||||||
// batched-bench params
|
// batched-bench params
|
||||||
bool is_pp_shared = false;
|
bool is_pp_shared = false;
|
||||||
|
|
|
||||||
|
|
@ -735,6 +735,9 @@ class TextModel(ModelBase):
|
||||||
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
|
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
|
||||||
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
|
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
|
||||||
res = "qwen2"
|
res = "qwen2"
|
||||||
|
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
|
||||||
|
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
|
||||||
|
res = "grok-2"
|
||||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||||
res = "llama-bpe"
|
res = "llama-bpe"
|
||||||
|
|
@ -2682,12 +2685,20 @@ class BitnetModel(TextModel):
|
||||||
yield (new_name, data_torch)
|
yield (new_name, data_torch)
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("GrokForCausalLM")
|
@ModelBase.register("GrokForCausalLM", "Grok1ForCausalLM")
|
||||||
class GrokModel(TextModel):
|
class GrokModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.GROK
|
model_arch = gguf.MODEL_ARCH.GROK
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
|
if (self.dir_model / 'tokenizer.model').is_file():
|
||||||
self._set_vocab_sentencepiece()
|
self._set_vocab_sentencepiece()
|
||||||
|
return
|
||||||
|
|
||||||
|
if not (self.dir_model / 'tokenizer.json').is_file() or not (self.dir_model / 'chat_template.jinja').is_file():
|
||||||
|
logger.error('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
@ -2695,11 +2706,46 @@ class GrokModel(TextModel):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
|
|
||||||
_experts: list[dict[str, Tensor]] | None = None
|
self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0))
|
||||||
|
self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0))
|
||||||
|
if (final_logit_softcap := self.hparams.get("final_logit_softcapping")):
|
||||||
|
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
|
||||||
|
|
||||||
|
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||||
|
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||||
|
|
||||||
|
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||||
|
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||||
|
|
||||||
|
# Treat "original" as "yarn", seems to have been a mistake
|
||||||
|
if self.hparams.get("rope_type") in ("yarn", "original"):
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
|
||||||
|
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_ext_factor(self.hparams["extrapolation_factor"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_attn_factor(self.hparams["attn_factor"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_beta_fast(self.hparams["beta_fast"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_beta_slow(self.hparams["beta_slow"])
|
||||||
|
|
||||||
|
if temp_len := self.hparams.get("attn_temperature_len"):
|
||||||
|
self.gguf_writer.add_attn_temperature_length(temp_len)
|
||||||
|
|
||||||
|
self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
|
||||||
|
self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
|
||||||
|
self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])
|
||||||
|
|
||||||
|
_experts: list[dict[str, list[Tensor]]] | None = None
|
||||||
|
_cur_expert = ""
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
tensors: list[tuple[str, Tensor]] = []
|
||||||
|
is_expert = ".moe." in name or ".block_sparse_moe.experts." in name
|
||||||
|
|
||||||
|
if not is_expert:
|
||||||
|
tensors.append((self.map_tensor_name(name), data_torch))
|
||||||
|
|
||||||
# process the experts separately
|
# process the experts separately
|
||||||
if name.find(".moe.") != -1:
|
if is_expert or self._cur_expert:
|
||||||
n_experts = self.hparams["num_local_experts"]
|
n_experts = self.hparams["num_local_experts"]
|
||||||
|
|
||||||
assert bid is not None
|
assert bid is not None
|
||||||
|
|
@ -2707,32 +2753,41 @@ class GrokModel(TextModel):
|
||||||
if self._experts is None:
|
if self._experts is None:
|
||||||
self._experts = [{} for _ in range(self.block_count)]
|
self._experts = [{} for _ in range(self.block_count)]
|
||||||
|
|
||||||
self._experts[bid][name] = data_torch
|
# concatenate split tensors
|
||||||
|
if name in self._experts[bid]:
|
||||||
|
self._cur_expert = name
|
||||||
|
self._experts[bid][name].append(data_torch)
|
||||||
|
return []
|
||||||
|
elif is_expert:
|
||||||
|
self._cur_expert = name
|
||||||
|
self._experts[bid][name] = [data_torch]
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
self._cur_expert = ""
|
||||||
|
|
||||||
|
for bid in range(self.block_count):
|
||||||
if len(self._experts[bid]) >= n_experts * 3:
|
if len(self._experts[bid]) >= n_experts * 3:
|
||||||
tensors: list[tuple[str, Tensor]] = []
|
|
||||||
|
|
||||||
# merge the experts into a single 3d tensor
|
# merge the experts into a single 3d tensor
|
||||||
for wid in ["linear", "linear_1", "linear_v"]:
|
for wid in [("linear", "w1", 0), ("linear_1", "w2", 1), ("linear_v", "w3", 0)]:
|
||||||
datas: list[Tensor] = []
|
datas: list[Tensor] = []
|
||||||
|
|
||||||
for xid in range(n_experts):
|
for xid in range(n_experts):
|
||||||
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
|
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight"
|
||||||
datas.append(self._experts[bid][ename])
|
if ename not in self._experts[bid]:
|
||||||
|
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight"
|
||||||
|
tensor_list = self._experts[bid][ename]
|
||||||
|
datas.append(torch.cat(tensor_list, dim=wid[2]) if len(tensor_list) > 1 else tensor_list[0])
|
||||||
del self._experts[bid][ename]
|
del self._experts[bid][ename]
|
||||||
|
|
||||||
data_torch = torch.stack(datas, dim=0)
|
data_torch = torch.stack(datas, dim=0)
|
||||||
|
|
||||||
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
|
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight"
|
||||||
|
|
||||||
new_name = self.map_tensor_name(merged_name)
|
new_name = self.map_tensor_name(merged_name)
|
||||||
|
|
||||||
tensors.append((new_name, data_torch))
|
yield (new_name, data_torch)
|
||||||
return tensors
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
yield from tensors
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("DbrxForCausalLM")
|
@ModelBase.register("DbrxForCausalLM")
|
||||||
|
|
|
||||||
|
|
@ -158,6 +158,7 @@ pre_computed_hashes = [
|
||||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
|
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
|
||||||
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
|
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
|
||||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
|
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
|
||||||
|
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -318,3 +318,7 @@ Operators are executed using ACL graph execution, rather than in op-by-op (eager
|
||||||
### GGML_CANN_GRAPH_CACHE_CAPACITY
|
### GGML_CANN_GRAPH_CACHE_CAPACITY
|
||||||
|
|
||||||
Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. When the number of cached graphs exceeds this capacity, the least recently used graph will be evicted.
|
Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. When the number of cached graphs exceeds this capacity, the least recently used graph will be evicted.
|
||||||
|
|
||||||
|
### GGML_CANN_PREFILL_USE_GRAPH
|
||||||
|
|
||||||
|
Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled.
|
||||||
|
|
|
||||||
|
|
@ -241,8 +241,8 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||||
| | VX/VXE/VXE2 | zDNN | Spyre |
|
| | VX/VXE/VXE2 | zDNN | Spyre |
|
||||||
|------------|-------------|------|-------|
|
|------------|-------------|------|-------|
|
||||||
| FP32 | ✅ | ✅ | ❓ |
|
| FP32 | ✅ | ✅ | ❓ |
|
||||||
| FP16 | ✅ | ❓ | ❓ |
|
| FP16 | ✅ | ✅ | ❓ |
|
||||||
| BF16 | 🚫 | ❓ | ❓ |
|
| BF16 | 🚫 | ✅ | ❓ |
|
||||||
| Q4_0 | ✅ | ❓ | ❓ |
|
| Q4_0 | ✅ | ❓ | ❓ |
|
||||||
| Q4_1 | ✅ | ❓ | ❓ |
|
| Q4_1 | ✅ | ❓ | ❓ |
|
||||||
| MXFP4 | 🚫 | ❓ | ❓ |
|
| MXFP4 | 🚫 | ❓ | ❓ |
|
||||||
|
|
@ -272,4 +272,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
|
||||||
- 🚫 - acceleration unavailable, will still run using scalar implementation
|
- 🚫 - acceleration unavailable, will still run using scalar implementation
|
||||||
- ❓ - acceleration unknown, please contribute if you can test it yourself
|
- ❓ - acceleration unknown, please contribute if you can test it yourself
|
||||||
|
|
||||||
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 6, 2025.
|
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 7, 2025.
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ Legend:
|
||||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
| ADD_ID | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
|
@ -26,6 +27,7 @@ Legend:
|
||||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| CONV_2D | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
| CONV_2D | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
|
| CONV_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||||
|
|
@ -49,9 +51,11 @@ Legend:
|
||||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||||
|
| IM2COL_3D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
|
|
@ -61,7 +65,9 @@ Legend:
|
||||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
|
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
|
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||||
|
|
@ -98,6 +104,7 @@ Legend:
|
||||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||||
|
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
|
|
|
||||||
11114
docs/ops/zDNN.csv
11114
docs/ops/zDNN.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -132,6 +132,8 @@ extern "C" {
|
||||||
GGML_BACKEND_DEVICE_TYPE_CPU,
|
GGML_BACKEND_DEVICE_TYPE_CPU,
|
||||||
// GPU device using dedicated memory
|
// GPU device using dedicated memory
|
||||||
GGML_BACKEND_DEVICE_TYPE_GPU,
|
GGML_BACKEND_DEVICE_TYPE_GPU,
|
||||||
|
// integrated GPU device using host memory
|
||||||
|
GGML_BACKEND_DEVICE_TYPE_IGPU,
|
||||||
// accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
|
// accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
|
||||||
GGML_BACKEND_DEVICE_TYPE_ACCEL
|
GGML_BACKEND_DEVICE_TYPE_ACCEL
|
||||||
};
|
};
|
||||||
|
|
@ -150,11 +152,21 @@ extern "C" {
|
||||||
|
|
||||||
// all the device properties
|
// all the device properties
|
||||||
struct ggml_backend_dev_props {
|
struct ggml_backend_dev_props {
|
||||||
|
// device name
|
||||||
const char * name;
|
const char * name;
|
||||||
|
// device description
|
||||||
const char * description;
|
const char * description;
|
||||||
|
// device free memory in bytes
|
||||||
size_t memory_free;
|
size_t memory_free;
|
||||||
|
// device total memory in bytes
|
||||||
size_t memory_total;
|
size_t memory_total;
|
||||||
|
// device type
|
||||||
enum ggml_backend_dev_type type;
|
enum ggml_backend_dev_type type;
|
||||||
|
// device id
|
||||||
|
// for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0")
|
||||||
|
// if the id is unknown, this should be NULL
|
||||||
|
const char * device_id;
|
||||||
|
// device capabilities
|
||||||
struct ggml_backend_dev_caps caps;
|
struct ggml_backend_dev_caps caps;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,6 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_t ggml_backend_zdnn_init(void);
|
|
||||||
|
|
||||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zdnn_reg(void);
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zdnn_reg(void);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define GGML_BACKEND_API_VERSION 1
|
#define GGML_BACKEND_API_VERSION 2
|
||||||
|
|
||||||
//
|
//
|
||||||
// Backend buffer type
|
// Backend buffer type
|
||||||
|
|
|
||||||
|
|
@ -400,9 +400,8 @@ ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const
|
||||||
|
|
||||||
ggml_backend_t ggml_backend_init_best(void) {
|
ggml_backend_t ggml_backend_init_best(void) {
|
||||||
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
|
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
|
||||||
if (!dev) {
|
dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU);
|
||||||
dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||||
}
|
|
||||||
if (!dev) {
|
if (!dev) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2360,6 +2360,21 @@ static enum ggml_status ggml_backend_cann_graph_compute(
|
||||||
bool use_cann_graph = true;
|
bool use_cann_graph = true;
|
||||||
bool cann_graph_update_required = false;
|
bool cann_graph_update_required = false;
|
||||||
|
|
||||||
|
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
||||||
|
if (!prefill_use_graph) {
|
||||||
|
// Do not use acl_graph for prefill.
|
||||||
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||||
|
ggml_tensor * node = cgraph->nodes[i];
|
||||||
|
// TODO: Optimize here. Currently, we can only
|
||||||
|
// get seq_len by FA's input.
|
||||||
|
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||||
|
// Q -> src[0], shape: [B, S, N, D]
|
||||||
|
use_cann_graph = (node->src[0]->ne[1] == 1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!cann_ctx->acl_graph_mode) {
|
if (!cann_ctx->acl_graph_mode) {
|
||||||
use_cann_graph = false;
|
use_cann_graph = false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -224,8 +224,14 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||||
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
|
||||||
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
|
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
|
||||||
if (NOT ${feature_pos} EQUAL -1)
|
if (NOT ${feature_pos} EQUAL -1)
|
||||||
|
# Special handling for MATMUL_INT8 when machine doesn't support i8mm
|
||||||
|
if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm)
|
||||||
|
message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm")
|
||||||
|
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
else()
|
||||||
message(STATUS "ARM feature ${feature} enabled")
|
message(STATUS "ARM feature ${feature} enabled")
|
||||||
endif()
|
endif()
|
||||||
|
endif()
|
||||||
endforeach()
|
endforeach()
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
|
||||||
|
|
@ -515,9 +515,6 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||||
op->src[0]->buffer &&
|
op->src[0]->buffer &&
|
||||||
(ggml_n_dims(op->src[0]) == 2) &&
|
(ggml_n_dims(op->src[0]) == 2) &&
|
||||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
||||||
if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -555,7 +555,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
|
||||||
#if defined(GGML_USE_HIP) && defined(GCN)
|
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
|
||||||
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
|
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
|
||||||
#else
|
#else
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
|
|
@ -567,7 +567,21 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
|
||||||
acc += tmpv.x * tmpu.x;
|
acc += tmpv.x * tmpu.x;
|
||||||
acc += tmpv.y * tmpu.y;
|
acc += tmpv.y * tmpu.y;
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // FAST_FP16_AVAILABLE
|
||||||
#endif // defined(GGML_USE_HIP) && defined(GCN)
|
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
|
||||||
|
template <int nbytes>
|
||||||
|
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
|
||||||
|
if constexpr (nbytes == 4) {
|
||||||
|
*(int *) dst = *(const int *) src;
|
||||||
|
} else if constexpr (nbytes == 8) {
|
||||||
|
*(int2 *) dst = *(const int2 *) src;
|
||||||
|
} else if constexpr (nbytes == 16) {
|
||||||
|
*(int4 *) dst = *(const int4 *) src;
|
||||||
|
} else {
|
||||||
|
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
||||||
|
|
|
||||||
|
|
@ -8,11 +8,14 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
|
||||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||||
switch (D) {
|
switch (D) {
|
||||||
case 64:
|
case 64:
|
||||||
return ncols <= 16 ? 32 : 64;
|
|
||||||
case 128:
|
|
||||||
return ncols <= 16 ? 64 : warp_size;
|
|
||||||
case 256:
|
|
||||||
return 64;
|
return 64;
|
||||||
|
case 128:
|
||||||
|
case 256:
|
||||||
|
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||||
|
return ncols <= 16 ? 64 : 32;
|
||||||
|
} else {
|
||||||
|
return 64;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
return -1;
|
return -1;
|
||||||
|
|
@ -41,17 +44,26 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
GGML_UNUSED(warp_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
|
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
|
||||||
#ifdef GGML_USE_HIP
|
#ifdef GGML_USE_HIP
|
||||||
switch (D) {
|
switch (D) {
|
||||||
case 64:
|
case 64:
|
||||||
return ncols <= 16 ? 32 : 64;
|
|
||||||
case 128:
|
|
||||||
return ncols <= 16 ? 64 : warp_size;
|
|
||||||
case 256:
|
|
||||||
return 64;
|
return 64;
|
||||||
|
case 128:
|
||||||
|
#if defined(GCN) || defined(CDNA)
|
||||||
|
return ncols <= 16 ? 64 : 32;
|
||||||
|
#else
|
||||||
|
return 64;
|
||||||
|
#endif // defined(GCN) || defined(CDNA)
|
||||||
|
case 256:
|
||||||
|
#if defined(GCN) || defined(CDNA)
|
||||||
|
return ncols <= 16 ? 64 : 32;
|
||||||
|
#else
|
||||||
|
return 64;
|
||||||
|
#endif // defined(GCN) || defined(CDNA)
|
||||||
default:
|
default:
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
@ -88,9 +100,17 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
||||||
case 64:
|
case 64:
|
||||||
return 64;
|
return 64;
|
||||||
case 128:
|
case 128:
|
||||||
return ncols <= 16 ? 2*warp_size : 128;
|
#if defined(GCN) || defined(CDNA)
|
||||||
|
return ncols <= 16 ? 64 : 128;
|
||||||
|
#else
|
||||||
|
return 64;
|
||||||
|
#endif // defined(GCN) || defined(CDNA)
|
||||||
case 256:
|
case 256:
|
||||||
return ncols <= 16 ? 128 : 2*warp_size;
|
#if defined(GCN) || defined(CDNA)
|
||||||
|
return ncols <= 16 ? 64 : 128;
|
||||||
|
#else
|
||||||
|
return ncols <= 16 ? 64 : 256;
|
||||||
|
#endif // defined(GCN) || defined(CDNA)
|
||||||
default:
|
default:
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
@ -196,14 +216,21 @@ static __global__ void flash_attn_tile(
|
||||||
|
|
||||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HIP)
|
||||||
|
constexpr int cpy_nb = 16;
|
||||||
|
#else
|
||||||
|
constexpr int cpy_nb = 8;
|
||||||
|
#endif // defined(GGML_USE_HIP) && defined(GCN)
|
||||||
|
constexpr int cpy_ne = cpy_nb / 4;
|
||||||
|
|
||||||
__shared__ float KQ[ncols][kq_stride];
|
__shared__ float KQ[ncols][kq_stride];
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
__shared__ half2 Q_tmp[ncols][D/2];
|
__shared__ half2 Q_tmp[ncols][D/2];
|
||||||
__shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + 1)]; // Padded to avoid memory bank conflicts.
|
__shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||||
half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||||
#else
|
#else
|
||||||
__shared__ float Q_tmp[ncols][D];
|
__shared__ float Q_tmp[ncols][D];
|
||||||
__shared__ float KV_tmp_f[kq_stride * (kq_nbatch + 1)]; // Padded to avoid memory bank conflicts.
|
__shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||||
float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
|
float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
|
||||||
float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // FAST_FP16_AVAILABLE
|
||||||
|
|
@ -256,11 +283,11 @@ static __global__ void flash_attn_tile(
|
||||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
|
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
|
||||||
const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
|
const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1 + threadIdx.x] = tmp_h2;
|
KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x] = tmp_h2;
|
||||||
#else
|
#else
|
||||||
const float2 tmp_f2 = __half22float2(tmp_h2);
|
const float2 tmp_f2 = __half22float2(tmp_h2);
|
||||||
KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x;
|
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x;
|
||||||
KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
|
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // FAST_FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -269,14 +296,14 @@ static __global__ void flash_attn_tile(
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; ++k_KQ_1) {
|
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
|
||||||
half2 K_k[kq_stride/warp_size];
|
half2 K_k[kq_stride/warp_size][cpy_ne];
|
||||||
half2 Q_k[ncols/nwarps];
|
half2 Q_k[ncols/nwarps][cpy_ne];
|
||||||
#else
|
#else
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; ++k_KQ_1) {
|
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
|
||||||
float K_k[kq_stride/warp_size];
|
float K_k[kq_stride/warp_size][cpy_ne];
|
||||||
float Q_k[ncols/nwarps];
|
float Q_k[ncols/nwarps][cpy_ne];
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // FAST_FP16_AVAILABLE
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|
@ -284,9 +311,9 @@ static __global__ void flash_attn_tile(
|
||||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
K_k[i_KQ_0/warp_size] = KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1];
|
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
||||||
#else
|
#else
|
||||||
K_k[i_KQ_0/warp_size] = KV_tmp_f [i_KQ*(kq_nbatch + 1) + k_KQ_1];
|
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // FAST_FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
|
@ -294,9 +321,9 @@ static __global__ void flash_attn_tile(
|
||||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1];
|
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
||||||
#else
|
#else
|
||||||
Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0 + k_KQ_1];
|
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // FAST_FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -304,7 +331,10 @@ static __global__ void flash_attn_tile(
|
||||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
|
#pragma unroll
|
||||||
|
for (int k = 0; k < cpy_ne; ++k) {
|
||||||
|
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -345,14 +375,54 @@ static __global__ void flash_attn_tile(
|
||||||
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
|
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
|
||||||
|
|
||||||
float kqsum_add = 0.0f;
|
float kqsum_add = 0.0f;
|
||||||
|
if (kq_stride % (4*warp_size) == 0 && cpy_ne % 4 == 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < kq_stride; i0 += 4*warp_size) {
|
||||||
|
const int i = i0 + 4*threadIdx.x;
|
||||||
|
|
||||||
|
float4 val = *(const float4 *) &KQ[j][i];
|
||||||
|
val.x = expf(val.x - kqmax[j0/nwarps]);
|
||||||
|
val.y = expf(val.y - kqmax[j0/nwarps]);
|
||||||
|
val.z = expf(val.z - kqmax[j0/nwarps]);
|
||||||
|
val.w = expf(val.w - kqmax[j0/nwarps]);
|
||||||
|
kqsum_add += val.x + val.y + val.z + val.w;
|
||||||
|
|
||||||
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
|
const half2 tmp[2] = {make_half2(val.x, val.y), make_half2(val.z, val.w)};
|
||||||
|
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
|
||||||
|
#else
|
||||||
|
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
|
||||||
|
#endif // FAST_FP16_AVAILABLE
|
||||||
|
}
|
||||||
|
} else if (kq_stride % (2*warp_size) == 0 && cpy_ne % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < kq_stride; i0 += 2*warp_size) {
|
||||||
|
const int i = i0 + 2*threadIdx.x;
|
||||||
|
|
||||||
|
float2 val = *(const float2 *) &KQ[j][i];
|
||||||
|
val.x = expf(val.x - kqmax[j0/nwarps]);
|
||||||
|
val.y = expf(val.y - kqmax[j0/nwarps]);
|
||||||
|
kqsum_add += val.x + val.y;
|
||||||
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
|
const half2 tmp = make_half2(val.x, val.y);
|
||||||
|
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
|
||||||
|
#else
|
||||||
|
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
|
||||||
|
#endif // FAST_FP16_AVAILABLE
|
||||||
|
}
|
||||||
|
} else {
|
||||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||||
const int i = i0 + threadIdx.x;
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
const float diff = KQ[j][i] - kqmax[j0/nwarps];
|
const float diff = KQ[j][i] - kqmax[j0/nwarps];
|
||||||
const float val = expf(diff);
|
const float val = expf(diff);
|
||||||
kqsum_add += val;
|
kqsum_add += val;
|
||||||
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
|
((half *) KQ[j])[i] = val;
|
||||||
|
#else
|
||||||
KQ[j][i] = val;
|
KQ[j][i] = val;
|
||||||
|
#endif // FAST_FP16_AVAILABLE
|
||||||
|
}
|
||||||
}
|
}
|
||||||
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
|
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
|
||||||
|
|
||||||
|
|
@ -419,8 +489,7 @@ static __global__ void flash_attn_tile(
|
||||||
const int j = j0 + threadIdx.y;
|
const int j = j0 + threadIdx.y;
|
||||||
|
|
||||||
#ifdef FAST_FP16_AVAILABLE
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
const float tmp = KQ[j][k0 + k1];
|
KQ_k[j0/nwarps] = __half2half2(((const half *)KQ[j])[k0 + k1]);
|
||||||
KQ_k[j0/nwarps] = make_half2(tmp, tmp);
|
|
||||||
#else
|
#else
|
||||||
KQ_k[j0/nwarps] = KQ[j][k0 + k1];
|
KQ_k[j0/nwarps] = KQ[j][k0 + k1];
|
||||||
#endif // FAST_FP16_AVAILABLE
|
#endif // FAST_FP16_AVAILABLE
|
||||||
|
|
|
||||||
|
|
@ -3210,6 +3210,7 @@ struct ggml_backend_cuda_device_context {
|
||||||
int device;
|
int device;
|
||||||
std::string name;
|
std::string name;
|
||||||
std::string description;
|
std::string description;
|
||||||
|
std::string pci_bus_id;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
||||||
|
|
@ -3234,9 +3235,12 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||||
|
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||||
|
|
||||||
props->name = ggml_backend_cuda_device_get_name(dev);
|
props->name = ggml_backend_cuda_device_get_name(dev);
|
||||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||||
|
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
|
||||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||||
|
|
||||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||||
|
|
@ -3804,6 +3808,10 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||||
dev_ctx->description = prop.name;
|
dev_ctx->description = prop.name;
|
||||||
|
|
||||||
|
char pci_bus_id[16] = {};
|
||||||
|
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
|
||||||
|
dev_ctx->pci_bus_id = pci_bus_id;
|
||||||
|
|
||||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||||
/* .iface = */ ggml_backend_cuda_device_interface,
|
/* .iface = */ ggml_backend_cuda_device_interface,
|
||||||
/* .reg = */ ®,
|
/* .reg = */ ®,
|
||||||
|
|
|
||||||
|
|
@ -162,6 +162,14 @@
|
||||||
#define GCN
|
#define GCN
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__gfx900__) || defined(__gfx906__)
|
||||||
|
#define GCN5
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(__gfx803__)
|
||||||
|
#define GCN4
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
||||||
#define CDNA // For the entire family
|
#define CDNA // For the entire family
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ message(STATUS "Metal framework found")
|
||||||
|
|
||||||
ggml_add_backend_library(ggml-metal
|
ggml_add_backend_library(ggml-metal
|
||||||
ggml-metal.m
|
ggml-metal.m
|
||||||
|
ggml-metal-common.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(ggml-metal PRIVATE
|
target_link_libraries(ggml-metal PRIVATE
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,458 @@
|
||||||
|
#include "ggml-metal-common.h"
|
||||||
|
|
||||||
|
#include "ggml-impl.h"
|
||||||
|
#include "ggml-backend-impl.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
|
||||||
|
// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
|
||||||
|
struct ggml_mem_range {
|
||||||
|
uint64_t pb; // buffer id
|
||||||
|
|
||||||
|
uint64_t p0; // begin
|
||||||
|
uint64_t p1; // end
|
||||||
|
|
||||||
|
ggml_mem_range_type pt;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_mem_ranges {
|
||||||
|
std::vector<ggml_mem_range> ranges;
|
||||||
|
|
||||||
|
int debug = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
|
||||||
|
auto * res = new ggml_mem_ranges;
|
||||||
|
|
||||||
|
res->ranges.reserve(256);
|
||||||
|
res->debug = debug;
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_mem_ranges_free(ggml_mem_ranges * mrs) {
|
||||||
|
delete mrs;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
|
||||||
|
mrs->ranges.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mr) {
|
||||||
|
mrs->ranges.push_back(mr);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) {
|
||||||
|
// always use the base tensor
|
||||||
|
tensor = tensor->view_src ? tensor->view_src : tensor;
|
||||||
|
|
||||||
|
GGML_ASSERT(!tensor->view_src);
|
||||||
|
|
||||||
|
ggml_mem_range mr;
|
||||||
|
|
||||||
|
if (tensor->buffer) {
|
||||||
|
// when the tensor is allocated, use the actual memory address range in the buffer
|
||||||
|
//
|
||||||
|
// take the actual allocated size with ggml_backend_buft_get_alloc_size()
|
||||||
|
// this can be larger than the tensor size if the buffer type allocates extra memory
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/15966
|
||||||
|
mr = {
|
||||||
|
/*.pb =*/ (uint64_t) tensor->buffer,
|
||||||
|
/*.p0 =*/ (uint64_t) tensor->data,
|
||||||
|
/*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
|
||||||
|
/*.pt =*/ pt,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
// otherwise, the pointer address is used as an unique id of the memory ranges
|
||||||
|
// that the tensor will be using when it is allocated
|
||||||
|
mr = {
|
||||||
|
/*.pb =*/ (uint64_t) tensor,
|
||||||
|
/*.p0 =*/ 0, //
|
||||||
|
/*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
|
||||||
|
/*.pt =*/ pt,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
return mr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
|
||||||
|
return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) {
|
||||||
|
return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||||
|
GGML_ASSERT(tensor);
|
||||||
|
|
||||||
|
ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
|
||||||
|
|
||||||
|
if (mrs->debug > 2) {
|
||||||
|
GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ggml_mem_ranges_add(mrs, mr);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||||
|
GGML_ASSERT(tensor);
|
||||||
|
|
||||||
|
ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
|
||||||
|
|
||||||
|
if (mrs->debug > 2) {
|
||||||
|
GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ggml_mem_ranges_add(mrs, mr);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||||
|
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||||
|
if (tensor->src[i]) {
|
||||||
|
ggml_mem_ranges_add_src(mrs, tensor->src[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ggml_mem_ranges_add_dst(mrs, tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr) {
|
||||||
|
for (size_t i = 0; i < mrs->ranges.size(); i++) {
|
||||||
|
const auto & cmp = mrs->ranges[i];
|
||||||
|
|
||||||
|
// two memory ranges cannot intersect if they are in different buffers
|
||||||
|
if (mr.pb != cmp.pb) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// intersecting source ranges are allowed
|
||||||
|
if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
|
||||||
|
if (mrs->debug > 2) {
|
||||||
|
GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
|
||||||
|
__func__,
|
||||||
|
mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
|
||||||
|
mr.pb, mr.p0, mr.p1,
|
||||||
|
cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
|
||||||
|
cmp.pb, cmp.p0, cmp.p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||||
|
GGML_ASSERT(tensor);
|
||||||
|
|
||||||
|
ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
|
||||||
|
|
||||||
|
const bool res = ggml_mem_ranges_check(mrs, mr);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||||
|
GGML_ASSERT(tensor);
|
||||||
|
|
||||||
|
ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
|
||||||
|
|
||||||
|
const bool res = ggml_mem_ranges_check(mrs, mr);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||||
|
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||||
|
if (tensor->src[i]) {
|
||||||
|
if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ggml_mem_ranges_check_dst(mrs, tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: move to ggml.h?
|
||||||
|
static bool is_empty(ggml_op op) {
|
||||||
|
switch (op) {
|
||||||
|
case GGML_OP_NONE:
|
||||||
|
case GGML_OP_RESHAPE:
|
||||||
|
case GGML_OP_TRANSPOSE:
|
||||||
|
case GGML_OP_VIEW:
|
||||||
|
case GGML_OP_PERMUTE:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct node_info {
|
||||||
|
ggml_tensor * node;
|
||||||
|
|
||||||
|
std::vector<ggml_tensor *> fused;
|
||||||
|
|
||||||
|
ggml_op op() const {
|
||||||
|
return node->op;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ggml_tensor * dst() const {
|
||||||
|
return fused.empty() ? node : fused.back();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_empty() const {
|
||||||
|
return ::is_empty(node->op);
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_fused(ggml_tensor * t) {
|
||||||
|
fused.push_back(t);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
|
||||||
|
// helper to add node src and dst ranges
|
||||||
|
const auto & h_add = [](ggml_mem_ranges * mrs, const node_info & node) {
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
|
if (node.node->src[i]) {
|
||||||
|
if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// keep track of the sources of the fused nodes as well
|
||||||
|
for (const auto * fused : node.fused) {
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
|
if (fused->src[i]) {
|
||||||
|
if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ggml_mem_ranges_add_dst(mrs, node.dst());
|
||||||
|
};
|
||||||
|
|
||||||
|
// helper to check if a node can run concurrently with the existing set of nodes
|
||||||
|
const auto & h_check = [](const ggml_mem_ranges * mrs, const node_info & node) {
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
|
if (node.node->src[i]) {
|
||||||
|
if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto * fused : node.fused) {
|
||||||
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||||
|
if (fused->src[i]) {
|
||||||
|
if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ggml_mem_ranges_check_dst(mrs, node.dst());
|
||||||
|
};
|
||||||
|
|
||||||
|
// perform reorders only across these types of ops
|
||||||
|
// can be expanded when needed
|
||||||
|
// IMPORTANT: do not add ops such as GGML_OP_CPY or GGML_OP_SET_ROWS
|
||||||
|
// the dependencies from such ops are not always represented in the graph
|
||||||
|
const auto & h_safe = [](ggml_op op) {
|
||||||
|
switch (op) {
|
||||||
|
case GGML_OP_MUL_MAT:
|
||||||
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
case GGML_OP_ROPE:
|
||||||
|
case GGML_OP_NORM:
|
||||||
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_GROUP_NORM:
|
||||||
|
case GGML_OP_SUM_ROWS:
|
||||||
|
case GGML_OP_MUL:
|
||||||
|
case GGML_OP_ADD:
|
||||||
|
case GGML_OP_DIV:
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
case GGML_OP_SCALE:
|
||||||
|
case GGML_OP_GET_ROWS:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return is_empty(op);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const int n = nodes.size();
|
||||||
|
|
||||||
|
std::vector<int> res;
|
||||||
|
res.reserve(n);
|
||||||
|
|
||||||
|
std::vector<bool> used(n, false);
|
||||||
|
|
||||||
|
// the memory ranges for the set of currently concurrent nodes
|
||||||
|
ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0);
|
||||||
|
|
||||||
|
// the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
|
||||||
|
ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0);
|
||||||
|
|
||||||
|
for (int i0 = 0; i0 < n; i0++) {
|
||||||
|
if (used[i0]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto & node0 = nodes[i0];
|
||||||
|
|
||||||
|
// the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
|
||||||
|
// but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
|
||||||
|
//
|
||||||
|
// note: we can always add empty nodes to the concurrent set as they don't read nor write anything
|
||||||
|
if (!node0.is_empty() && !h_check(mrs0, node0)) {
|
||||||
|
// this will hold the set of memory ranges from the nodes that haven't been processed yet
|
||||||
|
// if a node is not concurrent with this set, we cannot reorder it
|
||||||
|
ggml_mem_ranges_reset(mrs1);
|
||||||
|
|
||||||
|
// initialize it with the current node
|
||||||
|
h_add(mrs1, node0);
|
||||||
|
|
||||||
|
// that many nodes forward to search for a concurrent node
|
||||||
|
constexpr int N_FORWARD = 8;
|
||||||
|
|
||||||
|
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
|
||||||
|
if (used[i1]) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto & node1 = nodes[i1];
|
||||||
|
|
||||||
|
// disallow reordering of certain ops
|
||||||
|
if (!h_safe(node1.op())) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool is_empty = node1.is_empty();
|
||||||
|
|
||||||
|
// to reorder a node and add it to the concurrent set, it has to be:
|
||||||
|
// + empty or concurrent with all nodes in the existing concurrent set (mrs0)
|
||||||
|
// + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
|
||||||
|
if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
|
||||||
|
// add the node to the existing concurrent set (i.e. reorder it for early execution)
|
||||||
|
h_add(mrs0, node1);
|
||||||
|
res.push_back(i1);
|
||||||
|
|
||||||
|
// mark as used, so we skip re-processing it later
|
||||||
|
used[i1] = true;
|
||||||
|
} else {
|
||||||
|
// expand the set of nodes that haven't been processed yet
|
||||||
|
h_add(mrs1, node1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// finalize the concurrent set and begin a new one
|
||||||
|
ggml_mem_ranges_reset(mrs0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// expand the concurrent set with the current node
|
||||||
|
{
|
||||||
|
h_add(mrs0, node0);
|
||||||
|
res.push_back(i0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_mem_ranges_free(mrs0);
|
||||||
|
ggml_mem_ranges_free(mrs1);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_metal_graph_optimize(ggml_cgraph * gf) {
|
||||||
|
constexpr int MAX_FUSE = 16;
|
||||||
|
|
||||||
|
const int n = gf->n_nodes;
|
||||||
|
|
||||||
|
enum ggml_op ops[MAX_FUSE];
|
||||||
|
|
||||||
|
std::vector<node_info> nodes;
|
||||||
|
nodes.reserve(gf->n_nodes);
|
||||||
|
|
||||||
|
// fuse nodes:
|
||||||
|
// we don't want to make reorders that break fusing, so we first pack all fusable tensors
|
||||||
|
// and perform the reorder over the fused nodes. after the reorder is done, we unfuse
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
node_info node = {
|
||||||
|
/*.node =*/ gf->nodes[i],
|
||||||
|
/*.fused =*/ {},
|
||||||
|
};
|
||||||
|
|
||||||
|
// fuse only ops that start with these operations
|
||||||
|
// can be expanded when needed
|
||||||
|
if (node.op() == GGML_OP_ADD ||
|
||||||
|
node.op() == GGML_OP_RMS_NORM) {
|
||||||
|
ops[0] = node.op();
|
||||||
|
|
||||||
|
int f = i + 1;
|
||||||
|
while (f < n && f < i + MAX_FUSE) {
|
||||||
|
// conservatively allow fusing only these ops
|
||||||
|
// can be expanded when needed
|
||||||
|
if (gf->nodes[f]->op != GGML_OP_ADD &&
|
||||||
|
gf->nodes[f]->op != GGML_OP_MUL &&
|
||||||
|
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
ops[f - i] = gf->nodes[f]->op;
|
||||||
|
f++;
|
||||||
|
}
|
||||||
|
|
||||||
|
f -= i;
|
||||||
|
for (; f > 1; f--) {
|
||||||
|
if (ggml_can_fuse(gf, i, ops, f)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add the fused tensors into the node info so we can unfuse them later
|
||||||
|
for (int k = 1; k < f; k++) {
|
||||||
|
++i;
|
||||||
|
|
||||||
|
// the .dst() becomes the last fused tensor
|
||||||
|
node.add_fused(gf->nodes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes.push_back(std::move(node));
|
||||||
|
}
|
||||||
|
|
||||||
|
#if 1
|
||||||
|
// reorder to improve concurrency
|
||||||
|
const auto order = ggml_metal_graph_optimize_reorder(nodes);
|
||||||
|
#else
|
||||||
|
std::vector<int> order(nodes.size());
|
||||||
|
for (size_t i = 0; i < nodes.size(); i++) {
|
||||||
|
order[i] = i;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// unfuse
|
||||||
|
{
|
||||||
|
int j = 0;
|
||||||
|
for (const auto i : order) {
|
||||||
|
const auto & node = nodes[i];
|
||||||
|
|
||||||
|
gf->nodes[j++] = node.node;
|
||||||
|
|
||||||
|
for (auto * fused : node.fused) {
|
||||||
|
gf->nodes[j++] = fused;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
// helper functions for ggml-metal that are too difficult to implement in Objective-C
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct ggml_tensor;
|
||||||
|
struct ggml_cgraph;
|
||||||
|
|
||||||
|
enum ggml_mem_range_type {
|
||||||
|
MEM_RANGE_TYPE_SRC = 0,
|
||||||
|
MEM_RANGE_TYPE_DST = 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
// a helper object that can be used for reordering operations to improve concurrency
|
||||||
|
//
|
||||||
|
// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they
|
||||||
|
// don't write to a memory that is being read by another task or written to by another task in the set
|
||||||
|
//
|
||||||
|
// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task
|
||||||
|
// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
|
||||||
|
// tasks already in the set)
|
||||||
|
//
|
||||||
|
struct ggml_mem_ranges;
|
||||||
|
|
||||||
|
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug);
|
||||||
|
void ggml_mem_ranges_free(struct ggml_mem_ranges * mrs);
|
||||||
|
|
||||||
|
// remove all ranges from the set
|
||||||
|
void ggml_mem_ranges_reset(struct ggml_mem_ranges * mrs);
|
||||||
|
|
||||||
|
// add src or dst ranges to track
|
||||||
|
bool ggml_mem_ranges_add(struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
// return false if:
|
||||||
|
// - new src range overlaps with any existing dst range
|
||||||
|
// - new dst range overlaps with any existing range (src or dst)
|
||||||
|
bool ggml_mem_ranges_check(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
// reorder the nodes in the graph to improve concurrency, while respecting fusion
|
||||||
|
//
|
||||||
|
// note: this implementation is generic and not specific to metal
|
||||||
|
// if it proves to work well, we can start using it for other backends in the future
|
||||||
|
void ggml_metal_graph_optimize(struct ggml_cgraph * gf);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -928,7 +928,7 @@ kernel void kernel_add_fuse_impl(
|
||||||
|
|
||||||
typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
|
typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
|
||||||
|
|
||||||
template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
||||||
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
|
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
|
||||||
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
|
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
|
||||||
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
|
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
|
||||||
|
|
@ -937,7 +937,7 @@ template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_
|
||||||
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
|
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
|
||||||
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
|
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
|
||||||
|
|
||||||
kernel void kernel_sub(
|
kernel void kernel_sub_fuse_1(
|
||||||
constant ggml_metal_kargs_bin & args,
|
constant ggml_metal_kargs_bin & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
|
@ -963,7 +963,7 @@ kernel void kernel_sub(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul(
|
kernel void kernel_mul_fuse_1(
|
||||||
constant ggml_metal_kargs_bin & args,
|
constant ggml_metal_kargs_bin & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
|
@ -996,7 +996,7 @@ kernel void kernel_mul(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_div(
|
kernel void kernel_div_fuse_1(
|
||||||
constant ggml_metal_kargs_bin & args,
|
constant ggml_metal_kargs_bin & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
|
|
@ -1096,23 +1096,17 @@ kernel void kernel_add_row_c4_fuse_impl(
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
uint tpig[[thread_position_in_grid]]) {
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
|
||||||
const uint nb = args.ne00/4;
|
const uint nb = args.ne00/4;
|
||||||
const uint i = tpig % nb;
|
const uint i = tpig % nb;
|
||||||
|
|
||||||
device const float4 * src0_row = (device const float4 *) (src0);
|
device const float4 * src0_row = (device const float4 *) (src0);
|
||||||
device float4 * dst_row = (device float4 *) (dst);
|
device float4 * dst_row = (device float4 *) (dst);
|
||||||
|
|
||||||
device const float4 * src1_row[F];
|
|
||||||
for (short j = 0; j < F; ++j) {
|
|
||||||
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
|
||||||
}
|
|
||||||
|
|
||||||
float4 res = src0_row[tpig];
|
float4 res = src0_row[tpig];
|
||||||
|
|
||||||
#pragma unroll(F)
|
#pragma unroll(F)
|
||||||
for (short j = 0; j < F; ++j) {
|
for (short j = 0; j < F; ++j) {
|
||||||
res += src1_row[j][i];
|
res += ((device const float4 *) (src1 + args.o1[j]))[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
dst_row[tpig] = res;
|
dst_row[tpig] = res;
|
||||||
|
|
@ -1120,7 +1114,7 @@ kernel void kernel_add_row_c4_fuse_impl(
|
||||||
|
|
||||||
typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
|
typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
|
||||||
|
|
||||||
template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
|
template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
|
||||||
template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
|
template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
|
||||||
template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
|
template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
|
||||||
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
|
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
|
||||||
|
|
@ -1160,7 +1154,7 @@ kernel void kernel_sub_row_c4_fuse_impl(
|
||||||
|
|
||||||
typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
|
typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
|
||||||
|
|
||||||
template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
|
template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
|
||||||
|
|
||||||
template <short F>
|
template <short F>
|
||||||
kernel void kernel_mul_row_c4_fuse_impl(
|
kernel void kernel_mul_row_c4_fuse_impl(
|
||||||
|
|
@ -1193,7 +1187,7 @@ kernel void kernel_mul_row_c4_fuse_impl(
|
||||||
|
|
||||||
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
|
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
|
template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
|
||||||
|
|
||||||
template <short F>
|
template <short F>
|
||||||
kernel void kernel_div_row_c4_fuse_impl(
|
kernel void kernel_div_row_c4_fuse_impl(
|
||||||
|
|
@ -1226,7 +1220,7 @@ kernel void kernel_div_row_c4_fuse_impl(
|
||||||
|
|
||||||
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
|
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
|
||||||
|
|
||||||
template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
|
template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
|
||||||
|
|
||||||
kernel void kernel_scale(
|
kernel void kernel_scale(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
|
|
|
||||||
|
|
@ -225,9 +225,9 @@ struct bin_bcast_sycl {
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream,
|
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * sycl::range<3>(1, 1, block_size),
|
sycl::range<3>(1, 1, block_size),
|
||||||
sycl::range<3>(1, 1, block_size)),
|
sycl::range<3>(1, 1, block_size)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
k_bin_bcast_unravel<bin_op>(
|
k_bin_bcast_unravel<bin_op>(
|
||||||
|
|
@ -246,8 +246,9 @@ struct bin_bcast_sycl {
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
||||||
ne2, ne3, ne10, ne11, ne12, ne13,
|
ne2, ne3, ne10, ne11, ne12, ne13,
|
||||||
s1, s2, s3, s01, s02, s03, s11, s12, s13,
|
s1, s2, s3, s01, s02, s03, s11, s12, s13,
|
||||||
|
|
|
||||||
|
|
@ -89,23 +89,32 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
||||||
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
||||||
switch (dim) {
|
switch (dim) {
|
||||||
case 0:
|
case 0:
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
sycl::nd_range<3>(gridDim *
|
||||||
|
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1); });
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
|
||||||
|
});
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
sycl::nd_range<3>(gridDim *
|
||||||
|
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1); });
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
||||||
|
});
|
||||||
break;
|
break;
|
||||||
// dim >=2 will be dispatched to the default path
|
// dim >=2 will be dispatched to the default path
|
||||||
default:
|
default:
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
sycl::nd_range<3>(gridDim *
|
||||||
|
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) { concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1); });
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
|
||||||
|
});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -120,7 +129,7 @@ static void concat_f32_sycl_non_cont(
|
||||||
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
|
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
|
||||||
uint64_t nb3, int32_t dim) {
|
uint64_t nb3, int32_t dim) {
|
||||||
sycl::range<3> gridDim(ne3, ne2, ne1);
|
sycl::range<3> gridDim(ne3, ne2, ne1);
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
int64_t i3 = item_ct1.get_group(0);
|
int64_t i3 = item_ct1.get_group(0);
|
||||||
int64_t i2 = item_ct1.get_group(1);
|
int64_t i2 = item_ct1.get_group(1);
|
||||||
int64_t i1 = item_ct1.get_group(2);
|
int64_t i1 = item_ct1.get_group(2);
|
||||||
|
|
|
||||||
|
|
@ -59,9 +59,15 @@ static void conv_transpose_1d_f32_f32_sycl(
|
||||||
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
|
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
|
||||||
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
|
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
|
||||||
const sycl::range<3> block_nums(1, 1, num_blocks);
|
const sycl::range<3> block_nums(1, 1, num_blocks);
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(
|
||||||
conv_transpose_1d_kernel(s0, output_size, src0_ne0, src0_ne1, src0_ne2, src1_ne0, dst_ne0, src0, src1, dst,
|
sycl::nd_range<3>(
|
||||||
item_ct1);
|
block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
conv_transpose_1d_kernel(
|
||||||
|
s0, output_size,
|
||||||
|
src0_ne0, src0_ne1, src0_ne2,
|
||||||
|
src1_ne0, dst_ne0,
|
||||||
|
src0, src1, dst, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,11 +33,14 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream,
|
sycl::nd_range<3>(
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
|
sycl::range<3>(1, 1, num_blocks) *
|
||||||
|
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1); });
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -50,18 +53,24 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
sycl::range<3>(1, 1, 64),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
|
sycl::range<3>(1, 1, 64)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q2_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q2_K(vx, y, item_ct1); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q2_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -76,18 +85,24 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
sycl::range<3>(1, 1, 64),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
|
sycl::range<3>(1, 1, 64)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q3_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q3_K(vx, y, item_ct1); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q3_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
@ -101,9 +116,12 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_0(vx, y, nb32, item_ct1); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q4_0(vx, y, nb32, item_ct1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -117,12 +135,13 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
|
||||||
int constexpr WARP_K = WARP_SIZE * QK4_0;
|
int constexpr WARP_K = WARP_SIZE * QK4_0;
|
||||||
const int n_warp = (k + WARP_K - 1) / WARP_K;
|
const int n_warp = (k + WARP_K - 1) / WARP_K;
|
||||||
GGML_ASSERT(k % 2 == 0);
|
GGML_ASSERT(k % 2 == 0);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * sycl::range<3>(1, 1, WARP_SIZE),
|
sycl::range<3>(1, 1, WARP_SIZE),
|
||||||
sycl::range<3>(1, 1, WARP_SIZE)),
|
sycl::range<3>(1, 1, WARP_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
|
||||||
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
|
dequantize_block_q4_0_reorder(vx, y, k, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
|
|
@ -134,9 +153,12 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q4_1(vx, y, nb32, item_ct1); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q4_1(vx, y, nb32, item_ct1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -149,10 +171,11 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
|
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
|
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
|
||||||
});
|
});
|
||||||
|
|
@ -168,10 +191,10 @@ static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const i
|
||||||
|
|
||||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler & cgh) {
|
||||||
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
|
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
|
||||||
|
|
||||||
sycl_parallel_for<1>(cgh, sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
|
cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
|
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
|
||||||
});
|
});
|
||||||
|
|
@ -187,18 +210,24 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
sycl::range<3>(1, 1, 64),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
|
sycl::range<3>(1, 1, 64)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q5_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q5_K(vx, y, item_ct1); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q5_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -213,18 +242,24 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
sycl::range<3>(1, 1, 64),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
|
sycl::range<3>(1, 1, 64)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q6_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K(vx, y, item_ct1); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_q6_K(vx, y, item_ct1);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -236,7 +271,7 @@ static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const i
|
||||||
|
|
||||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
|
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
|
||||||
}
|
}
|
||||||
|
|
@ -249,10 +284,15 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_s(vx, y, item_ct1, iq1s_grid_gpu); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq1_s(
|
||||||
|
vx, y, item_ct1, iq1s_grid_gpu
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -265,10 +305,15 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq1_m(vx, y, item_ct1, iq1s_grid_gpu); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq1_m(
|
||||||
|
vx, y, item_ct1, iq1s_grid_gpu
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -281,11 +326,14 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
dequantize_block_iq2_xxs(vx, y, item_ct1, iq2xxs_grid, ksigns_iq2xs, kmask_iq2xs);
|
dequantize_block_iq2_xxs(
|
||||||
|
vx, y, item_ct1, iq2xxs_grid,
|
||||||
|
ksigns_iq2xs, kmask_iq2xs);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -299,11 +347,14 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
dequantize_block_iq2_xs(vx, y, item_ct1, iq2xs_grid, ksigns_iq2xs, kmask_iq2xs);
|
dequantize_block_iq2_xs(
|
||||||
|
vx, y, item_ct1, iq2xs_grid,
|
||||||
|
ksigns_iq2xs, kmask_iq2xs);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -317,10 +368,13 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq2_s(vx, y, item_ct1); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq2_s(vx, y, item_ct1);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -334,11 +388,14 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
|
sycl::range<3>(1, 1, 32)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
dequantize_block_iq3_xxs(vx, y, item_ct1, iq3xxs_grid, ksigns_iq2xs, kmask_iq2xs);
|
dequantize_block_iq3_xxs(
|
||||||
|
vx, y, item_ct1, iq3xxs_grid,
|
||||||
|
ksigns_iq2xs, kmask_iq2xs);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -352,10 +409,14 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq3_s(vx, y, item_ct1, kmask_iq2xs, iq3s_grid); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq3_s(
|
||||||
|
vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -371,11 +432,14 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh,
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_xs(vx, y, item_ct1); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq4_xs(vx, y, item_ct1);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -389,11 +453,14 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh,
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
|
sycl::range<3>(1, 1, 32),
|
||||||
[=](sycl::nd_item<3> item_ct1) { dequantize_block_iq4_nl(vx, y, item_ct1); });
|
sycl::range<3>(1, 1, 32)),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
dequantize_block_iq4_nl(vx, y, item_ct1);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -201,8 +201,7 @@ static void ggml_cpy_f16_f32_sycl(const char * cx, char * cdst, const int ne, co
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream,
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
|
@ -220,8 +219,7 @@ static void ggml_cpy_f32_f32_sycl(const char * cx, char * cdst, const int ne, co
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream,
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
|
@ -239,8 +237,7 @@ static void ggml_cpy_f32_f16_sycl(const char * cx, char * cdst, const int ne, co
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream,
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
|
@ -256,7 +253,7 @@ static void ggml_cpy_f32_q8_0_sycl(const char * cx, char * cdst, const int ne, c
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
GGML_ASSERT(ne % QK8_0 == 0);
|
GGML_ASSERT(ne % QK8_0 == 0);
|
||||||
const int num_blocks = ne / QK8_0;
|
const int num_blocks = ne / QK8_0;
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
|
@ -268,7 +265,7 @@ static void ggml_cpy_q8_0_f32_sycl(const char * cx, char * cdst, const int ne, c
|
||||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
const int num_blocks = ne;
|
const int num_blocks = ne;
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
|
@ -281,7 +278,7 @@ static void ggml_cpy_f32_q4_0_sycl(const char * cx, char * cdst, const int ne, c
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
GGML_ASSERT(ne % QK4_0 == 0);
|
GGML_ASSERT(ne % QK4_0 == 0);
|
||||||
const int num_blocks = ne / QK4_0;
|
const int num_blocks = ne / QK4_0;
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
|
@ -293,9 +290,8 @@ static void ggml_cpy_q4_0_f32_sycl(const char * cx, char * cdst, const int ne, c
|
||||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
const int num_blocks = ne;
|
const int num_blocks = ne;
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||||
item_ct1);
|
item_ct1);
|
||||||
|
|
@ -308,7 +304,7 @@ static void ggml_cpy_f32_q4_1_sycl(const char * cx, char * cdst, const int ne, c
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
GGML_ASSERT(ne % QK4_1 == 0);
|
GGML_ASSERT(ne % QK4_1 == 0);
|
||||||
const int num_blocks = ne / QK4_1;
|
const int num_blocks = ne / QK4_1;
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
|
@ -320,9 +316,8 @@ static void ggml_cpy_q4_1_f32_sycl(const char * cx, char * cdst, const int ne, c
|
||||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
const int num_blocks = ne;
|
const int num_blocks = ne;
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||||
item_ct1);
|
item_ct1);
|
||||||
|
|
@ -335,7 +330,7 @@ static void ggml_cpy_f32_q5_0_sycl(const char * cx, char * cdst, const int ne, c
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
GGML_ASSERT(ne % QK5_0 == 0);
|
GGML_ASSERT(ne % QK5_0 == 0);
|
||||||
const int num_blocks = ne / QK5_0;
|
const int num_blocks = ne / QK5_0;
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
|
@ -347,9 +342,8 @@ static void ggml_cpy_q5_0_f32_sycl(const char * cx, char * cdst, const int ne, c
|
||||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
const int num_blocks = ne;
|
const int num_blocks = ne;
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||||
item_ct1);
|
item_ct1);
|
||||||
|
|
@ -362,7 +356,7 @@ static void ggml_cpy_f32_q5_1_sycl(const char * cx, char * cdst, const int ne, c
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
GGML_ASSERT(ne % QK5_1 == 0);
|
GGML_ASSERT(ne % QK5_1 == 0);
|
||||||
const int num_blocks = ne / QK5_1;
|
const int num_blocks = ne / QK5_1;
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
|
|
@ -374,9 +368,8 @@ static void ggml_cpy_q5_1_f32_sycl(const char * cx, char * cdst, const int ne, c
|
||||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
const int num_blocks = ne;
|
const int num_blocks = ne;
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
|
||||||
item_ct1);
|
item_ct1);
|
||||||
|
|
@ -389,10 +382,10 @@ static void ggml_cpy_f32_iq4_nl_sycl(const char * cx, char * cdst, const int ne,
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
GGML_ASSERT(ne % QK4_NL == 0);
|
GGML_ASSERT(ne % QK4_NL == 0);
|
||||||
const int num_blocks = ne / QK4_NL;
|
const int num_blocks = ne / QK4_NL;
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)),
|
stream->parallel_for(
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks), sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
||||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -404,8 +397,7 @@ static void ggml_cpy_f16_f16_sycl(const char * cx, char * cdst, const int ne, co
|
||||||
{
|
{
|
||||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream,
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
|
@ -424,8 +416,7 @@ static void ggml_cpy_i16_i16_sycl(const char * cx, char * cdst, const int ne, co
|
||||||
// dpct::has_capability_or_fail(stream->get_device(),
|
// dpct::has_capability_or_fail(stream->get_device(),
|
||||||
// {sycl::aspect::fp16});
|
// {sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream,
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
|
@ -444,8 +435,7 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co
|
||||||
// dpct::has_capability_or_fail(stream->get_device(),
|
// dpct::has_capability_or_fail(stream->get_device(),
|
||||||
// {sycl::aspect::fp16});
|
// {sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream,
|
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
|
@ -460,12 +450,10 @@ static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const
|
||||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
cpy_q_q<block_q8_0, QK8_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
|
||||||
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -475,12 +463,10 @@ static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const
|
||||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
cpy_q_q<block_q5_0, QK5_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
|
||||||
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -491,12 +477,10 @@ static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||||
|
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
cpy_q_q<block_q5_1, QK5_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
|
||||||
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -506,12 +490,9 @@ static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const
|
||||||
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_q_q<block_q4_0, QK4_0>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
|
||||||
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -522,12 +503,9 @@ static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const
|
||||||
const int nb12, const int nb13, queue_ptr stream) {
|
const int nb12, const int nb13, queue_ptr stream) {
|
||||||
|
|
||||||
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) {
|
||||||
sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
|
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1);
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
|
||||||
cpy_q_q<block_q4_1, QK4_1>(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11,
|
|
||||||
ne12, nb10, nb11, nb12, nb13, item_ct1);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -208,9 +208,11 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec<1, 1, convert_f16>(vx, y, dst, ncols,
|
||||||
|
nrows, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -875,10 +877,11 @@ static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloa
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(vx, y, dst, ncols,
|
dequantize_mul_mat_vec_reorder<QK4_0, QR4_0, dequantize_q4_0_reorder>(
|
||||||
nrows, item_ct1);
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -897,9 +900,11 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>(
|
||||||
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -916,9 +921,11 @@ static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>(
|
||||||
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -935,9 +942,11 @@ static void dequantize_mul_mat_vec_q5_0_sycl(const void *vx, const dfloat *y,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>(
|
||||||
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -954,9 +963,11 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>(
|
||||||
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -973,9 +984,11 @@ static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>(
|
||||||
|
vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -989,7 +1002,8 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
|
||||||
const int block_num_y = (nrows + ny - 1) / ny;
|
const int block_num_y = (nrows + ny - 1) / ny;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
@ -1004,7 +1018,8 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
||||||
const int block_num_y = (nrows + ny - 1) / ny;
|
const int block_num_y = (nrows + ny - 1) / ny;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
@ -1019,7 +1034,8 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
||||||
const int block_num_y = (nrows + ny - 1) / ny;
|
const int block_num_y = (nrows + ny - 1) / ny;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
@ -1031,7 +1047,8 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
||||||
dpct::queue_ptr stream) {
|
dpct::queue_ptr stream) {
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
GGML_ASSERT(ncols % QK_K == 0);
|
||||||
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
|
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
@ -1046,7 +1063,8 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
||||||
const int block_num_y = (nrows + ny - 1) / ny;
|
const int block_num_y = (nrows + ny - 1) / ny;
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||||
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -13,10 +13,10 @@
|
||||||
#ifndef GGML_SYCL_DPCT_HELPER_HPP
|
#ifndef GGML_SYCL_DPCT_HELPER_HPP
|
||||||
#define GGML_SYCL_DPCT_HELPER_HPP
|
#define GGML_SYCL_DPCT_HELPER_HPP
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <sycl/sycl.hpp>
|
#include <sycl/sycl.hpp>
|
||||||
#include <sycl/half_type.hpp>
|
#include <sycl/half_type.hpp>
|
||||||
#include <syclcompat/math.hpp>
|
#include <syclcompat/math.hpp>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
#ifdef GGML_SYCL_USE_INTEL_ONEMKL
|
||||||
#include <oneapi/mkl.hpp>
|
#include <oneapi/mkl.hpp>
|
||||||
|
|
@ -118,36 +118,6 @@ inline auto get_onemath_backend(sycl::queue& queue)
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
|
||||||
namespace syclex = sycl::ext::oneapi::experimental;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
template <int NR, typename Func>
|
|
||||||
__dpct_inline__ void sycl_parallel_for(sycl::handler & cgh, sycl::nd_range<NR> nd_range, Func && func) {
|
|
||||||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
|
||||||
syclex::nd_launch(cgh, nd_range, func);
|
|
||||||
#else
|
|
||||||
cgh.parallel_for(nd_range, func);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int NR, typename Func>
|
|
||||||
__dpct_inline__ void sycl_parallel_for(sycl::queue * q, sycl::nd_range<NR> nd_range, Func && func) {
|
|
||||||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
|
||||||
syclex::nd_launch(*q, nd_range, func);
|
|
||||||
#else
|
|
||||||
q->parallel_for(nd_range, func);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Func> __dpct_inline__ void sycl_launch(sycl::queue * stream, Func && func) {
|
|
||||||
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_FUNCTIONS
|
|
||||||
syclex::submit(*stream, func);
|
|
||||||
#else
|
|
||||||
stream->submit(func);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace dpct
|
namespace dpct
|
||||||
{
|
{
|
||||||
typedef sycl::queue *queue_ptr;
|
typedef sycl::queue *queue_ptr;
|
||||||
|
|
|
||||||
|
|
@ -407,7 +407,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
||||||
const int ne12, const int nb1, const int nb2,
|
const int ne12, const int nb1, const int nb2,
|
||||||
const int offset, queue_ptr stream) {
|
const int offset, queue_ptr stream) {
|
||||||
int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
|
int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) *
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) *
|
||||||
sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
|
sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
|
||||||
|
|
@ -425,8 +425,8 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
||||||
int dst_size = ne10 * ne11 * ne12 * ne13;
|
int dst_size = ne10 * ne11 * ne12 * ne13;
|
||||||
int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
|
int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
|
||||||
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
|
sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
|
||||||
sycl_parallel_for<1>(
|
stream->parallel_for(
|
||||||
stream, sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||||
upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
|
upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -437,7 +437,7 @@ static void pad_sycl(const T *x, T *dst, const int ne00,
|
||||||
const int ne1, const int ne2, queue_ptr stream) {
|
const int ne1, const int ne2, queue_ptr stream) {
|
||||||
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
|
int num_blocks = ceil_div(ne0, SYCL_PAD_BLOCK_SIZE);
|
||||||
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
|
||||||
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
|
[=](sycl::nd_item<3> item_ct1) { pad(x, dst, ne0, ne00, ne01, ne02, item_ct1); });
|
||||||
|
|
@ -639,7 +639,7 @@ static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, 256);
|
const int num_blocks = ceil_div(k_elements, 256);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||||
sycl::range<1>(256)),
|
sycl::range<1>(256)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -652,7 +652,7 @@ static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, 256);
|
const int num_blocks = ceil_div(k_elements, 256);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||||
sycl::range<1>(256)),
|
sycl::range<1>(256)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -665,7 +665,7 @@ static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, 256);
|
const int num_blocks = ceil_div(k_elements, 256);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||||
sycl::range<1>(256)),
|
sycl::range<1>(256)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -678,7 +678,7 @@ static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_SILU_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SILU_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_SILU_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -691,7 +691,7 @@ static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -704,7 +704,7 @@ static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -717,7 +717,7 @@ static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_t
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_GELU_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_GELU_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_GELU_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -730,7 +730,7 @@ static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_TANH_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_TANH_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_TANH_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -743,7 +743,7 @@ static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -756,7 +756,7 @@ static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggm
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_HARDSIGMOID_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_HARDSIGMOID_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -769,7 +769,7 @@ static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_HARDSWISH_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_HARDSWISH_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -782,7 +782,7 @@ static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -795,7 +795,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size
|
const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -808,7 +808,7 @@ static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -821,7 +821,7 @@ static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
|
const int num_blocks = ceil_div(k_elements, SYCL_NEG_BLOCK_SIZE); // Using NEG block size
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_NEG_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_NEG_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -834,7 +834,7 @@ static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_te
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_SIGMOID_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_SIGMOID_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -847,7 +847,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -860,7 +860,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -873,7 +873,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size
|
const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -888,7 +888,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -901,7 +901,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -935,7 +935,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) {
|
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) {
|
||||||
const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE);
|
const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE);
|
||||||
sycl_parallel_for(stream,
|
stream->parallel_for(
|
||||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
|
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
|
||||||
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
|
sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
|
|
@ -967,7 +967,7 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
||||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||||
sycl_parallel_for(main_stream,
|
main_stream->parallel_for(
|
||||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||||
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
@ -978,7 +978,7 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
||||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
|
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
|
||||||
sycl_parallel_for(main_stream,
|
main_stream->parallel_for(
|
||||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||||
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
@ -989,7 +989,7 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
||||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||||
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
|
const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
|
||||||
sycl_parallel_for(main_stream,
|
main_stream->parallel_for(
|
||||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||||
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
@ -1000,7 +1000,7 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
||||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||||
sycl_parallel_for(main_stream,
|
main_stream->parallel_for(
|
||||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||||
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
@ -1011,7 +1011,7 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm
|
||||||
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
|
||||||
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
[](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
|
||||||
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
|
||||||
sycl_parallel_for(main_stream,
|
main_stream->parallel_for(
|
||||||
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
|
||||||
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -118,9 +118,11 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
|
||||||
|
|
||||||
GGML_ASSERT(ne00 % 2 == 0);
|
GGML_ASSERT(ne00 % 2 == 0);
|
||||||
|
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
k_get_rows<qk, qr, dq>(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2, s3, nb01, nb02, nb03, s10, s11, s12,
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
item_ct1);
|
k_get_rows<qk, qr, dq>(
|
||||||
|
src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||||
|
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
||||||
GGML_UNUSED(dst);
|
GGML_UNUSED(dst);
|
||||||
|
|
@ -154,8 +156,9 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
|
||||||
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -1746,12 +1746,13 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||||
|
|
||||||
if (order == GGML_SORT_ORDER_ASC) {
|
if (order == GGML_SORT_ORDER_ASC) {
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
||||||
sycl::range<1>(shared_mem), cgh);
|
sycl::range<1>(shared_mem), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>(
|
||||||
x, dst, ncols, ncols_pad, item_ct1,
|
x, dst, ncols, ncols_pad, item_ct1,
|
||||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
||||||
|
|
@ -1759,12 +1760,13 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
|
||||||
sycl::range<1>(shared_mem), cgh);
|
sycl::range<1>(shared_mem), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>(
|
||||||
x, dst, ncols, ncols_pad, item_ct1,
|
x, dst, ncols, ncols_pad, item_ct1,
|
||||||
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>()
|
||||||
|
|
@ -1782,13 +1784,15 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||||
const sycl::range<3> block_nums(1, nrows, 1);
|
const sycl::range<3> block_nums(1, nrows, 1);
|
||||||
const size_t shared_mem = 256 * sizeof(float);
|
const size_t shared_mem = 256 * sizeof(float);
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<float, 1> shared_data(
|
sycl::local_accessor<float, 1> shared_data(
|
||||||
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
||||||
sycl::local_accessor<int, 1> shared_indices(
|
sycl::local_accessor<int, 1> shared_indices(
|
||||||
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
sycl::range<1>(shared_mem/sizeof(float)), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
const int tid = item_ct1.get_local_id(2);
|
const int tid = item_ct1.get_local_id(2);
|
||||||
const int row = item_ct1.get_global_id(1);
|
const int row = item_ct1.get_global_id(1);
|
||||||
|
|
||||||
|
|
@ -1807,7 +1811,7 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||||
shared_indices[tid] = max_idx;
|
shared_indices[tid] = max_idx;
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
|
||||||
for (int stride = 256 / 2; stride > 0; stride >>= 1) {
|
for (int stride = 256/2; stride > 0; stride >>= 1) {
|
||||||
if (tid < stride) {
|
if (tid < stride) {
|
||||||
float val1 = shared_data[tid];
|
float val1 = shared_data[tid];
|
||||||
float val2 = shared_data[tid + stride];
|
float val2 = shared_data[tid + stride];
|
||||||
|
|
@ -1819,6 +1823,7 @@ static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols,
|
||||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[row] = shared_indices[0];
|
dst[row] = shared_indices[0];
|
||||||
}
|
}
|
||||||
|
|
@ -2895,7 +2900,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
||||||
void ** ptrs_dst_get = ptrs_dst.get();
|
void ** ptrs_dst_get = ptrs_dst.get();
|
||||||
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
||||||
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||||
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
||||||
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
@ -3403,7 +3408,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
||||||
{
|
{
|
||||||
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
|
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
|
||||||
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
|
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 0> src1_row_acc(cgh);
|
sycl::local_accessor<int, 0> src1_row_acc(cgh);
|
||||||
|
|
||||||
char *__restrict src1_contiguous_get =
|
char *__restrict src1_contiguous_get =
|
||||||
|
|
@ -3415,8 +3420,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
||||||
size_t ids_nb_ct6 = ids->nb[1];
|
size_t ids_nb_ct6 = ids->nb[1];
|
||||||
size_t ids_nb_ct7 = ids->nb[0];
|
size_t ids_nb_ct7 = ids->nb[0];
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
k_copy_src1_to_contiguous(
|
k_copy_src1_to_contiguous(
|
||||||
src1_original, src1_contiguous_get,
|
src1_original, src1_contiguous_get,
|
||||||
dev_cur_src1_row_get,
|
dev_cur_src1_row_get,
|
||||||
|
|
@ -3447,14 +3453,15 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
||||||
{
|
{
|
||||||
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
|
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
|
||||||
sycl::range<3> grid_dims(1, 1, num_src1_rows);
|
sycl::range<3> grid_dims(1, 1, num_src1_rows);
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
const char *__restrict dst_contiguous_get =
|
const char *__restrict dst_contiguous_get =
|
||||||
dst_contiguous.get();
|
dst_contiguous.get();
|
||||||
const mmid_row_mapping *__restrict dev_row_mapping_get =
|
const mmid_row_mapping *__restrict dev_row_mapping_get =
|
||||||
dev_row_mapping.get();
|
dev_row_mapping.get();
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
k_copy_dst_from_contiguous(dst_original,
|
k_copy_dst_from_contiguous(dst_original,
|
||||||
dst_contiguous_get,
|
dst_contiguous_get,
|
||||||
dev_row_mapping_get,
|
dev_row_mapping_get,
|
||||||
|
|
|
||||||
|
|
@ -11,13 +11,13 @@ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B,
|
||||||
const u_int n_seq_tokens = T / B;
|
const u_int n_seq_tokens = T / B;
|
||||||
sycl::range<1> block_dims((C / H));
|
sycl::range<1> block_dims((C / H));
|
||||||
sycl::range<1> grid_dims((B * H));
|
sycl::range<1> grid_dims((B * H));
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler & cgh) {
|
||||||
/* local memory accessors*/
|
/* local memory accessors*/
|
||||||
auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
auto _k = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
||||||
auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
auto _r = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
||||||
auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
auto _td = sycl::local_accessor<float, 1>(sycl::range<1>(head_size), cgh);
|
||||||
|
|
||||||
sycl_parallel_for<1>(cgh, sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
|
cgh.parallel_for(sycl::nd_range<1>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<1> item) {
|
||||||
u_int tid = item.get_local_id(0);
|
u_int tid = item.get_local_id(0);
|
||||||
u_int bid = item.get_group(0);
|
u_int bid = item.get_group(0);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t I
|
||||||
|
|
||||||
const int64_t CHW = IC * KH * KW;
|
const int64_t CHW = IC * KH * KW;
|
||||||
|
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) {
|
||||||
im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,
|
im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1,
|
||||||
p0, p1, d0, d1, item_ct1);
|
p0, p1, d0, d1, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -1818,7 +1818,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
|
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
|
||||||
|
|
@ -1829,8 +1829,9 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q4_0<need_check>(
|
mul_mat_q4_0<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -1852,7 +1853,7 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
|
sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
|
||||||
|
|
@ -1863,8 +1864,9 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q4_0<need_check>(
|
mul_mat_q4_0<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -1931,7 +1933,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
|
||||||
|
|
@ -1942,8 +1944,9 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q4_1<need_check>(
|
mul_mat_q4_1<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -1965,7 +1968,7 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
|
||||||
|
|
@ -1976,8 +1979,9 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q4_1<need_check>(
|
mul_mat_q4_1<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2044,7 +2048,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
|
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
|
||||||
|
|
@ -2055,8 +2059,9 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q5_0<need_check>(
|
mul_mat_q5_0<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2078,7 +2083,7 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
|
sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
|
||||||
|
|
@ -2089,8 +2094,9 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q5_0<need_check>(
|
mul_mat_q5_0<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2157,7 +2163,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
|
||||||
|
|
@ -2168,8 +2174,9 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q5_1<need_check>(
|
mul_mat_q5_1<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2191,7 +2198,7 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
|
||||||
|
|
@ -2202,8 +2209,9 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q5_1<need_check>(
|
mul_mat_q5_1<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2270,7 +2278,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
|
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
|
||||||
|
|
@ -2281,8 +2289,9 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q8_0<need_check>(
|
mul_mat_q8_0<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2304,7 +2313,7 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
|
sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
|
||||||
|
|
@ -2315,8 +2324,9 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q8_0<need_check>(
|
mul_mat_q8_0<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2383,7 +2393,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
|
||||||
|
|
@ -2396,8 +2406,9 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q2_K<need_check>(
|
mul_mat_q2_K<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2420,7 +2431,7 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
|
||||||
|
|
@ -2433,8 +2444,9 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q2_K<need_check>(
|
mul_mat_q2_K<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2504,7 +2516,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
|
||||||
|
|
@ -2519,8 +2531,9 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q3_K<need_check>(
|
mul_mat_q3_K<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2544,7 +2557,7 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
|
||||||
|
|
@ -2559,8 +2572,9 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q3_K<need_check>(
|
mul_mat_q3_K<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2630,7 +2644,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
|
||||||
|
|
@ -2643,8 +2657,9 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q4_K<need_check>(
|
mul_mat_q4_K<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2667,7 +2682,7 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
|
||||||
|
|
@ -2680,8 +2695,9 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q4_K<need_check>(
|
mul_mat_q4_K<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2749,7 +2765,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
|
||||||
|
|
@ -2762,8 +2778,9 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q5_K<need_check>(
|
mul_mat_q5_K<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2786,7 +2803,7 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
|
||||||
|
|
@ -2799,8 +2816,9 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q5_K<need_check>(
|
mul_mat_q5_K<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2868,7 +2886,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
|
||||||
|
|
@ -2881,8 +2899,9 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q6_K<need_check>(
|
mul_mat_q6_K<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
@ -2905,7 +2924,7 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
dpct::has_capability_or_fail(stream->get_device(),
|
dpct::has_capability_or_fail(stream->get_device(),
|
||||||
{sycl::aspect::fp16});
|
{sycl::aspect::fp16});
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
|
sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
|
||||||
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
|
||||||
|
|
@ -2918,8 +2937,9 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
|
||||||
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
mul_mat_q6_K<need_check>(
|
mul_mat_q6_K<need_check>(
|
||||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||||
nrows_dst, item_ct1,
|
nrows_dst, item_ct1,
|
||||||
|
|
|
||||||
|
|
@ -544,8 +544,8 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy,
|
||||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
|
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE));
|
||||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler & cgh) {
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
|
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
|
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>>(vx, vy, dst, ncols, nrows,
|
||||||
nd_item);
|
nd_item);
|
||||||
|
|
@ -561,8 +561,8 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float *
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
|
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler & cgh) {
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
|
|
@ -580,10 +580,15 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1,
|
||||||
|
VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -599,10 +604,15 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0,
|
||||||
|
VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -618,10 +628,15 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1,
|
||||||
|
VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -637,10 +652,15 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0,
|
||||||
|
VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -656,10 +676,15 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q<QK_K, QI2_K, block_q2_K,
|
||||||
|
VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -675,10 +700,15 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q<QK_K, QI3_K, block_q3_K,
|
||||||
|
VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -694,10 +724,15 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q<QK_K, QI4_K, block_q4_K,
|
||||||
|
VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -715,11 +750,11 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy,
|
||||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
||||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler & cgh) {
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
|
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols, nrows,
|
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>>(vx, vy, dst, ncols,
|
||||||
nd_item);
|
nrows, nd_item);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -734,10 +769,15 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q<QK_K, QI5_K, block_q5_K,
|
||||||
|
VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -754,8 +794,8 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy,
|
||||||
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
|
||||||
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler & cgh) {
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_size, workgroup_size),
|
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
|
||||||
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
|
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
|
||||||
nd_item);
|
nd_item);
|
||||||
|
|
@ -771,10 +811,15 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q<QK_K, QI6_K, block_q6_K,
|
||||||
|
VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>(
|
||||||
vx, vy, dst, ncols, nrows, item_ct1);
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -791,11 +836,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
cgh.parallel_for(
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS / 2, block_iq2_xxs, 1>(vx, vy, dst, ncols,
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
nrows, item_ct1);
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q_iq2_xxs_q8_1<QK_K, QI2_XXS/2, block_iq2_xxs, 1>(
|
||||||
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -810,11 +857,13 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler & cgh) {
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
cgh.parallel_for(
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS / 2, block_iq2_xs, 1>(vx, vy, dst, ncols,
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
nrows, item_ct1);
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q_iq2_xs_q8_1<QK_K, QI2_XS/2, block_iq2_xs, 1>(
|
||||||
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -829,11 +878,14 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
cgh.parallel_for(
|
||||||
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S / 2, block_iq2_s, 1>(vx, vy, dst, ncols, nrows,
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
item_ct1);
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q_iq2_s_q8_1<QK_K, QI2_S/2, block_iq2_s, 1>(
|
||||||
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -848,11 +900,14 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
cgh.parallel_for(
|
||||||
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS / 2, block_iq3_xxs, 1>(vx, vy, dst, ncols,
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
nrows, item_ct1);
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q_iq3_xxs_q8_1<QK_K, QI3_XXS/2, block_iq3_xxs, 1>(
|
||||||
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -867,11 +922,14 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
cgh.parallel_for(
|
||||||
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S / 2, block_iq3_s, 1>(vx, vy, dst, ncols, nrows,
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
item_ct1);
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q_iq3_s_q8_1<QK_K, QI3_S/2, block_iq3_s, 1>(
|
||||||
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -886,11 +944,14 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
cgh.parallel_for(
|
||||||
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(vx, vy, dst, ncols, nrows,
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
item_ct1);
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q_iq1_s_q8_1<QK_K, QI1_S, block_iq1_s, 1>(
|
||||||
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -905,11 +966,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
cgh.parallel_for(
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(vx, vy, dst, ncols, nrows,
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
item_ct1);
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q_iq1_m_q8_1<QK_K, QI1_S, block_iq1_m, 1>(
|
||||||
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -924,11 +987,14 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
cgh.parallel_for(
|
||||||
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(vx, vy, dst, ncols, nrows,
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
item_ct1);
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q_iq4_nl_q8_1<QK4_NL, QI4_NL, block_iq4_nl, 2>(
|
||||||
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -943,11 +1009,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
|
||||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||||
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
|
||||||
{
|
{
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
cgh.parallel_for(
|
||||||
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS / 4, block_iq4_xs, 1>(vx, vy, dst, ncols,
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
nrows, item_ct1);
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
mul_mat_vec_q_iq4_xs_q8_1<QK_K, QI4_XS/4, block_iq4_xs, 1>(
|
||||||
|
vx, vy, dst, ncols, nrows, item_ct1);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -254,11 +254,12 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
||||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||||
if (ncols < 1024) {
|
if (ncols < 1024) {
|
||||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
cgh.parallel_for(
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||||
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
nullptr, WARP_SIZE);
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -271,13 +272,14 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
|
||||||
the limit. To get the device limit, query
|
the limit. To get the device limit, query
|
||||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||||
*/
|
*/
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
|
sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
|
||||||
sycl::range<1>(work_group_size / WARP_SIZE), cgh);
|
sycl::range<1>(work_group_size / WARP_SIZE), cgh);
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
cgh.parallel_for(
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||||
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -288,12 +290,16 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||||
const int ne_elements, queue_ptr stream, int device) {
|
const int ne_elements, queue_ptr stream, int device) {
|
||||||
if (group_size < 1024) {
|
if (group_size < 1024) {
|
||||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
const float eps_ct4 = eps;
|
const float eps_ct4 = eps;
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
|
cgh.parallel_for(
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||||
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1, nullptr,
|
block_dims),
|
||||||
WARP_SIZE);
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
group_norm_f32(
|
||||||
|
x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
||||||
|
nullptr, WARP_SIZE);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -307,15 +313,19 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
||||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||||
cgh);
|
cgh);
|
||||||
|
|
||||||
const float eps_ct4 = eps;
|
const float eps_ct4 = eps;
|
||||||
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims, block_dims),
|
cgh.parallel_for(
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
|
||||||
group_norm_f32(x, dst, group_size, ne_elements, eps_ct4, item_ct1,
|
block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
group_norm_f32(x, dst, group_size, ne_elements,
|
||||||
|
eps_ct4, item_ct1,
|
||||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -330,10 +340,51 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
||||||
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
|
||||||
if (ncols < 1024) {
|
if (ncols < 1024) {
|
||||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
cgh.parallel_for(
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||||
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||||
|
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||||
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||||
|
/*
|
||||||
|
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
||||||
|
the limit. To get the device limit, query
|
||||||
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||||
|
*/
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||||
|
cgh);
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||||
|
const int nrows, const float eps,
|
||||||
|
queue_ptr stream, int device) {
|
||||||
|
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||||
|
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||||
|
if (ncols < 1024) {
|
||||||
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||||
|
block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||||
nullptr, WARP_SIZE);
|
nullptr, WARP_SIZE);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
@ -347,53 +398,21 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
|
||||||
the limit. To get the device limit, query
|
the limit. To get the device limit, query
|
||||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||||
*/
|
*/
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||||
cgh);
|
cgh);
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(global_dims * block_dims, block_dims),
|
cgh.parallel_for(
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||||
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
|
block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
||||||
const int nrows, const float eps,
|
|
||||||
queue_ptr stream, int device) {
|
|
||||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
|
||||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
|
||||||
if (ncols < 1024) {
|
|
||||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
l2_norm_f32(x, dst, ncols, eps, item_ct1, nullptr, WARP_SIZE);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
||||||
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
|
||||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
|
||||||
/*
|
|
||||||
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
|
||||||
the limit. To get the device limit, query
|
|
||||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
|
||||||
*/
|
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
|
||||||
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
|
||||||
cgh);
|
|
||||||
sycl_parallel_for(cgh, sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
|
||||||
l2_norm_f32(x, dst, ncols, eps, item_ct1, get_pointer(s_sum_acc_ct1),
|
|
||||||
work_group_size);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -232,10 +232,9 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
||||||
the limit. To get the device limit, query
|
the limit. To get the device limit, query
|
||||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||||
*/
|
*/
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||||
rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
theta_scale, freq_factors, item_ct1);
|
||||||
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
/*
|
/*
|
||||||
|
|
@ -243,10 +242,9 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
||||||
the limit. To get the device limit, query
|
the limit. To get the device limit, query
|
||||||
info::device::max_work_group_size. Adjust the work-group size if needed.
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||||
*/
|
*/
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||||
rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
theta_scale, freq_factors, item_ct1);
|
||||||
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -266,16 +264,14 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
|
||||||
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
|
||||||
|
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||||
rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
theta_scale, freq_factors, item_ct1);
|
||||||
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
||||||
[=](sycl::nd_item<3> item_ct1) {
|
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
|
||||||
rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
|
theta_scale, freq_factors, item_ct1);
|
||||||
attn_factor, corr_dims, theta_scale, freq_factors, item_ct1);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -299,12 +295,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
||||||
}
|
}
|
||||||
// launch kernel
|
// launch kernel
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||||
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||||
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
@ -334,12 +330,12 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
|
||||||
}
|
}
|
||||||
// launch kernel
|
// launch kernel
|
||||||
if (freq_factors == nullptr) {
|
if (freq_factors == nullptr) {
|
||||||
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||||
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
sycl_parallel_for(stream, nd_range, [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
|
||||||
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
|
||||||
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
corr_dims, theta_scale, freq_factors, sections, item_ct1);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ static void set_rows_sycl_q(const char * __restrict__ src0_d,
|
||||||
constexpr int block_size = 256;
|
constexpr int block_size = 256;
|
||||||
const int64_t grid_size = ceil_div(total_blocks, block_size);
|
const int64_t grid_size = ceil_div(total_blocks, block_size);
|
||||||
|
|
||||||
sycl_parallel_for(stream, sycl::nd_range<1>(grid_size * block_size, block_size), [=](sycl::nd_item<1> item_ct1) {
|
stream->parallel_for(sycl::nd_range<1>(grid_size * block_size, block_size), [=](sycl::nd_item<1> item_ct1) {
|
||||||
const int64_t i = item_ct1.get_global_linear_id();
|
const int64_t i = item_ct1.get_global_linear_id();
|
||||||
if (i >= total_blocks) {
|
if (i >= total_blocks) {
|
||||||
return;
|
return;
|
||||||
|
|
@ -129,8 +129,7 @@ static void set_rows_sycl(
|
||||||
constexpr int block_size = 64;
|
constexpr int block_size = 64;
|
||||||
const int64_t grid_size = ceil_div(total_elements, block_size);
|
const int64_t grid_size = ceil_div(total_elements, block_size);
|
||||||
|
|
||||||
sycl_parallel_for(
|
stream->parallel_for(
|
||||||
stream,
|
|
||||||
sycl::nd_range<1>(grid_size * block_size, block_size),
|
sycl::nd_range<1>(grid_size * block_size, block_size),
|
||||||
[=](sycl::nd_item<1> item_ct1) {
|
[=](sycl::nd_item<1> item_ct1) {
|
||||||
k_set_rows<TIn, TOut>(
|
k_set_rows<TIn, TOut>(
|
||||||
|
|
|
||||||
|
|
@ -127,11 +127,11 @@ static void soft_max_f32_submitter(const float * x, const T * mask, float * dst,
|
||||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||||
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
||||||
const size_t n_local_scratch, queue_ptr stream) {
|
const size_t n_local_scratch, queue_ptr stream) {
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler &cgh) {
|
||||||
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
|
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||||
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
||||||
nrows_y, scale, max_bias, m0,
|
nrows_y, scale, max_bias, m0,
|
||||||
|
|
|
||||||
|
|
@ -45,8 +45,13 @@ static void timestep_embedding_f32_sycl(
|
||||||
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
|
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
|
||||||
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
|
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
|
||||||
sycl::range<3> gridDim(1, ne00, num_blocks);
|
sycl::range<3> gridDim(1, ne00, num_blocks);
|
||||||
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
stream->parallel_for(
|
||||||
timestep_embedding_f32(x, dst, nb1, dim, max_period, item_ct1);
|
sycl::nd_range<3>(
|
||||||
|
gridDim * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
|
timestep_embedding_f32(
|
||||||
|
x, dst, nb1, dim, max_period, item_ct1
|
||||||
|
);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -207,11 +207,12 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||||
|
|
||||||
// Submit kernel
|
// Submit kernel
|
||||||
if (C / H == WKV_BLOCK_SIZE) {
|
if (C / H == WKV_BLOCK_SIZE) {
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
|
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
|
||||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||||
|
|
@ -219,11 +220,12 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||||
|
|
@ -262,11 +264,12 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||||
|
|
||||||
// Submit kernel
|
// Submit kernel
|
||||||
if (C / H == WKV_BLOCK_SIZE) {
|
if (C / H == WKV_BLOCK_SIZE) {
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
|
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
|
||||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||||
|
|
@ -274,11 +277,12 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
sycl_launch(stream, [&](sycl::handler & cgh) {
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||||
|
|
||||||
sycl_parallel_for(
|
cgh.parallel_for(
|
||||||
cgh, sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1) {
|
||||||
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
|
||||||
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
|
||||||
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,14 @@
|
||||||
#include "ggml-cpu.h"
|
#include "ggml-cpu.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
|
||||||
|
#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1
|
||||||
|
|
||||||
#include <vulkan/vulkan.hpp>
|
#include <vulkan/vulkan.hpp>
|
||||||
|
|
||||||
|
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
|
||||||
|
VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
|
@ -121,6 +127,8 @@ struct vk_pipeline_struct {
|
||||||
bool needed {};
|
bool needed {};
|
||||||
// set to true when the shader has been compiled
|
// set to true when the shader has been compiled
|
||||||
bool compiled {};
|
bool compiled {};
|
||||||
|
// number of registers used, extracted from pipeline executable properties
|
||||||
|
uint32_t register_count {};
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
|
typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
|
||||||
|
|
@ -429,6 +437,8 @@ struct vk_device_struct {
|
||||||
|
|
||||||
bool coopmat2;
|
bool coopmat2;
|
||||||
|
|
||||||
|
bool pipeline_executable_properties_support {};
|
||||||
|
|
||||||
size_t idx;
|
size_t idx;
|
||||||
|
|
||||||
bool mul_mat_l[GGML_TYPE_COUNT];
|
bool mul_mat_l[GGML_TYPE_COUNT];
|
||||||
|
|
@ -1221,8 +1231,6 @@ static std::string format_size(size_t size) {
|
||||||
return oss.str();
|
return oss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::mutex log_mutex;
|
|
||||||
|
|
||||||
class vk_memory_logger {
|
class vk_memory_logger {
|
||||||
public:
|
public:
|
||||||
vk_memory_logger(): total_device(0), total_host(0) {}
|
vk_memory_logger(): total_device(0), total_host(0) {}
|
||||||
|
|
@ -1412,6 +1420,8 @@ struct ggml_backend_vk_buffer_context {
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
||||||
|
static std::mutex log_mutex;
|
||||||
|
|
||||||
void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
|
void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
|
||||||
std::lock_guard<std::mutex> guard(log_mutex);
|
std::lock_guard<std::mutex> guard(log_mutex);
|
||||||
vk_buffer buf = buf_ref.lock();
|
vk_buffer buf = buf_ref.lock();
|
||||||
|
|
@ -1603,6 +1613,20 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
||||||
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
|
vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (device->pipeline_executable_properties_support) {
|
||||||
|
vk::PipelineExecutableInfoKHR executableInfo;
|
||||||
|
executableInfo.pipeline = pipeline->pipeline;
|
||||||
|
|
||||||
|
auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);
|
||||||
|
for (auto & s : statistics) {
|
||||||
|
// "Register Count" is reported by NVIDIA drivers.
|
||||||
|
if (strcmp(s.name, "Register Count") == 0) {
|
||||||
|
VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers");
|
||||||
|
pipeline->register_count = (uint32_t)s.value.u64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
||||||
device->all_pipelines.push_back(pipeline);
|
device->all_pipelines.push_back(pipeline);
|
||||||
|
|
@ -1960,7 +1984,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (buf->device_memory == VK_NULL_HANDLE) {
|
if (!buf->device_memory) {
|
||||||
device->device.destroyBuffer(buf->buffer);
|
device->device.destroyBuffer(buf->buffer);
|
||||||
throw vk::OutOfDeviceMemoryError("No suitable memory type found");
|
throw vk::OutOfDeviceMemoryError("No suitable memory type found");
|
||||||
}
|
}
|
||||||
|
|
@ -3610,6 +3634,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
bool amd_shader_core_properties2 = false;
|
bool amd_shader_core_properties2 = false;
|
||||||
bool pipeline_robustness = false;
|
bool pipeline_robustness = false;
|
||||||
bool coopmat2_support = false;
|
bool coopmat2_support = false;
|
||||||
|
bool pipeline_executable_properties_support = false;
|
||||||
device->coopmat_support = false;
|
device->coopmat_support = false;
|
||||||
device->integer_dot_product = false;
|
device->integer_dot_product = false;
|
||||||
bool bfloat16_support = false;
|
bool bfloat16_support = false;
|
||||||
|
|
@ -3652,6 +3677,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
||||||
bfloat16_support = true;
|
bfloat16_support = true;
|
||||||
#endif
|
#endif
|
||||||
|
} else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
|
||||||
|
pipeline_executable_properties_support = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -3878,8 +3905,18 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||||
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
|
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {};
|
||||||
|
pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR;
|
||||||
|
if (pipeline_executable_properties_support) {
|
||||||
|
last_struct->pNext = (VkBaseOutStructure *)&pep_features;
|
||||||
|
last_struct = (VkBaseOutStructure *)&pep_features;
|
||||||
|
device_extensions.push_back("VK_KHR_pipeline_executable_properties");
|
||||||
|
}
|
||||||
|
|
||||||
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
||||||
|
|
||||||
|
device->pipeline_executable_properties_support = pipeline_executable_properties_support;
|
||||||
|
|
||||||
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
||||||
|
|
||||||
#if defined(VK_KHR_shader_bfloat16)
|
#if defined(VK_KHR_shader_bfloat16)
|
||||||
|
|
@ -4395,6 +4432,9 @@ static void ggml_vk_instance_init() {
|
||||||
}
|
}
|
||||||
VK_LOG_DEBUG("ggml_vk_instance_init()");
|
VK_LOG_DEBUG("ggml_vk_instance_init()");
|
||||||
|
|
||||||
|
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
|
||||||
|
VULKAN_HPP_DEFAULT_DISPATCHER.init(vkGetInstanceProcAddr);
|
||||||
|
|
||||||
uint32_t api_version = vk::enumerateInstanceVersion();
|
uint32_t api_version = vk::enumerateInstanceVersion();
|
||||||
|
|
||||||
if (api_version < VK_API_VERSION_1_2) {
|
if (api_version < VK_API_VERSION_1_2) {
|
||||||
|
|
@ -4462,6 +4502,9 @@ static void ggml_vk_instance_init() {
|
||||||
|
|
||||||
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
|
||||||
|
|
||||||
|
// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
|
||||||
|
VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance);
|
||||||
|
|
||||||
std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
|
std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
|
||||||
|
|
||||||
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
|
||||||
|
|
@ -4497,7 +4540,7 @@ static void ggml_vk_instance_init() {
|
||||||
new_driver.pNext = &new_id;
|
new_driver.pNext = &new_id;
|
||||||
devices[i].getProperties2(&new_props);
|
devices[i].getProperties2(&new_props);
|
||||||
|
|
||||||
if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) {
|
if (new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) {
|
||||||
// Check if there are two physical devices corresponding to the same GPU
|
// Check if there are two physical devices corresponding to the same GPU
|
||||||
auto old_device = std::find_if(
|
auto old_device = std::find_if(
|
||||||
vk_instance.device_indices.begin(),
|
vk_instance.device_indices.begin(),
|
||||||
|
|
@ -4567,7 +4610,7 @@ static void ggml_vk_instance_init() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no dedicated GPUs found, fall back to the first non-CPU device.
|
// If no GPUs found, fall back to the first non-CPU device.
|
||||||
// If only CPU devices are available, return without devices.
|
// If only CPU devices are available, return without devices.
|
||||||
if (vk_instance.device_indices.empty()) {
|
if (vk_instance.device_indices.empty()) {
|
||||||
for (size_t i = 0; i < devices.size(); i++) {
|
for (size_t i = 0; i < devices.size(); i++) {
|
||||||
|
|
@ -12078,12 +12121,63 @@ void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) {
|
||||||
|
GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());
|
||||||
|
|
||||||
|
vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];
|
||||||
|
|
||||||
|
vk::PhysicalDeviceProperties2 props = {};
|
||||||
|
device.getProperties2(&props);
|
||||||
|
|
||||||
|
return props.properties.deviceType;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
|
||||||
|
GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());
|
||||||
|
|
||||||
|
vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];
|
||||||
|
|
||||||
|
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
|
||||||
|
|
||||||
|
bool ext_support = false;
|
||||||
|
|
||||||
|
for (const auto& properties : ext_props) {
|
||||||
|
if (strcmp("VK_EXT_pci_bus_info", properties.extensionName) == 0) {
|
||||||
|
ext_support = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ext_support) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
vk::PhysicalDeviceProperties2 props = {};
|
||||||
|
vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {};
|
||||||
|
|
||||||
|
props.pNext = &pci_bus_info;
|
||||||
|
|
||||||
|
device.getProperties2(&props);
|
||||||
|
|
||||||
|
const uint32_t pci_domain = pci_bus_info.pciDomain;
|
||||||
|
const uint32_t pci_bus = pci_bus_info.pciBus;
|
||||||
|
const uint32_t pci_device = pci_bus_info.pciDevice;
|
||||||
|
const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning
|
||||||
|
|
||||||
|
char pci_bus_id[16] = {};
|
||||||
|
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
|
||||||
|
|
||||||
|
return std::string(pci_bus_id);
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////
|
//////////////////////////
|
||||||
|
|
||||||
struct ggml_backend_vk_device_context {
|
struct ggml_backend_vk_device_context {
|
||||||
size_t device;
|
size_t device;
|
||||||
std::string name;
|
std::string name;
|
||||||
std::string description;
|
std::string description;
|
||||||
|
bool is_integrated_gpu;
|
||||||
|
std::string pci_bus_id;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
|
static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
|
||||||
|
|
@ -12112,14 +12206,18 @@ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(gg
|
||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
|
static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
|
||||||
UNUSED(dev);
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
|
||||||
|
return ctx->is_integrated_gpu ? GGML_BACKEND_DEVICE_TYPE_IGPU : GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||||
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||||
|
|
||||||
props->name = ggml_backend_vk_device_get_name(dev);
|
props->name = ggml_backend_vk_device_get_name(dev);
|
||||||
props->description = ggml_backend_vk_device_get_description(dev);
|
props->description = ggml_backend_vk_device_get_description(dev);
|
||||||
props->type = ggml_backend_vk_device_get_type(dev);
|
props->type = ggml_backend_vk_device_get_type(dev);
|
||||||
|
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
|
||||||
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||||
props->caps = {
|
props->caps = {
|
||||||
/* .async = */ false,
|
/* .async = */ false,
|
||||||
|
|
@ -12386,8 +12484,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
}
|
}
|
||||||
|
|
||||||
if (
|
if (
|
||||||
src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32 ||
|
(src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) ||
|
||||||
src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32
|
(src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32)
|
||||||
) {
|
) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
@ -12552,6 +12650,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg,
|
||||||
ctx->device = i;
|
ctx->device = i;
|
||||||
ctx->name = GGML_VK_NAME + std::to_string(i);
|
ctx->name = GGML_VK_NAME + std::to_string(i);
|
||||||
ctx->description = desc;
|
ctx->description = desc;
|
||||||
|
ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
|
||||||
|
ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
|
||||||
devices.push_back(new ggml_backend_device {
|
devices.push_back(new ggml_backend_device {
|
||||||
/* .iface = */ ggml_backend_vk_device_i,
|
/* .iface = */ ggml_backend_vk_device_i,
|
||||||
/* .reg = */ reg,
|
/* .reg = */ reg,
|
||||||
|
|
@ -13038,16 +13138,16 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
} else if (tensor->op == GGML_OP_IM2COL_3D) {
|
} else if (tensor->op == GGML_OP_IM2COL_3D) {
|
||||||
const int32_t s0 = tensor->op_params[0];
|
const int32_t s0 = tensor->op_params[0];
|
||||||
const int32_t s1 = tensor->op_params[1];
|
const int32_t s1 = tensor->op_params[1];
|
||||||
const int32_t s1 = tensor->op_params[2];
|
const int32_t s2 = tensor->op_params[2];
|
||||||
const int32_t p0 = tensor->op_params[3];
|
const int32_t p0 = tensor->op_params[3];
|
||||||
const int32_t p1 = tensor->op_params[4];
|
const int32_t p1 = tensor->op_params[4];
|
||||||
const int32_t p1 = tensor->op_params[5];
|
const int32_t p2 = tensor->op_params[5];
|
||||||
const int32_t d0 = tensor->op_params[6];
|
const int32_t d0 = tensor->op_params[6];
|
||||||
const int32_t d1 = tensor->op_params[7];
|
const int32_t d1 = tensor->op_params[7];
|
||||||
const int32_t d1 = tensor->op_params[8];
|
const int32_t d2 = tensor->op_params[8];
|
||||||
const int32_t IC = tensor->op_params[9];
|
const int32_t IC = tensor->op_params[9];
|
||||||
|
|
||||||
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
|
tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
|
||||||
} else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
|
} else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
|
||||||
const int32_t dim = tensor->op_params[0];
|
const int32_t dim = tensor->op_params[0];
|
||||||
const int32_t max_period = tensor->op_params[1];
|
const int32_t max_period = tensor->op_params[1];
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ void main() {
|
||||||
uint qs = data_a[ib].qs[4 * ib32 + l];
|
uint qs = data_a[ib].qs[4 * ib32 + l];
|
||||||
const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];
|
const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];
|
||||||
qs |= (qh << (8 - 2 * l)) & 0x300;
|
qs |= (qh << (8 - 2 * l)) & 0x300;
|
||||||
const uvec2 grid = iq2s_grid[qs & 511];
|
const uvec2 grid = iq2s_grid[qs];
|
||||||
const u8vec4 grid0 = unpack8(grid.x);
|
const u8vec4 grid0 = unpack8(grid.x);
|
||||||
const u8vec4 grid1 = unpack8(grid.y);
|
const u8vec4 grid1 = unpack8(grid.y);
|
||||||
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));
|
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,8 @@ void main() {
|
||||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||||
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
|
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
|
||||||
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
|
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
|
||||||
const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]];
|
const uint qs = data_a[ib].qs[8 * is + l];
|
||||||
|
const uvec2 grid = iq2xxs_grid[qs];
|
||||||
const u8vec4 grid0 = unpack8(grid.x);
|
const u8vec4 grid0 = unpack8(grid.x);
|
||||||
const u8vec4 grid1 = unpack8(grid.y);
|
const u8vec4 grid1 = unpack8(grid.y);
|
||||||
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
|
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
|
||||||
|
|
|
||||||
|
|
@ -22,15 +22,16 @@ void main() {
|
||||||
const uint b_idx = 256 * ib + 32 * is;
|
const uint b_idx = 256 * ib + 32 * is;
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
const float d = float(data_a[ib].d);
|
||||||
const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf));
|
const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf));
|
||||||
|
|
||||||
// We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.
|
// We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.
|
||||||
uint qh = data_a[ib].qh[is];
|
uint qh = data_a[ib].qh[is];
|
||||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||||
uint qs = data_a[ib].qs[8 * is + l];
|
const uint iqs = 8 * is + l;
|
||||||
uint gidx = qs | ((qh << (8 - l)) & 256);
|
const uint qs = data_a[ib].qs[iqs];
|
||||||
uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1));
|
const uint gidx = qs | ((qh << (8 - l)) & 256);
|
||||||
u8vec4 grid = unpack8(iq3s_grid[gidx]);
|
const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1));
|
||||||
|
const u8vec4 grid = unpack8(iq3s_grid[gidx]);
|
||||||
data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));
|
data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));
|
||||||
data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));
|
data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));
|
||||||
data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));
|
data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));
|
||||||
|
|
|
||||||
|
|
@ -35,8 +35,10 @@ void main() {
|
||||||
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
|
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
|
||||||
// Restore parity bit.
|
// Restore parity bit.
|
||||||
const uint sign8 = sign7 | (bitCount(sign7) << 7);
|
const uint sign8 = sign7 | (bitCount(sign7) << 7);
|
||||||
const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]);
|
const uint qs0 = data_a[ib].qs[8 * is + 2 * l];
|
||||||
const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]);
|
const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1];
|
||||||
|
const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]);
|
||||||
|
const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]);
|
||||||
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
|
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
|
||||||
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
|
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
|
||||||
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
|
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
|
||||||
|
|
|
||||||
|
|
@ -183,6 +183,8 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "mul_mm_funcs.comp"
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||||
init_iq_shmem(gl_WorkGroupSize);
|
init_iq_shmem(gl_WorkGroupSize);
|
||||||
|
|
@ -310,550 +312,13 @@ void main() {
|
||||||
|
|
||||||
for (uint block = start_k; block < end_k; block += BK) {
|
for (uint block = start_k; block < end_k; block += BK) {
|
||||||
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
|
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
|
||||||
|
load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block + loadr_a, end_k);
|
||||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
|
||||||
#if LOAD_VEC_A == 8
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
A_TYPE32 aa = A_TYPE32(data_a[idx]);
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(aa[0].x);
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(aa[0].y);
|
|
||||||
buf_a[buf_idx + 2] = FLOAT_TYPE(aa[0].z);
|
|
||||||
buf_a[buf_idx + 3] = FLOAT_TYPE(aa[0].w);
|
|
||||||
buf_a[buf_idx + 4] = FLOAT_TYPE(aa[1].x);
|
|
||||||
buf_a[buf_idx + 5] = FLOAT_TYPE(aa[1].y);
|
|
||||||
buf_a[buf_idx + 6] = FLOAT_TYPE(aa[1].z);
|
|
||||||
buf_a[buf_idx + 7] = FLOAT_TYPE(aa[1].w);
|
|
||||||
#elif LOAD_VEC_A == 4
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
A_TYPE32 aa = A_TYPE32(data_a[idx]);
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(aa.x);
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(aa.y);
|
|
||||||
buf_a[buf_idx + 2] = FLOAT_TYPE(aa.z);
|
|
||||||
buf_a[buf_idx + 3] = FLOAT_TYPE(aa.w);
|
|
||||||
#else
|
|
||||||
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
|
||||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
|
||||||
} else {
|
|
||||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#elif defined(DATA_A_BF16)
|
|
||||||
#if LOAD_VEC_A == 4
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
|
|
||||||
buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
|
|
||||||
buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
|
|
||||||
buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
|
|
||||||
#else
|
|
||||||
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
|
||||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
|
||||||
} else {
|
|
||||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
#elif defined(DATA_A_Q4_0)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
|
|
||||||
|
|
||||||
const uint ib = idx / 4;
|
|
||||||
const uint iqs = idx & 0x03;
|
|
||||||
|
|
||||||
const float d = float(data_a_packed16[ib].d);
|
|
||||||
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
|
|
||||||
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
|
|
||||||
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
|
|
||||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
|
|
||||||
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
|
|
||||||
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
|
|
||||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
|
|
||||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
|
|
||||||
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
|
|
||||||
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
|
|
||||||
#elif defined(DATA_A_Q4_1)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
|
|
||||||
|
|
||||||
const uint ib = idx / 4;
|
|
||||||
const uint iqs = idx & 0x03;
|
|
||||||
|
|
||||||
const float d = float(data_a_packed16[ib].d);
|
|
||||||
const float m = float(data_a_packed16[ib].m);
|
|
||||||
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
|
|
||||||
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
|
|
||||||
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
|
|
||||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
|
|
||||||
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
|
|
||||||
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
|
|
||||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
|
|
||||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
|
|
||||||
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
|
|
||||||
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
|
|
||||||
#elif defined(DATA_A_Q5_0)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
|
|
||||||
|
|
||||||
const uint ib = idx / 8;
|
|
||||||
const uint iqs = idx & 0x07;
|
|
||||||
|
|
||||||
const float d = float(data_a_packed16[ib].d);
|
|
||||||
const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
|
|
||||||
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
|
|
||||||
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
|
|
||||||
|
|
||||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
|
||||||
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
|
||||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
|
|
||||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
|
||||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
|
|
||||||
#elif defined(DATA_A_Q5_1)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
|
|
||||||
|
|
||||||
const uint ib = idx / 8;
|
|
||||||
const uint iqs = idx & 0x07;
|
|
||||||
|
|
||||||
const float d = float(data_a_packed16[ib].d);
|
|
||||||
const float m = float(data_a_packed16[ib].m);
|
|
||||||
const uint uint_qh = data_a_packed16[ib].qh;
|
|
||||||
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
|
|
||||||
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
|
|
||||||
|
|
||||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
|
||||||
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
|
||||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
|
|
||||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
|
||||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
|
|
||||||
#elif defined(DATA_A_Q8_0)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 8;
|
|
||||||
const uint iqs = idx & 0x07;
|
|
||||||
|
|
||||||
const float d = float(data_a_packed16[ib].d);
|
|
||||||
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
|
|
||||||
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
|
|
||||||
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
|
||||||
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
|
|
||||||
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
|
|
||||||
#elif defined(DATA_A_Q2_K)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
|
||||||
const uint iqs = idx % 128; // 0..127
|
|
||||||
|
|
||||||
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
|
|
||||||
const uint scalesi = iqs / 8; // 0..15
|
|
||||||
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
|
||||||
|
|
||||||
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
|
|
||||||
const uint scales = data_a[ib].scales[scalesi];
|
|
||||||
const vec2 d = vec2(data_a[ib].d);
|
|
||||||
|
|
||||||
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
|
||||||
#elif defined(DATA_A_Q3_K)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
|
||||||
const uint iqs = idx % 128; // 0..127
|
|
||||||
|
|
||||||
const uint n = iqs / 64; // 0,1
|
|
||||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
|
||||||
const uint hmi = (iqs % 16) * 2; // 0,2,4..30
|
|
||||||
const uint j = (iqs % 64) / 4; // 0..3
|
|
||||||
const uint is = iqs / 8; // 0..15
|
|
||||||
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
|
|
||||||
const uint qsshift = halfsplit * 2; // 0,2,4,6
|
|
||||||
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
|
|
||||||
|
|
||||||
const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
|
|
||||||
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
|
|
||||||
const float dl = float(data_a[ib].d) * float(us - 32);
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
|
||||||
#elif defined(DATA_A_Q4_K)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
|
||||||
const uint iqs = idx % 128; // 0..127
|
|
||||||
|
|
||||||
const uint n = iqs / 32; // 0,1,2,3
|
|
||||||
const uint b = (iqs % 32) / 16; // 0,1
|
|
||||||
const uint is = 2 * n + b; // 0..7
|
|
||||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
|
||||||
|
|
||||||
const vec2 loadd = vec2(data_a[ib].d);
|
|
||||||
|
|
||||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
|
||||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
|
||||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
|
||||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
|
||||||
const uint mbidx0 = is + 4;
|
|
||||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
|
||||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
|
||||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
|
||||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
|
||||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
|
||||||
|
|
||||||
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
|
||||||
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
|
||||||
|
|
||||||
const float d = loadd.x * sc;
|
|
||||||
const float m = -loadd.y * mbyte;
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m));
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
|
||||||
#elif defined(DATA_A_Q5_K)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
|
||||||
const uint iqs = idx % 128; // 0..127
|
|
||||||
|
|
||||||
const uint n = iqs / 32; // 0,1,2,3
|
|
||||||
const uint b = (iqs % 32) / 16; // 0,1
|
|
||||||
const uint is = 2 * n + b; // 0..7
|
|
||||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
|
||||||
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
|
|
||||||
|
|
||||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
|
||||||
|
|
||||||
const vec2 loadd = vec2(data_a[ib].d);
|
|
||||||
|
|
||||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
|
||||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
|
||||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
|
||||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
|
||||||
const uint mbidx0 = is + 4;
|
|
||||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
|
||||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
|
||||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
|
||||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
|
||||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
|
||||||
|
|
||||||
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
|
||||||
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
|
||||||
|
|
||||||
const float d = loadd.x * sc;
|
|
||||||
const float m = -loadd.y * mbyte;
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m));
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
|
||||||
#elif defined(DATA_A_Q6_K)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
|
||||||
const uint iqs = idx % 128; // 0..127
|
|
||||||
|
|
||||||
const uint n = iqs / 64; // 0,1
|
|
||||||
const uint b = (iqs % 64) / 32; // 0,1
|
|
||||||
const uint is_b = (iqs % 16) / 8; // 0,1
|
|
||||||
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
|
||||||
const uint is = 8 * n + qhshift + is_b; // 0..15
|
|
||||||
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
|
|
||||||
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
|
||||||
|
|
||||||
const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
|
||||||
#elif defined(DATA_A_IQ1_S)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 32; // 8 values per idx
|
|
||||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
|
||||||
const uint ib8 = idx % 32;
|
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
|
||||||
const uint qh = data_a[ib].qh[ib32];
|
|
||||||
const uint qs = data_a[ib].qs[ib8];
|
|
||||||
const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
|
|
||||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
|
||||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
|
||||||
|
|
||||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
|
||||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
|
||||||
}
|
|
||||||
#elif defined(DATA_A_IQ1_M)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 32; // 8 values per idx
|
|
||||||
const uint ib8 = idx % 32;
|
|
||||||
const uint ib16 = ib8 / 2;
|
|
||||||
|
|
||||||
const uint16_t[4] scales = data_a[ib].scales;
|
|
||||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
|
||||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
|
||||||
const uint sc = scales[ib8 / 8];
|
|
||||||
const uint qs = data_a[ib].qs[ib8];
|
|
||||||
const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));
|
|
||||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
|
||||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
|
||||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
|
||||||
|
|
||||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
|
||||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
|
||||||
}
|
|
||||||
#elif defined(DATA_A_IQ2_XXS)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 32; // 8 values per idx
|
|
||||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
|
||||||
const uint ib8 = idx % 4;
|
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
|
||||||
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
|
||||||
const uint signs = pack32(u8vec4(
|
|
||||||
data_a[ib].qs[8*ib32 + 4],
|
|
||||||
data_a[ib].qs[8*ib32 + 5],
|
|
||||||
data_a[ib].qs[8*ib32 + 6],
|
|
||||||
data_a[ib].qs[8*ib32 + 7]
|
|
||||||
));
|
|
||||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
|
|
||||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
|
||||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
|
||||||
const uvec2 grid = iq2xxs_grid[qs];
|
|
||||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
|
||||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
|
||||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
|
||||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
|
||||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
|
||||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
|
||||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
|
||||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
|
||||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
|
||||||
#elif defined(DATA_A_IQ2_XS)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 32; // 8 values per idx
|
|
||||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
|
||||||
const uint ib8 = idx % 4; // 0..3
|
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
|
||||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
|
||||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
|
||||||
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
|
||||||
const uint sign7 = qs >> 9;
|
|
||||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
|
||||||
const uvec2 grid = iq2xs_grid[qs & 511];
|
|
||||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
|
||||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
|
||||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
|
||||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
|
||||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
|
||||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
|
||||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
|
||||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
|
||||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
|
||||||
#elif defined(DATA_A_IQ2_S)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 32; // 8 values per idx
|
|
||||||
const uint ib8 = idx % 32; // 0..31
|
|
||||||
const uint ib32 = ib8 / 4; // 0..7
|
|
||||||
|
|
||||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
|
||||||
const uint qs = data_a[ib].qs[ib8];
|
|
||||||
const uint qh = data_a[ib].qh[ib32];
|
|
||||||
const uint qhshift = 2 * (ib8 % 4);
|
|
||||||
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
|
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
|
||||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
|
||||||
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
|
|
||||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
|
||||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
|
||||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
|
||||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
|
||||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
|
||||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
|
||||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
|
||||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
|
||||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
|
||||||
#elif defined(DATA_A_IQ3_XXS)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 64; // 4 values per idx
|
|
||||||
const uint iqs = idx % 64; // 0..63
|
|
||||||
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
|
||||||
const uint qs = data_a[ib].qs[iqs];
|
|
||||||
const uint signs = pack32(u8vec4(
|
|
||||||
data_a[ib].qs[is+0],
|
|
||||||
data_a[ib].qs[is+1],
|
|
||||||
data_a[ib].qs[is+2],
|
|
||||||
data_a[ib].qs[is+3]
|
|
||||||
));
|
|
||||||
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
|
||||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
|
||||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
|
|
||||||
const uint grid = iq3xxs_grid[qs];
|
|
||||||
const vec4 v = db * vec4(unpack8(grid));
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
|
||||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
|
||||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
|
||||||
#elif defined(DATA_A_IQ3_S)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 64; // 4 values per idx
|
|
||||||
const uint iqs = idx % 64; // 0..63
|
|
||||||
const uint iqh = iqs / 8;
|
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
|
||||||
const uint qs = data_a[ib].qs[iqs];
|
|
||||||
const uint qh = data_a[ib].qh[iqh];
|
|
||||||
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
|
|
||||||
const uint scale = data_a[ib].scales[iqs / 16];
|
|
||||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
|
||||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
|
||||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
|
||||||
const vec4 v = db * vec4(unpack8(grid));
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
|
||||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
|
||||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
|
||||||
#elif defined(DATA_A_IQ4_XS)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
|
||||||
|
|
||||||
const uint ib = idx / 128; // 2 values per idx
|
|
||||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
|
||||||
const uint iq = 16 * ib32 + 2 * (idx % 8);
|
|
||||||
|
|
||||||
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
|
|
||||||
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
|
|
||||||
const uint qshift = (idx & 8) >> 1;
|
|
||||||
u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
|
|
||||||
qs = (qs >> qshift) & uint8_t(0xF);
|
|
||||||
|
|
||||||
const float d = float(data_a[ib].d);
|
|
||||||
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
|
||||||
#elif defined(DATA_A_IQ4_NL)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
|
|
||||||
|
|
||||||
const uint ib = idx / 8;
|
|
||||||
const uint iqs = idx & 0x07;
|
|
||||||
|
|
||||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
|
|
||||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
|
|
||||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
|
|
||||||
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
|
|
||||||
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
|
|
||||||
#elif defined(DATA_A_MXFP4)
|
|
||||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
|
||||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
|
|
||||||
|
|
||||||
const uint ib = idx / 8;
|
|
||||||
const uint iqs = (idx & 0x07) * 2;
|
|
||||||
|
|
||||||
const float d = e8m0_to_fp32(data_a[ib].e);
|
|
||||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
|
||||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
|
||||||
|
|
||||||
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d);
|
|
||||||
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d);
|
|
||||||
buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d);
|
|
||||||
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d);
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
||||||
#if LOAD_VEC_B == 8
|
#if !defined(MUL_MAT_ID)
|
||||||
#ifdef MUL_MAT_ID
|
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block + loadr_b, end_k);
|
||||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
|
||||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
|
||||||
#else
|
#else
|
||||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block + loadr_b, end_k);
|
||||||
#endif
|
|
||||||
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
|
||||||
#if defined(DATA_B_BF16)
|
|
||||||
B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
|
|
||||||
#else
|
|
||||||
B_TYPE32 bb = B_TYPE32(data_b[idx]);
|
|
||||||
#endif
|
|
||||||
buf_b[buf_idx + 0] = FLOAT_TYPE(bb[0].x);
|
|
||||||
buf_b[buf_idx + 1] = FLOAT_TYPE(bb[0].y);
|
|
||||||
buf_b[buf_idx + 2] = FLOAT_TYPE(bb[0].z);
|
|
||||||
buf_b[buf_idx + 3] = FLOAT_TYPE(bb[0].w);
|
|
||||||
buf_b[buf_idx + 4] = FLOAT_TYPE(bb[1].x);
|
|
||||||
buf_b[buf_idx + 5] = FLOAT_TYPE(bb[1].y);
|
|
||||||
buf_b[buf_idx + 6] = FLOAT_TYPE(bb[1].z);
|
|
||||||
buf_b[buf_idx + 7] = FLOAT_TYPE(bb[1].w);
|
|
||||||
#elif LOAD_VEC_B == 4
|
|
||||||
#ifdef MUL_MAT_ID
|
|
||||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
|
||||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
|
||||||
#else
|
|
||||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
|
||||||
#endif
|
|
||||||
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
|
||||||
#if defined(DATA_B_BF16)
|
|
||||||
B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
|
|
||||||
#else
|
|
||||||
B_TYPE32 bb = B_TYPE32(data_b[idx]);
|
|
||||||
#endif
|
|
||||||
buf_b[buf_idx + 0] = FLOAT_TYPE(bb.x);
|
|
||||||
buf_b[buf_idx + 1] = FLOAT_TYPE(bb.y);
|
|
||||||
buf_b[buf_idx + 2] = FLOAT_TYPE(bb.z);
|
|
||||||
buf_b[buf_idx + 3] = FLOAT_TYPE(bb.w);
|
|
||||||
#elif !MUL_MAT_ID
|
|
||||||
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
|
||||||
} else {
|
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
const uint row_i = ic * BN + loadc_b + l;
|
|
||||||
if (row_i < _ne1 && block + loadr_b < end_k) {
|
|
||||||
const u16vec2 row_idx = row_ids[loadc_b + l];
|
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
|
||||||
} else {
|
|
||||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,568 @@
|
||||||
|
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint idx_k, const uint end_k) {
|
||||||
|
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||||
|
#if LOAD_VEC_A == 8
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]);
|
||||||
|
buf_a[buf_idx ] = aa[0].x;
|
||||||
|
buf_a[buf_idx + 1] = aa[0].y;
|
||||||
|
buf_a[buf_idx + 2] = aa[0].z;
|
||||||
|
buf_a[buf_idx + 3] = aa[0].w;
|
||||||
|
buf_a[buf_idx + 4] = aa[1].x;
|
||||||
|
buf_a[buf_idx + 5] = aa[1].y;
|
||||||
|
buf_a[buf_idx + 6] = aa[1].z;
|
||||||
|
buf_a[buf_idx + 7] = aa[1].w;
|
||||||
|
#elif LOAD_VEC_A == 4
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
|
||||||
|
buf_a[buf_idx ] = aa.x;
|
||||||
|
buf_a[buf_idx + 1] = aa.y;
|
||||||
|
buf_a[buf_idx + 2] = aa.z;
|
||||||
|
buf_a[buf_idx + 3] = aa.w;
|
||||||
|
#else
|
||||||
|
if (idx_m < p.M && idx_k < end_k) {
|
||||||
|
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE(data_a[pos_a + col * p.stride_a + row]);
|
||||||
|
} else {
|
||||||
|
buf_a[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#elif defined(DATA_A_BF16)
|
||||||
|
#if LOAD_VEC_A == 4
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
|
||||||
|
buf_a[buf_idx ] = aa.x;
|
||||||
|
buf_a[buf_idx + 1] = aa.y;
|
||||||
|
buf_a[buf_idx + 2] = aa.z;
|
||||||
|
buf_a[buf_idx + 3] = aa.w;
|
||||||
|
#else
|
||||||
|
if (idx_m < p.M && idx_k < end_k) {
|
||||||
|
buf_a[col * SHMEM_STRIDE + row] = TO_FLOAT_TYPE(data_a[pos_a + col * p.stride_a + row]);
|
||||||
|
} else {
|
||||||
|
buf_a[col * SHMEM_STRIDE + row] = TO_FLOAT_TYPE(uint16_t(0));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#elif defined(DATA_A_Q4_0)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + 4 * row;
|
||||||
|
|
||||||
|
const uint ib = idx / 4;
|
||||||
|
const uint iqs = idx & 0x03;
|
||||||
|
|
||||||
|
const float d = float(data_a_packed16[ib].d);
|
||||||
|
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
|
||||||
|
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
|
||||||
|
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
|
||||||
|
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
|
||||||
|
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
|
||||||
|
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
|
||||||
|
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
|
||||||
|
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
|
||||||
|
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
|
||||||
|
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
|
||||||
|
#elif defined(DATA_A_Q4_1)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + 4 * row;
|
||||||
|
|
||||||
|
const uint ib = idx / 4;
|
||||||
|
const uint iqs = idx & 0x03;
|
||||||
|
|
||||||
|
const float d = float(data_a_packed16[ib].d);
|
||||||
|
const float m = float(data_a_packed16[ib].m);
|
||||||
|
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
|
||||||
|
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
|
||||||
|
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
|
||||||
|
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
|
||||||
|
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
|
||||||
|
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
|
||||||
|
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
|
||||||
|
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
|
||||||
|
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
|
||||||
|
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
|
||||||
|
#elif defined(DATA_A_Q5_0)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||||
|
|
||||||
|
const uint ib = idx / 8;
|
||||||
|
const uint iqs = idx & 0x07;
|
||||||
|
|
||||||
|
const float d = float(data_a_packed16[ib].d);
|
||||||
|
const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
|
||||||
|
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
|
||||||
|
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
|
||||||
|
|
||||||
|
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||||
|
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||||
|
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
|
||||||
|
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||||
|
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
|
||||||
|
#elif defined(DATA_A_Q5_1)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||||
|
|
||||||
|
const uint ib = idx / 8;
|
||||||
|
const uint iqs = idx & 0x07;
|
||||||
|
|
||||||
|
const float d = float(data_a_packed16[ib].d);
|
||||||
|
const float m = float(data_a_packed16[ib].m);
|
||||||
|
const uint uint_qh = data_a_packed16[ib].qh;
|
||||||
|
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
|
||||||
|
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
|
||||||
|
|
||||||
|
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||||
|
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||||
|
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
|
||||||
|
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||||
|
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
|
||||||
|
#elif defined(DATA_A_Q8_0)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 8;
|
||||||
|
const uint iqs = idx & 0x07;
|
||||||
|
|
||||||
|
const float d = float(data_a_packed16[ib].d);
|
||||||
|
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
|
||||||
|
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
|
||||||
|
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||||
|
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||||
|
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
|
||||||
|
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
|
||||||
|
#elif defined(DATA_A_Q2_K)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 128; // 2 values per idx
|
||||||
|
const uint iqs = idx % 128; // 0..127
|
||||||
|
|
||||||
|
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
|
||||||
|
const uint scalesi = iqs / 8; // 0..15
|
||||||
|
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||||
|
|
||||||
|
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
|
||||||
|
const uint scales = data_a[ib].scales[scalesi];
|
||||||
|
const vec2 d = vec2(data_a[ib].d);
|
||||||
|
|
||||||
|
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||||
|
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||||
|
#elif defined(DATA_A_Q3_K)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 128; // 2 values per idx
|
||||||
|
const uint iqs = idx % 128; // 0..127
|
||||||
|
|
||||||
|
const uint n = iqs / 64; // 0,1
|
||||||
|
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
||||||
|
const uint hmi = (iqs % 16) * 2; // 0,2,4..30
|
||||||
|
const uint j = (iqs % 64) / 4; // 0..3
|
||||||
|
const uint is = iqs / 8; // 0..15
|
||||||
|
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
|
||||||
|
const uint qsshift = halfsplit * 2; // 0,2,4,6
|
||||||
|
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
|
||||||
|
|
||||||
|
const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
|
||||||
|
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
|
||||||
|
const float dl = float(data_a[ib].d) * float(us - 32);
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
|
||||||
|
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
||||||
|
#elif defined(DATA_A_Q4_K)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 128; // 2 values per idx
|
||||||
|
const uint iqs = idx % 128; // 0..127
|
||||||
|
|
||||||
|
const uint n = iqs / 32; // 0,1,2,3
|
||||||
|
const uint b = (iqs % 32) / 16; // 0,1
|
||||||
|
const uint is = 2 * n + b; // 0..7
|
||||||
|
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||||
|
|
||||||
|
const vec2 loadd = vec2(data_a[ib].d);
|
||||||
|
|
||||||
|
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||||
|
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||||
|
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||||
|
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||||
|
const uint mbidx0 = is + 4;
|
||||||
|
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||||
|
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||||
|
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||||
|
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||||
|
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||||
|
|
||||||
|
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||||
|
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||||
|
|
||||||
|
const float d = loadd.x * sc;
|
||||||
|
const float m = -loadd.y * mbyte;
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m));
|
||||||
|
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
||||||
|
#elif defined(DATA_A_Q5_K)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 128; // 2 values per idx
|
||||||
|
const uint iqs = idx % 128; // 0..127
|
||||||
|
|
||||||
|
const uint n = iqs / 32; // 0,1,2,3
|
||||||
|
const uint b = (iqs % 32) / 16; // 0,1
|
||||||
|
const uint is = 2 * n + b; // 0..7
|
||||||
|
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||||
|
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
|
||||||
|
|
||||||
|
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||||
|
|
||||||
|
const vec2 loadd = vec2(data_a[ib].d);
|
||||||
|
|
||||||
|
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||||
|
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||||
|
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||||
|
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||||
|
const uint mbidx0 = is + 4;
|
||||||
|
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||||
|
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||||
|
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||||
|
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||||
|
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||||
|
|
||||||
|
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||||
|
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||||
|
|
||||||
|
const float d = loadd.x * sc;
|
||||||
|
const float m = -loadd.y * mbyte;
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m));
|
||||||
|
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
||||||
|
#elif defined(DATA_A_Q6_K)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 128; // 2 values per idx
|
||||||
|
const uint iqs = idx % 128; // 0..127
|
||||||
|
|
||||||
|
const uint n = iqs / 64; // 0,1
|
||||||
|
const uint b = (iqs % 64) / 32; // 0,1
|
||||||
|
const uint is_b = (iqs % 16) / 8; // 0,1
|
||||||
|
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||||
|
const uint is = 8 * n + qhshift + is_b; // 0..15
|
||||||
|
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
|
||||||
|
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
||||||
|
|
||||||
|
const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
|
||||||
|
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||||
|
#elif defined(DATA_A_IQ1_S)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 32; // 8 values per idx
|
||||||
|
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||||
|
const uint ib8 = idx % 32;
|
||||||
|
|
||||||
|
const float d = float(data_a[ib].d);
|
||||||
|
const uint qh = data_a[ib].qh[ib32];
|
||||||
|
const uint qs = data_a[ib].qs[ib8];
|
||||||
|
const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||||
|
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||||
|
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||||
|
|
||||||
|
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||||
|
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||||
|
}
|
||||||
|
#elif defined(DATA_A_IQ1_M)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 32; // 8 values per idx
|
||||||
|
const uint ib8 = idx % 32;
|
||||||
|
const uint ib16 = ib8 / 2;
|
||||||
|
|
||||||
|
const uint16_t[4] scales = data_a[ib].scales;
|
||||||
|
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||||
|
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||||
|
const uint sc = scales[ib8 / 8];
|
||||||
|
const uint qs = data_a[ib].qs[ib8];
|
||||||
|
const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));
|
||||||
|
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||||
|
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||||
|
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||||
|
|
||||||
|
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||||
|
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||||
|
}
|
||||||
|
#elif defined(DATA_A_IQ2_XXS)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 32; // 8 values per idx
|
||||||
|
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||||
|
const uint ib8 = idx % 4;
|
||||||
|
|
||||||
|
const float d = float(data_a[ib].d);
|
||||||
|
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
||||||
|
const uint signs = pack32(u8vec4(
|
||||||
|
data_a[ib].qs[8*ib32 + 4],
|
||||||
|
data_a[ib].qs[8*ib32 + 5],
|
||||||
|
data_a[ib].qs[8*ib32 + 6],
|
||||||
|
data_a[ib].qs[8*ib32 + 7]
|
||||||
|
));
|
||||||
|
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
|
||||||
|
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
||||||
|
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||||
|
const uvec2 grid = iq2xxs_grid[qs];
|
||||||
|
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||||
|
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||||
|
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||||
|
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||||
|
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||||
|
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||||
|
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||||
|
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||||
|
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||||
|
#elif defined(DATA_A_IQ2_XS)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 32; // 8 values per idx
|
||||||
|
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||||
|
const uint ib8 = idx % 4; // 0..3
|
||||||
|
|
||||||
|
const float d = float(data_a[ib].d);
|
||||||
|
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||||
|
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||||
|
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
||||||
|
const uint sign7 = qs >> 9;
|
||||||
|
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||||
|
const uvec2 grid = iq2xs_grid[qs & 511];
|
||||||
|
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||||
|
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||||
|
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||||
|
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||||
|
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||||
|
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||||
|
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||||
|
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||||
|
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||||
|
#elif defined(DATA_A_IQ2_S)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 32; // 8 values per idx
|
||||||
|
const uint ib8 = idx % 32; // 0..31
|
||||||
|
const uint ib32 = ib8 / 4; // 0..7
|
||||||
|
|
||||||
|
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||||
|
const uint qs = data_a[ib].qs[ib8];
|
||||||
|
const uint qh = data_a[ib].qh[ib32];
|
||||||
|
const uint qhshift = 2 * (ib8 % 4);
|
||||||
|
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
|
||||||
|
|
||||||
|
const float d = float(data_a[ib].d);
|
||||||
|
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||||
|
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
|
||||||
|
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||||
|
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||||
|
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||||
|
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||||
|
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||||
|
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||||
|
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||||
|
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||||
|
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||||
|
#elif defined(DATA_A_IQ3_XXS)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 64; // 4 values per idx
|
||||||
|
const uint iqs = idx % 64; // 0..63
|
||||||
|
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
||||||
|
|
||||||
|
const float d = float(data_a[ib].d);
|
||||||
|
const uint qs = data_a[ib].qs[iqs];
|
||||||
|
const uint signs = pack32(u8vec4(
|
||||||
|
data_a[ib].qs[is+0],
|
||||||
|
data_a[ib].qs[is+1],
|
||||||
|
data_a[ib].qs[is+2],
|
||||||
|
data_a[ib].qs[is+3]
|
||||||
|
));
|
||||||
|
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
||||||
|
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
||||||
|
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
|
||||||
|
const uint grid = iq3xxs_grid[qs];
|
||||||
|
const vec4 v = db * vec4(unpack8(grid));
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||||
|
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||||
|
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||||
|
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||||
|
#elif defined(DATA_A_IQ3_S)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 64; // 4 values per idx
|
||||||
|
const uint iqs = idx % 64; // 0..63
|
||||||
|
const uint iqh = iqs / 8;
|
||||||
|
|
||||||
|
const float d = float(data_a[ib].d);
|
||||||
|
const uint qs = data_a[ib].qs[iqs];
|
||||||
|
const uint qh = data_a[ib].qh[iqh];
|
||||||
|
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
|
||||||
|
const uint scale = data_a[ib].scales[iqs / 16];
|
||||||
|
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
||||||
|
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||||
|
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||||
|
const vec4 v = db * vec4(unpack8(grid));
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||||
|
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||||
|
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||||
|
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||||
|
#elif defined(DATA_A_IQ4_XS)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A;
|
||||||
|
|
||||||
|
const uint ib = idx / 128; // 2 values per idx
|
||||||
|
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||||
|
const uint iq = 16 * ib32 + 2 * (idx % 8);
|
||||||
|
|
||||||
|
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
|
||||||
|
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
|
||||||
|
const uint qshift = (idx & 8) >> 1;
|
||||||
|
u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
|
||||||
|
qs = (qs >> qshift) & uint8_t(0xF);
|
||||||
|
|
||||||
|
const float d = float(data_a[ib].d);
|
||||||
|
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||||
|
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||||
|
#elif defined(DATA_A_IQ4_NL)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||||
|
|
||||||
|
const uint ib = idx / 8;
|
||||||
|
const uint iqs = idx & 0x07;
|
||||||
|
|
||||||
|
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||||
|
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
|
||||||
|
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
|
||||||
|
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
|
||||||
|
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
|
||||||
|
#elif defined(DATA_A_MXFP4)
|
||||||
|
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||||
|
|
||||||
|
const uint ib = idx / 8;
|
||||||
|
const uint iqs = (idx & 0x07) * 2;
|
||||||
|
|
||||||
|
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||||
|
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||||
|
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||||
|
|
||||||
|
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d);
|
||||||
|
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d);
|
||||||
|
buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d);
|
||||||
|
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !defined(MUL_MAT_ID)
|
||||||
|
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint idx_k, const uint end_k) {
|
||||||
|
#if LOAD_VEC_B == 8
|
||||||
|
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||||
|
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B;
|
||||||
|
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
||||||
|
buf_b[buf_idx + 0] = bb[0].x;
|
||||||
|
buf_b[buf_idx + 1] = bb[0].y;
|
||||||
|
buf_b[buf_idx + 2] = bb[0].z;
|
||||||
|
buf_b[buf_idx + 3] = bb[0].w;
|
||||||
|
buf_b[buf_idx + 4] = bb[1].x;
|
||||||
|
buf_b[buf_idx + 5] = bb[1].y;
|
||||||
|
buf_b[buf_idx + 6] = bb[1].z;
|
||||||
|
buf_b[buf_idx + 7] = bb[1].w;
|
||||||
|
#elif LOAD_VEC_B == 4
|
||||||
|
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B;
|
||||||
|
#if defined(DATA_B_BF16)
|
||||||
|
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
||||||
|
#else
|
||||||
|
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
||||||
|
#endif
|
||||||
|
buf_b[buf_idx + 0] = bb.x;
|
||||||
|
buf_b[buf_idx + 1] = bb.y;
|
||||||
|
buf_b[buf_idx + 2] = bb.z;
|
||||||
|
buf_b[buf_idx + 3] = bb.w;
|
||||||
|
#else // LOAD_VEC_B == 1
|
||||||
|
if (idx_n < p.N && idx_k < end_k) {
|
||||||
|
buf_b[col * SHMEM_STRIDE + row] = TO_FLOAT_TYPE(data_b[pos_b + col * p.stride_b + row]);
|
||||||
|
} else {
|
||||||
|
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint idx_k, const uint end_k) {
|
||||||
|
#if LOAD_VEC_B == 8
|
||||||
|
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||||
|
const u16vec2 row_idx = row_ids[col];
|
||||||
|
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B;
|
||||||
|
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
||||||
|
buf_b[buf_idx + 0] = bb[0].x;
|
||||||
|
buf_b[buf_idx + 1] = bb[0].y;
|
||||||
|
buf_b[buf_idx + 2] = bb[0].z;
|
||||||
|
buf_b[buf_idx + 3] = bb[0].w;
|
||||||
|
buf_b[buf_idx + 4] = bb[1].x;
|
||||||
|
buf_b[buf_idx + 5] = bb[1].y;
|
||||||
|
buf_b[buf_idx + 6] = bb[1].z;
|
||||||
|
buf_b[buf_idx + 7] = bb[1].w;
|
||||||
|
#elif LOAD_VEC_B == 4
|
||||||
|
const u16vec2 row_idx = row_ids[col];
|
||||||
|
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||||
|
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B;
|
||||||
|
#if defined(DATA_B_BF16)
|
||||||
|
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
||||||
|
#else
|
||||||
|
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
||||||
|
#endif
|
||||||
|
buf_b[buf_idx + 0] = bb.x;
|
||||||
|
buf_b[buf_idx + 1] = bb.y;
|
||||||
|
buf_b[buf_idx + 2] = bb.z;
|
||||||
|
buf_b[buf_idx + 3] = bb.w;
|
||||||
|
#else // LOAD_VEC_B == 1
|
||||||
|
const uint row_i = ic * BN + col;
|
||||||
|
if (row_i < _ne1 && idx_k < end_k) {
|
||||||
|
const u16vec2 row_idx = row_ids[col];
|
||||||
|
buf_b[col * SHMEM_STRIDE + row] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row]);
|
||||||
|
} else {
|
||||||
|
buf_b[col * SHMEM_STRIDE + row] = FLOAT_TYPE(0.0f);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
@ -13,13 +13,10 @@
|
||||||
|
|
||||||
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
||||||
#define A_TYPE float
|
#define A_TYPE float
|
||||||
#define A_TYPE32 float
|
|
||||||
#elif LOAD_VEC_A == 4
|
#elif LOAD_VEC_A == 4
|
||||||
#define A_TYPE vec4
|
#define A_TYPE vec4
|
||||||
#define A_TYPE32 vec4
|
|
||||||
#elif LOAD_VEC_A == 8
|
#elif LOAD_VEC_A == 8
|
||||||
#define A_TYPE mat2x4
|
#define A_TYPE mat2x4
|
||||||
#define A_TYPE32 mat2x4
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -29,13 +26,10 @@
|
||||||
|
|
||||||
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
||||||
#define A_TYPE float16_t
|
#define A_TYPE float16_t
|
||||||
#define A_TYPE32 float
|
|
||||||
#elif LOAD_VEC_A == 4
|
#elif LOAD_VEC_A == 4
|
||||||
#define A_TYPE f16vec4
|
#define A_TYPE f16vec4
|
||||||
#define A_TYPE32 vec4
|
|
||||||
#elif LOAD_VEC_A == 8
|
#elif LOAD_VEC_A == 8
|
||||||
#define A_TYPE f16mat2x4
|
#define A_TYPE f16mat2x4
|
||||||
#define A_TYPE32 mat2x4
|
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -320,9 +320,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||||
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
||||||
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
||||||
|
|
||||||
std::map<std::string, std::string> base_dict = {
|
std::map<std::string, std::string> base_dict;
|
||||||
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
|
||||||
};
|
|
||||||
std::string shader_name = "matmul";
|
std::string shader_name = "matmul";
|
||||||
|
|
||||||
if (matmul_id_type == MatMulIdType::DEFAULT) {
|
if (matmul_id_type == MatMulIdType::DEFAULT) {
|
||||||
|
|
@ -349,7 +347,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||||
|
|
||||||
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
||||||
|
|
||||||
auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
|
auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string {
|
||||||
|
switch (vec) {
|
||||||
|
case 1:
|
||||||
if (t == "bf16") {
|
if (t == "bf16") {
|
||||||
// scalar path promotes to float
|
// scalar path promotes to float
|
||||||
if (!coopmat && !coopmat2) {
|
if (!coopmat && !coopmat2) {
|
||||||
|
|
@ -361,14 +361,60 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||||
return "float16_t";
|
return "float16_t";
|
||||||
}
|
}
|
||||||
return "float";
|
return "float";
|
||||||
|
case 2:
|
||||||
|
if (t == "bf16") {
|
||||||
|
// scalar path promotes to float
|
||||||
|
if (!coopmat && !coopmat2) {
|
||||||
|
return "vec2";
|
||||||
|
}
|
||||||
|
return "bf16vec2";
|
||||||
|
}
|
||||||
|
if (coopmat2 || fp16) {
|
||||||
|
return "f16vec2";
|
||||||
|
}
|
||||||
|
return "vec2";
|
||||||
|
case 4:
|
||||||
|
if (t == "bf16") {
|
||||||
|
// scalar path promotes to float
|
||||||
|
if (!coopmat && !coopmat2) {
|
||||||
|
return "vec4";
|
||||||
|
}
|
||||||
|
return "bf16vec4";
|
||||||
|
}
|
||||||
|
if (coopmat2 || fp16) {
|
||||||
|
return "f16vec4";
|
||||||
|
}
|
||||||
|
return "vec4";
|
||||||
|
case 8:
|
||||||
|
if (t == "bf16") {
|
||||||
|
// scalar path promotes to float
|
||||||
|
if (!coopmat && !coopmat2) {
|
||||||
|
return "mat2x4";
|
||||||
|
}
|
||||||
|
throw std::runtime_error("bf16 vec8 not supported");
|
||||||
|
}
|
||||||
|
if (coopmat2 || fp16) {
|
||||||
|
return "f16mat2x4";
|
||||||
|
}
|
||||||
|
return "mat2x4";
|
||||||
|
default:
|
||||||
|
throw std::runtime_error("invalid vector size");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::map<std::string, std::string> float_type_dict_f16 = {
|
||||||
|
{"FLOAT_TYPE", FLOAT_TYPE(1, "f16")},
|
||||||
|
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")},
|
||||||
|
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")},
|
||||||
|
{"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")},
|
||||||
};
|
};
|
||||||
|
|
||||||
// Shaders with f16 B_TYPE
|
// Shaders with f16 B_TYPE
|
||||||
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||||
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
|
||||||
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
|
|
||||||
// bf16
|
// bf16
|
||||||
{
|
{
|
||||||
|
|
@ -379,13 +425,19 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||||
// scalar path promotes to float
|
// scalar path promotes to float
|
||||||
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
|
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
|
||||||
|
|
||||||
|
const std::map<std::string, std::string> float_type_dict_bf16 = {
|
||||||
|
{"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")},
|
||||||
|
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")},
|
||||||
|
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")},
|
||||||
|
};
|
||||||
|
|
||||||
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
|
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
|
||||||
#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||||
if (!(coopmat || coopmat2))
|
if (!(coopmat || coopmat2))
|
||||||
#endif
|
#endif
|
||||||
{
|
{
|
||||||
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE32", "vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -406,20 +458,27 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||||
// For aligned matmul loads
|
// For aligned matmul loads
|
||||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||||
|
|
||||||
|
const std::map<std::string, std::string> float_type_dict = {
|
||||||
|
{"FLOAT_TYPE", FLOAT_TYPE(1, tname)},
|
||||||
|
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)},
|
||||||
|
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)},
|
||||||
|
{"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)},
|
||||||
|
};
|
||||||
|
|
||||||
// don't generate f32 variants for coopmat2
|
// don't generate f32 variants for coopmat2
|
||||||
if (!coopmat2) {
|
if (!coopmat2) {
|
||||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tname != "f16" && tname != "f32") {
|
if (tname != "f16" && tname != "f32") {
|
||||||
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||||
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
|
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
|
||||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,7 @@ struct ggml_backend_zdnn_context {
|
||||||
|
|
||||||
struct ggml_backend_zdnn_buffer {
|
struct ggml_backend_zdnn_buffer {
|
||||||
void * data;
|
void * data;
|
||||||
|
ggml_backend_zdnn_buffer * extra; // for bias, etc.
|
||||||
size_t size;
|
size_t size;
|
||||||
|
|
||||||
zdnn_tensor_desc pre_tfm_desc;
|
zdnn_tensor_desc pre_tfm_desc;
|
||||||
|
|
|
||||||
|
|
@ -115,9 +115,7 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten
|
||||||
ggml_backend_zdnn_buffer * weights_extra = (ggml_backend_zdnn_buffer *)weights->extra;
|
ggml_backend_zdnn_buffer * weights_extra = (ggml_backend_zdnn_buffer *)weights->extra;
|
||||||
ggml_backend_zdnn_buffer * inputs_extra = (ggml_backend_zdnn_buffer *)inputs->extra;
|
ggml_backend_zdnn_buffer * inputs_extra = (ggml_backend_zdnn_buffer *)inputs->extra;
|
||||||
ggml_backend_zdnn_buffer * output_extra = (ggml_backend_zdnn_buffer *)output->extra;
|
ggml_backend_zdnn_buffer * output_extra = (ggml_backend_zdnn_buffer *)output->extra;
|
||||||
|
ggml_backend_zdnn_buffer * bias_extra = (ggml_backend_zdnn_buffer *)output_extra->extra;
|
||||||
zdnn_tensor_desc ptd_bias, td_bias;
|
|
||||||
zdnn_ztensor zt_bias;
|
|
||||||
|
|
||||||
const int64_t weights_rows = ne01;
|
const int64_t weights_rows = ne01;
|
||||||
const int64_t weights_cols = ne00;
|
const int64_t weights_cols = ne00;
|
||||||
|
|
@ -129,14 +127,6 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten
|
||||||
const int64_t output_rows = ne1;
|
const int64_t output_rows = ne1;
|
||||||
const int64_t output_cols = ne0;
|
const int64_t output_cols = ne0;
|
||||||
|
|
||||||
const int64_t bias_dim [GGML_MAX_DIMS] = { 1, 1, 1, output_cols };
|
|
||||||
ggml_zdnn_create_tensor(ptd_bias, td_bias, zt_bias, output, bias_dim, ZDNN_1D);
|
|
||||||
|
|
||||||
void * bias_data = (void *)calloc(ne0, ggml_element_size(output));
|
|
||||||
if (weights_extra->ztensor.is_transformed == false) ggml_zdnn_load_tensor(weights_extra->ztensor, weights->data);
|
|
||||||
if (inputs_extra->ztensor.is_transformed == false) ggml_zdnn_load_tensor(inputs_extra->ztensor, inputs->data);
|
|
||||||
ggml_zdnn_load_tensor(zt_bias, bias_data);
|
|
||||||
|
|
||||||
// GGML_LOG_INFO("%s: tensor '%s' tensor dimensions: [%ld, %ld, %ld, %ld] pre_tfm_desc dimensions: [%ld, %ld, %ld, %ld]\n",
|
// GGML_LOG_INFO("%s: tensor '%s' tensor dimensions: [%ld, %ld, %ld, %ld] pre_tfm_desc dimensions: [%ld, %ld, %ld, %ld]\n",
|
||||||
// __func__, weights_extra->name,
|
// __func__, weights_extra->name,
|
||||||
// weights->ne[3], weights->ne[2], weights->ne[1], weights->ne[0],
|
// weights->ne[3], weights->ne[2], weights->ne[1], weights->ne[0],
|
||||||
|
|
@ -158,29 +148,21 @@ static void ggml_zdnn_mul_mat_op(ggml_backend_zdnn_context * ctx, const ggml_ten
|
||||||
GGML_ASSERT(inputs_extra->pre_tfm_desc.dim1 == inputs->ne[0] && "inputs_extra->pre_tfm_desc.dim1 must match inputs->ne[0]");
|
GGML_ASSERT(inputs_extra->pre_tfm_desc.dim1 == inputs->ne[0] && "inputs_extra->pre_tfm_desc.dim1 must match inputs->ne[0]");
|
||||||
GGML_ASSERT(inputs_extra->pre_tfm_desc.dim2 == inputs->ne[1] && "inputs_extra->pre_tfm_desc.dim2 must match inputs->ne[1]");
|
GGML_ASSERT(inputs_extra->pre_tfm_desc.dim2 == inputs->ne[1] && "inputs_extra->pre_tfm_desc.dim2 must match inputs->ne[1]");
|
||||||
|
|
||||||
ZDNN_CHECK(zdnn_matmul_transpose_op(&inputs_extra->ztensor, &weights_extra->ztensor, &zt_bias,
|
ZDNN_CHECK(zdnn_matmul_transpose_op(&inputs_extra->ztensor, &weights_extra->ztensor, &bias_extra->ztensor,
|
||||||
false, true, MATMUL_OP_ADDITION, &output_extra->ztensor));
|
false, true, MATMUL_OP_ADDITION, &output_extra->ztensor));
|
||||||
// TODO: Remove in the future as we are currently DLF16 -> FP32 then in the next op, FP32 -> DLF16 again. Inefficient.
|
// TODO: Remove in the future as we are currently DLF16 -> FP32 then in the next op, FP32 -> DLF16 again. Inefficient.
|
||||||
ZDNN_CHECK(zdnn_transform_origtensor(&output_extra->ztensor, output->data));
|
ZDNN_CHECK(zdnn_transform_origtensor(&output_extra->ztensor, output->data));
|
||||||
|
|
||||||
ZDNN_CHECK(zdnn_free_ztensor_buffer(&zt_bias));
|
GGML_UNUSED(ctx);
|
||||||
free(bias_data);
|
GGML_UNUSED(weights_rows);
|
||||||
|
GGML_UNUSED(weights_cols);
|
||||||
|
GGML_UNUSED(inputs_rows);
|
||||||
|
GGML_UNUSED(inputs_cols);
|
||||||
|
GGML_UNUSED(output_rows);
|
||||||
|
GGML_UNUSED(output_cols);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_zdnn_mul_mat_dispatch(ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_zdnn_mul_mat_dispatch(ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
bool use_mul_mat_vec =
|
|
||||||
(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F16)
|
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
||||||
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
|
||||||
|
|
||||||
bool use_mul_mat_vec_q =
|
|
||||||
ggml_is_quantized(src0->type)
|
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
|
||||||
|
|
||||||
bool use_mul_mat_q =
|
|
||||||
ggml_is_quantized(src0->type)
|
|
||||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
|
||||||
|
|
||||||
// debug helpers
|
// debug helpers
|
||||||
// GGML_LOG_INFO("%s: use_mul_mat_vec = %d\n", __func__, use_mul_mat_vec);
|
// GGML_LOG_INFO("%s: use_mul_mat_vec = %d\n", __func__, use_mul_mat_vec);
|
||||||
// GGML_LOG_INFO("%s: use_mul_mat_vec_q = %d\n", __func__, use_mul_mat_vec_q);
|
// GGML_LOG_INFO("%s: use_mul_mat_vec_q = %d\n", __func__, use_mul_mat_vec_q);
|
||||||
|
|
@ -192,25 +174,7 @@ static void ggml_zdnn_mul_mat_dispatch(ggml_backend_zdnn_context * ctx, const gg
|
||||||
// GGML_LOG_INFO("%s: src0 is contiguous %d, transposed %d, type = %s, name = %s\n", __func__, ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
// GGML_LOG_INFO("%s: src0 is contiguous %d, transposed %d, type = %s, name = %s\n", __func__, ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||||
// GGML_LOG_INFO("%s: src1 is contiguous %d, transposed %d, type = %s, name = %s\n", __func__, ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
// GGML_LOG_INFO("%s: src1 is contiguous %d, transposed %d, type = %s, name = %s\n", __func__, ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||||
|
|
||||||
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16
|
|
||||||
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1)
|
|
||||||
&& src1->ne[2] * src1->ne[3] > 1) {
|
|
||||||
// general KQ + KQV multi-batch
|
|
||||||
GGML_LOG_INFO("%s: using zdnn_mul_mat_batched for KQ + KQV multi-batch\n", __func__);
|
|
||||||
// ggml_zdnn_mul_mat_batched(ctx, src0, src1, dst);
|
|
||||||
} else if (use_mul_mat_vec) {
|
|
||||||
GGML_LOG_INFO("%s: using zdnn_op_mul_mat_vec for vector multiplication\n", __func__);
|
|
||||||
// ggml_zdnn_op_mul_mat(ctx, src0, src1, dst, ggml_zdnn_op_mul_mat_vec, nullptr);
|
|
||||||
} else if (use_mul_mat_vec_q) {
|
|
||||||
GGML_LOG_INFO("%s: using zdnn_op_mul_mat_vec_q for quantized vector multiplication\n", __func__);
|
|
||||||
// ggml_zdnn_op_mul_mat(ctx, src0, src1, dst, ggml_zdnn_op_mul_mat_vec_q, ggml_zdnn_quantize_row_q8_1);
|
|
||||||
} else if (use_mul_mat_q) {
|
|
||||||
GGML_LOG_INFO("%s: using zdnn_op_mul_mat_q for quantized matrix multiplication\n", __func__);
|
|
||||||
// ggml_zdnn_op_mul_mat(ctx, src0, src1, dst, ggml_zdnn_op_mul_mat_q, ggml_zdnn_quantize_mmq_q8_1);
|
|
||||||
} else {
|
|
||||||
// GGML_LOG_INFO("%s: using zdnn_op_mul_mat for general matrix multiplication\n", __func__);
|
|
||||||
ggml_zdnn_mul_mat_op(ctx, src0, src1, dst);
|
ggml_zdnn_mul_mat_op(ctx, src0, src1, dst);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_zdnn_compute_forward(ggml_backend_zdnn_context * ctx, ggml_tensor * dst) {
|
static bool ggml_zdnn_compute_forward(ggml_backend_zdnn_context * ctx, ggml_tensor * dst) {
|
||||||
|
|
@ -253,6 +217,8 @@ static enum ggml_status ggml_zdnn_graph_compute(ggml_backend_t backend, ggml_cgr
|
||||||
}
|
}
|
||||||
|
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
|
|
||||||
|
GGML_UNUSED(ctx_dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool ggml_zdnn_supports_op(const ggml_backend_zdnn_device_context * ctx_dev, const ggml_tensor * op) {
|
static bool ggml_zdnn_supports_op(const ggml_backend_zdnn_device_context * ctx_dev, const ggml_tensor * op) {
|
||||||
|
|
@ -266,22 +232,30 @@ static bool ggml_zdnn_supports_op(const ggml_backend_zdnn_device_context * ctx_d
|
||||||
|
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
const ggml_tensor * src0 = op->src[0];
|
const ggml_tensor * weights = op->src[0];
|
||||||
const ggml_tensor * src1 = op->src[1];
|
const ggml_tensor * inputs = op->src[1];
|
||||||
|
|
||||||
const int64_t ne10 = src1->ne[0];
|
const int64_t ne10 = inputs->ne[0];
|
||||||
const int64_t ne0 = op->ne[0];
|
const int64_t ne0 = op->ne[0];
|
||||||
const int64_t ne1 = op->ne[1];
|
const int64_t ne1 = op->ne[1];
|
||||||
|
|
||||||
const int64_t max_batch = ctx_dev->max_size;
|
const int64_t max_batch = ctx_dev->max_size;
|
||||||
|
|
||||||
return ggml_is_matrix(src0) &&
|
if (!ggml_is_matrix(weights) || !ggml_is_matrix(inputs) ||
|
||||||
ggml_is_matrix(src1) &&
|
!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) ||
|
||||||
ggml_is_contiguous(src0) &&
|
weights->view_src != nullptr || inputs->view_src != nullptr ||
|
||||||
ggml_is_contiguous(src1) &&
|
ne0 > max_batch || ne1 > max_batch || ne10 > max_batch) {
|
||||||
src0->view_src == nullptr && src1->view_src == nullptr &&
|
return false;
|
||||||
src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
|
}
|
||||||
(ne0 <= max_batch && ne1 <= max_batch && ne10 <= max_batch);
|
|
||||||
|
switch (weights->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
|
@ -374,10 +348,12 @@ static void ggml_zdnn_free(ggml_backend_zdnn_context * ctx) {
|
||||||
static void ggml_backend_zdnn_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
static void ggml_backend_zdnn_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_zdnn_buffer_context * ctx = (ggml_backend_zdnn_buffer_context *)buffer->context;
|
ggml_backend_zdnn_buffer_context * ctx = (ggml_backend_zdnn_buffer_context *)buffer->context;
|
||||||
|
|
||||||
for (int i = 0; i < ctx->n_buffers; i++) {
|
for (const auto & buf_ptr : ctx->buffers) {
|
||||||
if (ctx->buffers[i]->ztensor.buffer != NULL && ctx->buffers[i]->ztensor.is_transformed) {
|
ggml_backend_zdnn_buffer * buf = buf_ptr.get();
|
||||||
ZDNN_CHECK(zdnn_free_ztensor_buffer(&ctx->buffers[i]->ztensor));
|
|
||||||
}
|
// Free any extra buffer allocated for the tensor. E.g., bias for GGML_OP_MUL_MAT
|
||||||
|
if (buf->extra != nullptr) free(buf->extra->data);
|
||||||
|
if (buf->ztensor.buffer_size > 0) ZDNN_CHECK(zdnn_free_ztensor_buffer(&buf->ztensor));
|
||||||
}
|
}
|
||||||
|
|
||||||
delete ctx;
|
delete ctx;
|
||||||
|
|
@ -402,11 +378,37 @@ static enum ggml_status ggml_backend_zdnn_buffer_init_tensor(ggml_backend_buffer
|
||||||
std::unique_ptr<ggml_backend_zdnn_buffer> zdnn_buffer = std::make_unique<ggml_backend_zdnn_buffer>();
|
std::unique_ptr<ggml_backend_zdnn_buffer> zdnn_buffer = std::make_unique<ggml_backend_zdnn_buffer>();
|
||||||
zdnn_buffer->data = tensor->data;
|
zdnn_buffer->data = tensor->data;
|
||||||
zdnn_buffer->size = tsize;
|
zdnn_buffer->size = tsize;
|
||||||
strncpy(zdnn_buffer->name, tensor->name, GGML_MAX_NAME - 1);
|
zdnn_buffer->extra = nullptr;
|
||||||
|
snprintf(zdnn_buffer->name, GGML_MAX_NAME, "%s", tensor->name);
|
||||||
|
|
||||||
ggml_zdnn_init_tensor(zdnn_buffer.get(), tensor);
|
ggml_zdnn_init_tensor(zdnn_buffer.get(), tensor);
|
||||||
tensor->extra = zdnn_buffer.get();
|
tensor->extra = zdnn_buffer.get();
|
||||||
|
|
||||||
|
switch (tensor->op) {
|
||||||
|
case GGML_OP_MUL_MAT:
|
||||||
|
{
|
||||||
|
std::unique_ptr<ggml_backend_zdnn_buffer> zdnn_bias_buffer = std::make_unique<ggml_backend_zdnn_buffer>();
|
||||||
|
zdnn_bias_buffer->data = (void *)calloc(tensor->ne[0], ggml_element_size(tensor));
|
||||||
|
zdnn_bias_buffer->size = ggml_element_size(tensor) * tensor->ne[0];
|
||||||
|
snprintf(zdnn_bias_buffer->name, GGML_MAX_NAME, "%.*s (bias)",
|
||||||
|
GGML_MAX_NAME - (int)sizeof(" (bias)"), tensor->name);
|
||||||
|
|
||||||
|
const int64_t bias_dim[GGML_MAX_DIMS] = { 1, 1, 1, tensor->ne[0] };
|
||||||
|
ggml_zdnn_create_tensor(zdnn_bias_buffer->pre_tfm_desc,
|
||||||
|
zdnn_bias_buffer->tfm_desc,
|
||||||
|
zdnn_bias_buffer->ztensor,
|
||||||
|
tensor, bias_dim, ZDNN_1D);
|
||||||
|
|
||||||
|
ggml_zdnn_load_tensor(zdnn_bias_buffer->ztensor, zdnn_bias_buffer->data);
|
||||||
|
zdnn_buffer->extra = zdnn_bias_buffer.get();
|
||||||
|
|
||||||
|
ctx->buffers.push_back(std::move(zdnn_bias_buffer));
|
||||||
|
ctx->n_buffers++;
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
ctx->buffers.push_back(std::move(zdnn_buffer));
|
ctx->buffers.push_back(std::move(zdnn_buffer));
|
||||||
ctx->n_buffers++;
|
ctx->n_buffers++;
|
||||||
|
|
||||||
|
|
@ -414,6 +416,8 @@ static enum ggml_status ggml_backend_zdnn_buffer_init_tensor(ggml_backend_buffer
|
||||||
// __func__, tensor->name, buffer_idx, tsize);
|
// __func__, tensor->name, buffer_idx, tsize);
|
||||||
|
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
|
|
||||||
|
GGML_UNUSED(buffer_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_zdnn_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
static void ggml_backend_zdnn_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||||
|
|
@ -425,6 +429,13 @@ static void ggml_backend_zdnn_buffer_memset_tensor(ggml_backend_buffer_t buffer,
|
||||||
static void ggml_backend_zdnn_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
static void ggml_backend_zdnn_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||||
memcpy((char *)tensor->data + offset, data, size);
|
memcpy((char *)tensor->data + offset, data, size);
|
||||||
|
|
||||||
|
ggml_backend_zdnn_buffer * extra = (ggml_backend_zdnn_buffer *)tensor->extra;
|
||||||
|
|
||||||
|
// Fixes the LLAMA_SET_ROWS bug
|
||||||
|
// see: https://github.com/ggml-org/llama.cpp/issues/15414
|
||||||
|
if (tensor->buffer->usage == GGML_BACKEND_BUFFER_USAGE_COMPUTE && extra->ztensor.is_transformed) zdnn_reset_ztensor(&extra->ztensor);
|
||||||
|
if (extra->ztensor.is_transformed == false) ggml_zdnn_load_tensor(extra->ztensor, tensor->data);
|
||||||
|
|
||||||
GGML_UNUSED(buffer);
|
GGML_UNUSED(buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -528,29 +539,6 @@ ggml_backend_buffer_type_t ggml_backend_zdnn_buffer_type(void) {
|
||||||
return &ggml_backend_buffer_type_zdnn;
|
return &ggml_backend_buffer_type_zdnn;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char * ggml_backend_zdnn_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
|
|
||||||
return GGML_ZDNN_NAME "_Mapped";
|
|
||||||
|
|
||||||
GGML_UNUSED(buft);
|
|
||||||
}
|
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t ggml_backend_zdnn_buffer_from_ptr_type(void) {
|
|
||||||
static ggml_backend_buffer_type ggml_backend_buffer_from_ptr_type_zdnn = {
|
|
||||||
/* .iface = */ {
|
|
||||||
/* .get_name = */ ggml_backend_zdnn_buffer_from_ptr_type_get_name,
|
|
||||||
/* .alloc_buffer = */ ggml_backend_zdnn_buffer_type_alloc_buffer,
|
|
||||||
/* .get_alignment = */ ggml_backend_zdnn_buffer_type_get_alignment,
|
|
||||||
/* .get_max_size = */ NULL,
|
|
||||||
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
|
||||||
/* .is_host = */ ggml_backend_zdnn_buffer_type_is_host,
|
|
||||||
},
|
|
||||||
/* .device = */ &g_ggml_backend_zdnn_device,
|
|
||||||
/* .context = */ NULL,
|
|
||||||
};
|
|
||||||
|
|
||||||
return &ggml_backend_buffer_from_ptr_type_zdnn;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// backend
|
// backend
|
||||||
//
|
//
|
||||||
|
|
@ -594,27 +582,6 @@ static ggml_guid_t ggml_backend_zdnn_guid(void) {
|
||||||
return reinterpret_cast<ggml_guid_t>((void *)guid_str);
|
return reinterpret_cast<ggml_guid_t>((void *)guid_str);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: remove in the future
|
|
||||||
ggml_backend_t ggml_backend_zdnn_init(void) {
|
|
||||||
ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_zdnn_reg(), 0);
|
|
||||||
|
|
||||||
ggml_backend_zdnn_context * ctx = ggml_zdnn_init(dev);
|
|
||||||
if (ctx == NULL) {
|
|
||||||
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_backend_t backend = (ggml_backend_t)malloc(sizeof(ggml_backend));
|
|
||||||
*backend = (ggml_backend) {
|
|
||||||
/* .guid = */ ggml_backend_zdnn_guid(),
|
|
||||||
/* .iface = */ ggml_backend_zdnn_i,
|
|
||||||
/* .device = */ dev,
|
|
||||||
/* .context = */ ctx,
|
|
||||||
};
|
|
||||||
|
|
||||||
return backend;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ggml_backend_is_zdnn(ggml_backend_t backend) {
|
bool ggml_backend_is_zdnn(ggml_backend_t backend) {
|
||||||
return backend != NULL &&
|
return backend != NULL &&
|
||||||
ggml_guid_matches(backend->guid, ggml_backend_zdnn_guid());
|
ggml_guid_matches(backend->guid, ggml_backend_zdnn_guid());
|
||||||
|
|
@ -634,11 +601,15 @@ static const char * ggml_backend_zdnn_device_get_name(ggml_backend_dev_t dev) {
|
||||||
|
|
||||||
static const char * ggml_backend_zdnn_device_get_description(ggml_backend_dev_t dev) {
|
static const char * ggml_backend_zdnn_device_get_description(ggml_backend_dev_t dev) {
|
||||||
return "IBM Z Neural Network Processing Assist (NNPA)";
|
return "IBM Z Neural Network Processing Assist (NNPA)";
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_zdnn_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
static void ggml_backend_zdnn_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||||
*free = 0;
|
*free = 0;
|
||||||
*total = 0;
|
*total = 0;
|
||||||
|
|
||||||
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_backend_dev_type ggml_backend_zdnn_device_get_type(ggml_backend_dev_t dev) {
|
static enum ggml_backend_dev_type ggml_backend_zdnn_device_get_type(ggml_backend_dev_t dev) {
|
||||||
|
|
@ -655,8 +626,8 @@ static void ggml_backend_zdnn_device_get_props(ggml_backend_dev_t dev, ggml_back
|
||||||
props->caps = (ggml_backend_dev_caps) {
|
props->caps = (ggml_backend_dev_caps) {
|
||||||
/* .async = */ false,
|
/* .async = */ false,
|
||||||
/* .host_buffer = */ false,
|
/* .host_buffer = */ false,
|
||||||
/* .buffer_from_host_ptr = */ true,
|
/* .buffer_from_host_ptr = */ false,
|
||||||
/* .events = */ false,
|
/* .events = */ false
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -672,7 +643,7 @@ static ggml_backend_t ggml_backend_zdnn_device_init(ggml_backend_dev_t dev, cons
|
||||||
/* .guid = */ ggml_backend_zdnn_guid(),
|
/* .guid = */ ggml_backend_zdnn_guid(),
|
||||||
/* .iface = */ ggml_backend_zdnn_i,
|
/* .iface = */ ggml_backend_zdnn_i,
|
||||||
/* .device = */ dev,
|
/* .device = */ dev,
|
||||||
/* .context = */ ctx,
|
/* .context = */ ctx
|
||||||
};
|
};
|
||||||
|
|
||||||
return backend;
|
return backend;
|
||||||
|
|
@ -686,46 +657,6 @@ static ggml_backend_buffer_type_t ggml_backend_zdnn_device_get_buffer_type(ggml_
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_t ggml_backend_zdnn_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
|
||||||
ggml_backend_zdnn_buffer_context * ctx = new ggml_backend_zdnn_buffer_context();
|
|
||||||
|
|
||||||
ctx->all_data = ptr;
|
|
||||||
ctx->all_size = size;
|
|
||||||
ctx->owned = false;
|
|
||||||
ctx->n_buffers = 0;
|
|
||||||
|
|
||||||
const size_t size_page = sysconf(_SC_PAGESIZE);
|
|
||||||
|
|
||||||
// page-align the data ptr
|
|
||||||
{
|
|
||||||
const uintptr_t offs = (uintptr_t) ptr % size_page;
|
|
||||||
ptr = (void *)((char *)ptr - offs);
|
|
||||||
size += offs;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t size_aligned = size;
|
|
||||||
if ((size_aligned % size_page) != 0) {
|
|
||||||
size_aligned += size_page - (size_aligned % size_page);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_backend_zdnn_device_context * ctx_dev = (ggml_backend_zdnn_device_context *)dev->context;
|
|
||||||
|
|
||||||
GGML_ASSERT(ctx_dev->zdnn_device >= 0);
|
|
||||||
int device = ctx_dev->zdnn_device; GGML_UNUSED(device);
|
|
||||||
|
|
||||||
std::unique_ptr<ggml_backend_zdnn_buffer> zdnn_buffer = std::make_unique<ggml_backend_zdnn_buffer>();
|
|
||||||
zdnn_buffer->data = ptr;
|
|
||||||
zdnn_buffer->size = size;
|
|
||||||
ctx->buffers.push_back(std::move(zdnn_buffer));
|
|
||||||
|
|
||||||
GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB\n",
|
|
||||||
__func__, size_aligned / 1024.0 / 1024.0);
|
|
||||||
|
|
||||||
++ctx->n_buffers;
|
|
||||||
|
|
||||||
return ggml_backend_buffer_init(ggml_backend_zdnn_buffer_from_ptr_type(), ggml_backend_zdnn_buffer_i, ctx, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool ggml_backend_zdnn_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
static bool ggml_backend_zdnn_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||||
ggml_backend_zdnn_device_context * ctx_dev = (ggml_backend_zdnn_device_context *) dev->context;
|
ggml_backend_zdnn_device_context * ctx_dev = (ggml_backend_zdnn_device_context *) dev->context;
|
||||||
|
|
||||||
|
|
@ -734,8 +665,7 @@ static bool ggml_backend_zdnn_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
|
|
||||||
static bool ggml_backend_zdnn_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
static bool ggml_backend_zdnn_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||||
return
|
return
|
||||||
buft->iface.get_name == ggml_backend_zdnn_buffer_type_get_name ||
|
buft->iface.get_name == ggml_backend_zdnn_buffer_type_get_name;
|
||||||
buft->iface.get_name == ggml_backend_zdnn_buffer_from_ptr_type_get_name;
|
|
||||||
|
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
}
|
}
|
||||||
|
|
@ -749,7 +679,7 @@ static ggml_backend_device_i ggml_backend_zdnn_device_i = {
|
||||||
/* .init_backend = */ ggml_backend_zdnn_device_init,
|
/* .init_backend = */ ggml_backend_zdnn_device_init,
|
||||||
/* .get_buffer_type = */ ggml_backend_zdnn_device_get_buffer_type,
|
/* .get_buffer_type = */ ggml_backend_zdnn_device_get_buffer_type,
|
||||||
/* .get_host_buffer_type = */ NULL,
|
/* .get_host_buffer_type = */ NULL,
|
||||||
/* .buffer_from_host_ptr = */ ggml_backend_zdnn_device_buffer_from_ptr,
|
/* .buffer_from_host_ptr = */ NULL,
|
||||||
/* .supports_op = */ ggml_backend_zdnn_device_supports_op,
|
/* .supports_op = */ ggml_backend_zdnn_device_supports_op,
|
||||||
/* .supports_buft = */ ggml_backend_zdnn_device_supports_buft,
|
/* .supports_buft = */ ggml_backend_zdnn_device_supports_buft,
|
||||||
/* .offload_op = */ NULL,
|
/* .offload_op = */ NULL,
|
||||||
|
|
@ -813,7 +743,7 @@ static ggml_backend_reg_i ggml_backend_zdnn_reg_i = {
|
||||||
/* .get_name = */ ggml_backend_zdnn_reg_get_name,
|
/* .get_name = */ ggml_backend_zdnn_reg_get_name,
|
||||||
/* .get_device_count = */ ggml_backend_zdnn_reg_device_count,
|
/* .get_device_count = */ ggml_backend_zdnn_reg_device_count,
|
||||||
/* .get_device = */ ggml_backend_zdnn_reg_device_get,
|
/* .get_device = */ ggml_backend_zdnn_reg_device_get,
|
||||||
/* .get_proc_address = */ ggml_backend_zdnn_get_proc_address,
|
/* .get_proc_address = */ ggml_backend_zdnn_get_proc_address
|
||||||
};
|
};
|
||||||
|
|
||||||
static void ggml_zdnn_cleanup(void) {
|
static void ggml_zdnn_cleanup(void) {
|
||||||
|
|
@ -831,13 +761,13 @@ ggml_backend_reg_t ggml_backend_zdnn_reg(void) {
|
||||||
g_ggml_backend_zdnn_reg = (ggml_backend_reg) {
|
g_ggml_backend_zdnn_reg = (ggml_backend_reg) {
|
||||||
/* .api_version = */ GGML_ZDNN_VERSION,
|
/* .api_version = */ GGML_ZDNN_VERSION,
|
||||||
/* .iface = */ ggml_backend_zdnn_reg_i,
|
/* .iface = */ ggml_backend_zdnn_reg_i,
|
||||||
/* .context = */ NULL,
|
/* .context = */ NULL
|
||||||
};
|
};
|
||||||
|
|
||||||
g_ggml_backend_zdnn_device = (ggml_backend_device) {
|
g_ggml_backend_zdnn_device = (ggml_backend_device) {
|
||||||
/* .iface = */ ggml_backend_zdnn_device_i,
|
/* .iface = */ ggml_backend_zdnn_device_i,
|
||||||
/* .reg = */ &g_ggml_backend_zdnn_reg,
|
/* .reg = */ &g_ggml_backend_zdnn_reg,
|
||||||
/* .context = */ &g_ggml_ctx_dev_main,
|
/* .context = */ &g_ggml_ctx_dev_main
|
||||||
};
|
};
|
||||||
|
|
||||||
return &g_ggml_backend_zdnn_reg;
|
return &g_ggml_backend_zdnn_reg;
|
||||||
|
|
|
||||||
|
|
@ -111,6 +111,7 @@ class Keys:
|
||||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||||
DECODER_BLOCK_COUNT = "{arch}.decoder_block_count"
|
DECODER_BLOCK_COUNT = "{arch}.decoder_block_count"
|
||||||
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
|
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
|
||||||
|
ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping"
|
||||||
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
|
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
|
||||||
SWIN_NORM = "{arch}.swin_norm"
|
SWIN_NORM = "{arch}.swin_norm"
|
||||||
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
|
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
|
||||||
|
|
@ -146,6 +147,8 @@ class Keys:
|
||||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||||
SCALE = "{arch}.attention.scale"
|
SCALE = "{arch}.attention.scale"
|
||||||
|
OUTPUT_SCALE = "{arch}.attention.output_scale"
|
||||||
|
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
|
||||||
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
|
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
|
||||||
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
|
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
|
||||||
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
|
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
|
||||||
|
|
@ -161,6 +164,10 @@ class Keys:
|
||||||
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||||
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
|
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
|
||||||
|
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
|
||||||
|
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
|
||||||
|
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
|
||||||
|
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
|
||||||
|
|
||||||
class Split:
|
class Split:
|
||||||
LLM_KV_SPLIT_NO = "split.no"
|
LLM_KV_SPLIT_NO = "split.no"
|
||||||
|
|
@ -1114,6 +1121,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.FFN_GATE_EXP,
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
MODEL_TENSOR.FFN_UP_EXP,
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
MODEL_TENSOR.FFN_POST_NORM,
|
||||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.GPTNEOX: [
|
MODEL_ARCH.GPTNEOX: [
|
||||||
|
|
|
||||||
|
|
@ -733,6 +733,9 @@ class GGUFWriter:
|
||||||
def add_attn_logit_softcapping(self, value: float) -> None:
|
def add_attn_logit_softcapping(self, value: float) -> None:
|
||||||
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_router_logit_softcapping(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_final_logit_softcapping(self, value: float) -> None:
|
def add_final_logit_softcapping(self, value: float) -> None:
|
||||||
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
|
@ -829,6 +832,12 @@ class GGUFWriter:
|
||||||
def add_attention_scale(self, value: float) -> None:
|
def add_attention_scale(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
|
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_attn_output_scale(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_attn_temperature_length(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_pooling_type(self, value: PoolingType) -> None:
|
def add_pooling_type(self, value: PoolingType) -> None:
|
||||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||||
|
|
||||||
|
|
@ -859,6 +868,18 @@ class GGUFWriter:
|
||||||
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
|
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
|
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_ssm_conv_kernel(self, value: int) -> None:
|
def add_ssm_conv_kernel(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
|
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -136,6 +136,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.norm", # mamba-qbert
|
"model.layers.{bid}.norm", # mamba-qbert
|
||||||
"backbone.layers.{bid}.norm", # mamba
|
"backbone.layers.{bid}.norm", # mamba
|
||||||
"transformer.decoder_layer.{bid}.rms_norm", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm", # Grok
|
||||||
|
"model.layers.{bid}.pre_attn_norm", # grok-2
|
||||||
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
||||||
"encoder.layers.{bid}.input_layernorm", # chatglm
|
"encoder.layers.{bid}.input_layernorm", # chatglm
|
||||||
"transformer.layers.{bid}.attn_norm", # openelm
|
"transformer.layers.{bid}.attn_norm", # openelm
|
||||||
|
|
@ -278,6 +279,7 @@ class TensorNameMap:
|
||||||
"transformer.layer.{bid}.sa_layer_norm", # distillbert
|
"transformer.layer.{bid}.sa_layer_norm", # distillbert
|
||||||
"encoder.layers.{bid}.norm1", # nomic-bert
|
"encoder.layers.{bid}.norm1", # nomic-bert
|
||||||
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
|
||||||
|
"model.layers.{bid}.post_attn_norm", # grok-2
|
||||||
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|
@ -313,6 +315,7 @@ class TensorNameMap:
|
||||||
"h.{bid}.ln_2", # gpt2
|
"h.{bid}.ln_2", # gpt2
|
||||||
"model.layers.{bid}.ffn_norm", # internlm2
|
"model.layers.{bid}.ffn_norm", # internlm2
|
||||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||||
|
"model.layers.{bid}.pre_moe_norm", # grok-2
|
||||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||||
"model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid
|
"model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid
|
||||||
|
|
@ -338,6 +341,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
|
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
|
||||||
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
|
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
|
||||||
"model.layers.{bid}.feed_forward.up_proj",
|
"model.layers.{bid}.feed_forward.up_proj",
|
||||||
|
"model.layers.{bid}.post_moe_norm", # grok-2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_INP: (
|
MODEL_TENSOR.FFN_GATE_INP: (
|
||||||
|
|
|
||||||
|
|
@ -139,6 +139,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
||||||
{ LLM_KV_DECODER_BLOCK_COUNT, "%s.decoder_block_count" },
|
{ LLM_KV_DECODER_BLOCK_COUNT, "%s.decoder_block_count" },
|
||||||
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
|
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
|
||||||
|
{ LLM_KV_ROUTER_LOGIT_SOFTCAPPING, "%s.router_logit_softcapping" },
|
||||||
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
|
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
|
||||||
{ LLM_KV_SWIN_NORM, "%s.swin_norm" },
|
{ LLM_KV_SWIN_NORM, "%s.swin_norm" },
|
||||||
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
|
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
|
||||||
|
|
@ -169,6 +170,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||||
|
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
|
||||||
|
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
|
||||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||||
|
|
||||||
|
|
@ -182,6 +185,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
||||||
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
||||||
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
|
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
|
||||||
|
{ LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" },
|
||||||
|
{ LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" },
|
||||||
|
{ LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" },
|
||||||
|
{ LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" },
|
||||||
|
|
||||||
{ LLM_KV_SPLIT_NO, "split.no" },
|
{ LLM_KV_SPLIT_NO, "split.no" },
|
||||||
{ LLM_KV_SPLIT_COUNT, "split.count" },
|
{ LLM_KV_SPLIT_COUNT, "split.count" },
|
||||||
|
|
@ -398,12 +405,16 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
||||||
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
||||||
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||||
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
||||||
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,7 @@ enum llm_kv {
|
||||||
LLM_KV_DECODER_START_TOKEN_ID,
|
LLM_KV_DECODER_START_TOKEN_ID,
|
||||||
LLM_KV_DECODER_BLOCK_COUNT,
|
LLM_KV_DECODER_BLOCK_COUNT,
|
||||||
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
|
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
|
||||||
|
LLM_KV_ROUTER_LOGIT_SOFTCAPPING,
|
||||||
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
|
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
|
||||||
LLM_KV_SWIN_NORM,
|
LLM_KV_SWIN_NORM,
|
||||||
LLM_KV_RESCALE_EVERY_N_LAYERS,
|
LLM_KV_RESCALE_EVERY_N_LAYERS,
|
||||||
|
|
@ -173,6 +174,8 @@ enum llm_kv {
|
||||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||||
LLM_KV_ATTENTION_SCALE,
|
LLM_KV_ATTENTION_SCALE,
|
||||||
|
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||||
|
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||||
|
|
||||||
|
|
@ -186,6 +189,10 @@ enum llm_kv {
|
||||||
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
||||||
LLM_KV_ROPE_SCALING_FINETUNED,
|
LLM_KV_ROPE_SCALING_FINETUNED,
|
||||||
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
|
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
|
||||||
|
LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,
|
||||||
|
LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR,
|
||||||
|
LLM_KV_ROPE_SCALING_YARN_BETA_FAST,
|
||||||
|
LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,
|
||||||
|
|
||||||
LLM_KV_SPLIT_NO,
|
LLM_KV_SPLIT_NO,
|
||||||
LLM_KV_SPLIT_COUNT,
|
LLM_KV_SPLIT_COUNT,
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||||
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
|
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
|
||||||
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
||||||
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
|
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
|
||||||
|
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
|
||||||
};
|
};
|
||||||
|
|
||||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||||
|
|
@ -204,6 +205,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||||
return LLM_CHAT_TEMPLATE_KIMI_K2;
|
return LLM_CHAT_TEMPLATE_KIMI_K2;
|
||||||
} else if (tmpl_contains("<seed:bos>")) {
|
} else if (tmpl_contains("<seed:bos>")) {
|
||||||
return LLM_CHAT_TEMPLATE_SEED_OSS;
|
return LLM_CHAT_TEMPLATE_SEED_OSS;
|
||||||
|
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
|
||||||
|
return LLM_CHAT_TEMPLATE_GROK_2;
|
||||||
}
|
}
|
||||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||||
}
|
}
|
||||||
|
|
@ -763,6 +766,20 @@ int32_t llm_chat_apply_template(
|
||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<seed:bos>assistant\n";
|
ss << "<seed:bos>assistant\n";
|
||||||
}
|
}
|
||||||
|
} else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) {
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
if (role == "system") {
|
||||||
|
ss << "System: " << trim(message->content) << "<|separator|>\n\n";
|
||||||
|
} else if (role == "user") {
|
||||||
|
ss << "Human: " << trim(message->content) << "<|separator|>\n\n";
|
||||||
|
} else if (role == "assistant") {
|
||||||
|
ss << "Assistant: " << message->content << "<|separator|>\n\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "Assistant:";
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// template not supported
|
// template not supported
|
||||||
return -1;
|
return -1;
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ enum llm_chat_template {
|
||||||
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
|
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
|
||||||
LLM_CHAT_TEMPLATE_KIMI_K2,
|
LLM_CHAT_TEMPLATE_KIMI_K2,
|
||||||
LLM_CHAT_TEMPLATE_SEED_OSS,
|
LLM_CHAT_TEMPLATE_SEED_OSS,
|
||||||
|
LLM_CHAT_TEMPLATE_GROK_2,
|
||||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,10 +35,10 @@ llama_context::llama_context(
|
||||||
|
|
||||||
cparams.n_threads = params.n_threads;
|
cparams.n_threads = params.n_threads;
|
||||||
cparams.n_threads_batch = params.n_threads_batch;
|
cparams.n_threads_batch = params.n_threads_batch;
|
||||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
|
||||||
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
|
||||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
|
||||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
|
||||||
cparams.embeddings = params.embeddings;
|
cparams.embeddings = params.embeddings;
|
||||||
cparams.offload_kqv = params.offload_kqv;
|
cparams.offload_kqv = params.offload_kqv;
|
||||||
cparams.no_perf = params.no_perf;
|
cparams.no_perf = params.no_perf;
|
||||||
|
|
@ -181,7 +181,7 @@ llama_context::llama_context(
|
||||||
// graph outputs buffer
|
// graph outputs buffer
|
||||||
{
|
{
|
||||||
// resized during inference when a batch uses more outputs
|
// resized during inference when a batch uses more outputs
|
||||||
if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
|
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
|
||||||
throw std::runtime_error("failed to reserve initial output buffer");
|
throw std::runtime_error("failed to reserve initial output buffer");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2263,9 +2263,9 @@ llama_context_params llama_context_default_params() {
|
||||||
/*.rope_freq_base =*/ 0.0f,
|
/*.rope_freq_base =*/ 0.0f,
|
||||||
/*.rope_freq_scale =*/ 0.0f,
|
/*.rope_freq_scale =*/ 0.0f,
|
||||||
/*.yarn_ext_factor =*/ -1.0f,
|
/*.yarn_ext_factor =*/ -1.0f,
|
||||||
/*.yarn_attn_factor =*/ 1.0f,
|
/*.yarn_attn_factor =*/ -1.0f,
|
||||||
/*.yarn_beta_fast =*/ 32.0f,
|
/*.yarn_beta_fast =*/ -1.0f,
|
||||||
/*.yarn_beta_slow =*/ 1.0f,
|
/*.yarn_beta_slow =*/ -1.0f,
|
||||||
/*.yarn_orig_ctx =*/ 0,
|
/*.yarn_orig_ctx =*/ 0,
|
||||||
/*.defrag_thold =*/ -1.0f,
|
/*.defrag_thold =*/ -1.0f,
|
||||||
/*.cb_eval =*/ nullptr,
|
/*.cb_eval =*/ nullptr,
|
||||||
|
|
|
||||||
|
|
@ -1335,14 +1335,14 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||||
|
|
||||||
if (arch == LLM_ARCH_GROK) {
|
if (arch == LLM_ARCH_GROK) {
|
||||||
// need to do the following:
|
// need to do the following:
|
||||||
// multiply by attn_output_multiplyer of 0.08838834764831845
|
// multiply by attn_output_multiplier
|
||||||
// and then :
|
// and then :
|
||||||
// kq = 30 * tanh(kq / 30)
|
// kq = 30 * tanh(kq / 30)
|
||||||
// before the softmax below
|
// before the softmax below
|
||||||
|
|
||||||
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
|
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
|
||||||
cb(kq, "kq_tanh", il);
|
cb(kq, "kq_tanh", il);
|
||||||
kq = ggml_scale(ctx0, kq, 30);
|
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
|
||||||
cb(kq, "kq_scaled", il);
|
cb(kq, "kq_scaled", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ struct llama_hparams {
|
||||||
float f_norm_group_eps;
|
float f_norm_group_eps;
|
||||||
|
|
||||||
float f_attn_logit_softcapping = 50.0f;
|
float f_attn_logit_softcapping = 50.0f;
|
||||||
|
float f_router_logit_softcapping = 30.0f;
|
||||||
float f_final_logit_softcapping = 30.0f;
|
float f_final_logit_softcapping = 30.0f;
|
||||||
|
|
||||||
// for RWKV
|
// for RWKV
|
||||||
|
|
@ -104,6 +105,11 @@ struct llama_hparams {
|
||||||
uint32_t n_ctx_orig_yarn;
|
uint32_t n_ctx_orig_yarn;
|
||||||
float rope_yarn_log_mul = 0.0f;
|
float rope_yarn_log_mul = 0.0f;
|
||||||
|
|
||||||
|
float yarn_ext_factor = -1.0f;
|
||||||
|
float yarn_attn_factor = 1.0f;
|
||||||
|
float yarn_beta_fast = 32.0f;
|
||||||
|
float yarn_beta_slow = 1.0f;
|
||||||
|
|
||||||
std::array<int, 4> rope_sections;
|
std::array<int, 4> rope_sections;
|
||||||
|
|
||||||
// Sliding Window Attention (SWA)
|
// Sliding Window Attention (SWA)
|
||||||
|
|
@ -136,6 +142,10 @@ struct llama_hparams {
|
||||||
float f_embedding_scale = 0.0f;
|
float f_embedding_scale = 0.0f;
|
||||||
float f_attention_scale = 0.0f;
|
float f_attention_scale = 0.0f;
|
||||||
|
|
||||||
|
// grok-2
|
||||||
|
float f_attn_out_scale = 0.0f;
|
||||||
|
uint32_t attn_temp_length = 0;
|
||||||
|
|
||||||
bool causal_attn = true;
|
bool causal_attn = true;
|
||||||
bool use_alibi = false;
|
bool use_alibi = false;
|
||||||
bool attn_soft_cap = false;
|
bool attn_soft_cap = false;
|
||||||
|
|
|
||||||
|
|
@ -685,7 +685,30 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_GROK:
|
case LLM_ARCH_GROK:
|
||||||
{
|
{
|
||||||
|
// defaults for old GGUFs
|
||||||
|
hparams.yarn_beta_fast = 8.0f;
|
||||||
|
hparams.f_logit_scale = 0.5773502691896257f;
|
||||||
|
hparams.f_embedding_scale = 78.38367176906169f;
|
||||||
|
hparams.f_attn_out_scale = 0.08838834764831845f;
|
||||||
|
hparams.f_attn_logit_softcapping = 30.0f;
|
||||||
|
hparams.f_router_logit_softcapping = 30.0f;
|
||||||
|
// no final_logit_softcapping in grok-1
|
||||||
|
hparams.f_final_logit_softcapping = 0.0f;
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||||
|
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false);
|
||||||
|
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false);
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false);
|
||||||
|
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
|
||||||
|
ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false);
|
||||||
|
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
|
||||||
|
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false);
|
||||||
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false);
|
||||||
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false);
|
||||||
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
|
||||||
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 64: type = LLM_TYPE_314B; break;
|
case 64: type = LLM_TYPE_314B; break;
|
||||||
|
|
@ -2540,6 +2563,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff
|
||||||
for (int i = 0; i < n_layer; ++i) {
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
auto & layer = layers[i];
|
auto & layer = layers[i];
|
||||||
|
|
||||||
|
|
@ -2554,12 +2578,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
|
|
||||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
|
||||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
|
||||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
|
||||||
|
|
||||||
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
|
||||||
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||||
|
|
||||||
|
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
|
if (!layer.ffn_post_norm) {
|
||||||
|
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_DBRX:
|
case LLM_ARCH_DBRX:
|
||||||
|
|
@ -7028,9 +7059,6 @@ struct llm_build_grok : public llm_graph_context {
|
||||||
|
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
// multiply by embedding_multiplier_scale of 78.38367176906169
|
|
||||||
inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
|
|
||||||
|
|
||||||
// inp_pos - contains the positions
|
// inp_pos - contains the positions
|
||||||
ggml_tensor * inp_pos = build_inp_pos();
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
|
@ -7102,26 +7130,22 @@ struct llm_build_grok : public llm_graph_context {
|
||||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grok
|
|
||||||
// if attn_out_norm is present then apply it before adding the input
|
|
||||||
if (model.layers[il].attn_out_norm) {
|
|
||||||
cur = build_norm(cur,
|
cur = build_norm(cur,
|
||||||
model.layers[il].attn_out_norm, NULL,
|
model.layers[il].attn_out_norm, NULL,
|
||||||
LLM_NORM_RMS, il);
|
LLM_NORM_RMS, il);
|
||||||
cb(cur, "attn_out_norm", il);
|
cb(cur, "attn_out_norm", il);
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
cb(ffn_inp, "ffn_inp", il);
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
// feed-forward network
|
// feed-forward network
|
||||||
// MoE branch
|
|
||||||
cur = build_norm(ffn_inp,
|
cur = build_norm(ffn_inp,
|
||||||
model.layers[il].ffn_norm, NULL,
|
model.layers[il].ffn_norm, NULL,
|
||||||
LLM_NORM_RMS, il);
|
LLM_NORM_RMS, il);
|
||||||
cb(cur, "ffn_norm", il);
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
cur = build_moe_ffn(cur,
|
// MoE branch
|
||||||
|
ggml_tensor * moe_out = build_moe_ffn(cur,
|
||||||
model.layers[il].ffn_gate_inp,
|
model.layers[il].ffn_gate_inp,
|
||||||
model.layers[il].ffn_up_exps,
|
model.layers[il].ffn_up_exps,
|
||||||
model.layers[il].ffn_gate_exps,
|
model.layers[il].ffn_gate_exps,
|
||||||
|
|
@ -7132,18 +7156,28 @@ struct llm_build_grok : public llm_graph_context {
|
||||||
false, 0.0,
|
false, 0.0,
|
||||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
il);
|
il);
|
||||||
cb(cur, "ffn_moe_out", il);
|
cb(moe_out, "ffn_moe_out", il);
|
||||||
|
|
||||||
// Grok
|
if (model.layers[il].ffn_up) {
|
||||||
// if layer_out_norm is present then apply it before adding the input
|
ggml_tensor * ffn_out = build_ffn(cur,
|
||||||
// Idea: maybe ffn_out_norm is a better name
|
model.layers[il].ffn_up, NULL, NULL,
|
||||||
if (model.layers[il].layer_out_norm) {
|
model.layers[il].ffn_gate, NULL, NULL,
|
||||||
cur = build_norm(cur,
|
model.layers[il].ffn_down, NULL, NULL,
|
||||||
model.layers[il].layer_out_norm, NULL,
|
NULL,
|
||||||
LLM_NORM_RMS, il);
|
LLM_FFN_GELU, LLM_FFN_PAR, il);
|
||||||
cb(cur, "layer_out_norm", il);
|
cb(ffn_out, "ffn_out", il);
|
||||||
|
|
||||||
|
cur = ggml_scale(ctx0, ggml_add(ctx0, ffn_out, moe_out), std::sqrt(2) / 2);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
} else {
|
||||||
|
cur = moe_out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cur = build_norm(cur,
|
||||||
|
model.layers[il].ffn_post_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "ffn_post_norm", il);
|
||||||
|
|
||||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
|
@ -7166,10 +7200,14 @@ struct llm_build_grok : public llm_graph_context {
|
||||||
// lm_head
|
// lm_head
|
||||||
cur = build_lora_mm(model.output, cur);
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
// Grok
|
cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
|
||||||
// multiply logits by output_multiplier_scale of 0.5773502691896257
|
|
||||||
|
|
||||||
cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
|
// final logit soft-capping
|
||||||
|
if (hparams.f_final_logit_softcapping) {
|
||||||
|
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
|
||||||
|
cur = ggml_tanh(ctx0, cur);
|
||||||
|
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
|
||||||
|
}
|
||||||
|
|
||||||
cb(cur, "result_output", -1);
|
cb(cur, "result_output", -1);
|
||||||
res->t_logits = cur;
|
res->t_logits = cur;
|
||||||
|
|
|
||||||
|
|
@ -1799,7 +1799,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
new_type = tensor->type;
|
new_type = tensor->type;
|
||||||
new_data = tensor->data;
|
new_data = tensor->data;
|
||||||
new_size = ggml_nbytes(tensor);
|
new_size = ggml_nbytes(tensor);
|
||||||
LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0);
|
LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||||
} else {
|
} else {
|
||||||
const int64_t nelements = ggml_nelements(tensor);
|
const int64_t nelements = ggml_nelements(tensor);
|
||||||
|
|
||||||
|
|
@ -1916,8 +1916,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
}
|
}
|
||||||
close_ofstream();
|
close_ofstream();
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0);
|
LLAMA_LOG_INFO("%s: model size = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0);
|
||||||
LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0);
|
LLAMA_LOG_INFO("%s: quant size = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0);
|
||||||
|
|
||||||
if (qs.n_fallback > 0) {
|
if (qs.n_fallback > 0) {
|
||||||
LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
|
LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n",
|
||||||
|
|
|
||||||
|
|
@ -434,6 +434,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_GROK_2:
|
||||||
|
regex_exprs = {
|
||||||
|
// original regex from tokenizer.json
|
||||||
|
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||||
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
};
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
// default regex for BPE tokenization pre-processing
|
// default regex for BPE tokenization pre-processing
|
||||||
regex_exprs = {
|
regex_exprs = {
|
||||||
|
|
@ -1974,6 +1981,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
tokenizer_pre == "kimi-k2") {
|
tokenizer_pre == "kimi-k2") {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
|
||||||
clean_spaces = false;
|
clean_spaces = false;
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "grok-2") {
|
||||||
|
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
|
||||||
|
clean_spaces = false;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ enum llama_vocab_pre_type {
|
||||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
||||||
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
|
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
|
||||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
|
LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LLM_KV;
|
struct LLM_KV;
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,7 @@ bool llama_supports_mlock(void) {
|
||||||
|
|
||||||
bool llama_supports_gpu_offload(void) {
|
bool llama_supports_gpu_offload(void) {
|
||||||
return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr ||
|
return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr ||
|
||||||
|
ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU) != nullptr ||
|
||||||
llama_supports_rpc();
|
llama_supports_rpc();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -184,8 +185,13 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||||
model->devices.push_back(*dev);
|
model->devices.push_back(*dev);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// default device selection
|
||||||
|
|
||||||
|
// build list of available devices
|
||||||
|
std::vector<ggml_backend_dev_t> gpus;
|
||||||
|
std::vector<ggml_backend_dev_t> igpus;
|
||||||
std::vector<ggml_backend_dev_t> rpc_servers;
|
std::vector<ggml_backend_dev_t> rpc_servers;
|
||||||
// use all available devices
|
|
||||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||||
switch (ggml_backend_dev_type(dev)) {
|
switch (ggml_backend_dev_type(dev)) {
|
||||||
|
|
@ -194,19 +200,51 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||||
// skip CPU backends since they are handled separately
|
// skip CPU backends since they are handled separately
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case GGML_BACKEND_DEVICE_TYPE_GPU:
|
case GGML_BACKEND_DEVICE_TYPE_GPU: {
|
||||||
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
||||||
if (ggml_backend_reg_name(reg) == std::string("RPC")) {
|
if (ggml_backend_reg_name(reg) == std::string("RPC")) {
|
||||||
rpc_servers.push_back(dev);
|
rpc_servers.push_back(dev);
|
||||||
} else {
|
} else {
|
||||||
model->devices.push_back(dev);
|
// check if there is already a GPU with the same device id
|
||||||
|
ggml_backend_dev_props props;
|
||||||
|
ggml_backend_dev_get_props(dev, &props);
|
||||||
|
auto it = std::find_if(gpus.begin(), gpus.end(), [&props](ggml_backend_dev_t d) {
|
||||||
|
ggml_backend_dev_props d_props;
|
||||||
|
ggml_backend_dev_get_props(d, &d_props);
|
||||||
|
if (props.device_id && d_props.device_id) {
|
||||||
|
return strcmp(props.device_id, d_props.device_id) == 0;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (it != gpus.end()) {
|
||||||
|
LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n",
|
||||||
|
__func__,
|
||||||
|
ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
|
||||||
|
props.device_id ? props.device_id : "unknown id",
|
||||||
|
ggml_backend_dev_name(*it), ggml_backend_dev_description(*it));
|
||||||
|
} else {
|
||||||
|
gpus.push_back(dev);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case GGML_BACKEND_DEVICE_TYPE_IGPU:
|
||||||
|
igpus.push_back(dev);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
// add RPC servers at the front of the list
|
}
|
||||||
if (!rpc_servers.empty()) {
|
|
||||||
|
// add RPC servers at the front of the list to minimize network transfers
|
||||||
model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end());
|
model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end());
|
||||||
|
|
||||||
|
// add GPUs
|
||||||
|
model->devices.insert(model->devices.end(), gpus.begin(), gpus.end());
|
||||||
|
|
||||||
|
// add integrated GPUs only if no other devices were found
|
||||||
|
if (model->devices.empty()) {
|
||||||
|
model->devices.insert(model->devices.end(), igpus.begin(), igpus.end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -227,9 +265,12 @@ static struct llama_model * llama_model_load_from_file_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto * dev : model->devices) {
|
for (auto * dev : model->devices) {
|
||||||
size_t free, total; // NOLINT
|
ggml_backend_dev_props props;
|
||||||
ggml_backend_dev_memory(dev, &free, &total);
|
ggml_backend_dev_get_props(dev, &props);
|
||||||
LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
|
LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__,
|
||||||
|
ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
|
||||||
|
props.device_id ? props.device_id : "unknown id",
|
||||||
|
props.memory_free/1024/1024);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int status = llama_model_load(path_model, splits, *model, params);
|
const int status = llama_model_load(path_model, splits, *model, params);
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,7 @@ static std::string get_gpu_info() {
|
||||||
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
||||||
auto * dev = ggml_backend_dev_get(i);
|
auto * dev = ggml_backend_dev_get(i);
|
||||||
auto dev_type = ggml_backend_dev_type(dev);
|
auto dev_type = ggml_backend_dev_type(dev);
|
||||||
if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU) {
|
if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU || dev_type == GGML_BACKEND_DEVICE_TYPE_IGPU) {
|
||||||
gpu_list.push_back(ggml_backend_dev_description(dev));
|
gpu_list.push_back(ggml_backend_dev_description(dev));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -945,6 +945,7 @@ struct cmd_params_instance {
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// FIXME: use llama.cpp device selection logic
|
||||||
// add local GPU devices if any
|
// add local GPU devices if any
|
||||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||||
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
||||||
|
|
@ -957,6 +958,10 @@ struct cmd_params_instance {
|
||||||
case GGML_BACKEND_DEVICE_TYPE_GPU:
|
case GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||||
devices.push_back(dev);
|
devices.push_back(dev);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case GGML_BACKEND_DEVICE_TYPE_IGPU:
|
||||||
|
// iGPUs are not used when there are RPC servers
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
devices.push_back(nullptr);
|
devices.push_back(nullptr);
|
||||||
|
|
|
||||||
|
|
@ -384,5 +384,5 @@ These options provide extra functionality and customization when running the LLa
|
||||||
- `--verbose-prompt`: Print the prompt before generating text.
|
- `--verbose-prompt`: Print the prompt before generating text.
|
||||||
- `--no-display-prompt`: Don't print prompt at generation.
|
- `--no-display-prompt`: Don't print prompt at generation.
|
||||||
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used.
|
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used.
|
||||||
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance.
|
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple devices this option controls how tensors should be split across devices. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each device should get in order. For example, "3,2" will assign 60% of the data to device 0 and 40% to device 1. By default, the data is split in proportion to VRAM, but this may not be optimal for performance. The list of the devices which are being used is printed on startup and can be different from the device list given by `--list-devices` or e.g. `nvidia-smi`.
|
||||||
- `-hfr URL --hf-repo URL`: The url to the Hugging Face model repository. Used in conjunction with `--hf-file` or `-hff`. The model is downloaded and stored in the file provided by `-m` or `--model`. If `-m` is not provided, the model is auto-stored in the path specified by the `LLAMA_CACHE` environment variable or in an OS-specific local cache.
|
- `-hfr URL --hf-repo URL`: The url to the Hugging Face model repository. Used in conjunction with `--hf-file` or `-hff`. The model is downloaded and stored in the file provided by `-m` or `--model`. If `-m` is not provided, the model is auto-stored in the path specified by the `LLAMA_CACHE` environment variable or in an OS-specific local cache.
|
||||||
|
|
|
||||||
|
|
@ -406,6 +406,7 @@ struct clip_ctx {
|
||||||
}
|
}
|
||||||
if (!backend) {
|
if (!backend) {
|
||||||
backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr);
|
backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr);
|
||||||
|
backend = backend ? backend : ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU, nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -227,14 +227,8 @@ static ggml_backend_t create_backend(const rpc_server_params & params) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// try to initialize a GPU backend first
|
|
||||||
if (!backend) {
|
if (!backend) {
|
||||||
backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr);
|
backend = ggml_backend_init_best();
|
||||||
}
|
|
||||||
|
|
||||||
// if there aren't GPU backends fallback to CPU backend
|
|
||||||
if (!backend) {
|
|
||||||
backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (backend) {
|
if (backend) {
|
||||||
|
|
|
||||||
|
|
@ -2313,7 +2313,7 @@ struct server_context {
|
||||||
// thinking is enabled if:
|
// thinking is enabled if:
|
||||||
// 1. It's not explicitly disabled (reasoning_budget == 0)
|
// 1. It's not explicitly disabled (reasoning_budget == 0)
|
||||||
// 2. The chat template supports it
|
// 2. The chat template supports it
|
||||||
const bool enable_thinking = params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
|
const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get());
|
||||||
SRV_INF("Enable thinking? %d\n", enable_thinking);
|
SRV_INF("Enable thinking? %d\n", enable_thinking);
|
||||||
|
|
||||||
oai_parser_opt = {
|
oai_parser_opt = {
|
||||||
|
|
@ -2372,7 +2372,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ret != nullptr) {
|
if (ret != nullptr) {
|
||||||
SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity);
|
SLT_INF(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %.3f (> %.3f thold)\n", lcs_len, similarity, slot_prompt_similarity);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2394,7 +2394,7 @@ struct server_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ret != nullptr) {
|
if (ret != nullptr) {
|
||||||
SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last);
|
SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue