Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
2aa05c6cf0
|
|
@ -387,6 +387,39 @@ jobs:
|
|||
cd build
|
||||
ctest -L main --verbose
|
||||
|
||||
ubuntu-24-cmake-vulkan-deb:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ubuntu-24-cmake-vulkan-deb
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get install -y glslc libvulkan-dev libcurl4-openssl-dev
|
||||
|
||||
- name: Configure
|
||||
id: cmake_configure
|
||||
run: |
|
||||
cmake -B build \
|
||||
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
|
||||
-DGGML_BACKEND_DL=ON \
|
||||
-DGGML_CPU_ALL_VARIANTS=ON \
|
||||
-DGGML_VULKAN=ON
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
cmake --build build -j $(nproc)
|
||||
|
||||
ubuntu-24-cmake-vulkan:
|
||||
runs-on: ubuntu-24.04
|
||||
|
||||
|
|
@ -1272,6 +1305,81 @@ jobs:
|
|||
cd examples/llama.android
|
||||
./gradlew build --no-daemon
|
||||
|
||||
android-ndk-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
OPENCL_VERSION: 2025.07.22
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- build: 'arm64-cpu'
|
||||
defines: '-D ANDROID_ABI=arm64-v8a -D ANDROID_PLATFORM=android-31 -D CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -D GGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm -G Ninja -D LLAMA_CURL=OFF -D GGML_OPENMP=OFF'
|
||||
- build: 'arm64-snapdragon'
|
||||
defines: '--preset arm64-android-snapdragon-release'
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install OpenCL Headers and Libs
|
||||
id: install_opencl
|
||||
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||
run: |
|
||||
mkdir opencl
|
||||
curl -L -o opencl/clhpp.tar.gz https://github.com/KhronosGroup/OpenCL-CLHPP/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
|
||||
curl -L -o opencl/headers.tar.gz https://github.com/KhronosGroup/OpenCL-Headers/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
|
||||
curl -L -o opencl/icd-loader.tar.gz https://github.com/KhronosGroup/OpenCL-ICD-Loader/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
|
||||
tar -xaf opencl/headers.tar.gz -C opencl
|
||||
tar -xaf opencl/clhpp.tar.gz -C opencl
|
||||
tar -xaf opencl/icd-loader.tar.gz -C opencl
|
||||
sudo cp -r opencl/OpenCL-Headers-${OPENCL_VERSION}/CL ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
|
||||
sudo cp -r opencl/OpenCL-CLHPP-${OPENCL_VERSION}/include/CL/* ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include/CL
|
||||
cd opencl/OpenCL-ICD-Loader-${OPENCL_VERSION}
|
||||
cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -DOPENCL_ICD_LOADER_HEADERS_DIR=${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=31 -DANDROID_STL=c++_shared
|
||||
cmake --build build
|
||||
sudo cp build/libOpenCL.so ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android
|
||||
rm -rf opencl
|
||||
|
||||
- name: Install Hexagon SDK
|
||||
id: install_hexsdk
|
||||
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||
env:
|
||||
HEXSDK_VER: 6.4.0.2
|
||||
HEXTLS_VER: 19.0.04
|
||||
run: |
|
||||
curl -L -o hex-sdk.tar.gz https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v$HEXSDK_VER/hexagon-sdk-v$HEXSDK_VER-amd64-lnx.tar.xz
|
||||
mkdir hex-sdk
|
||||
tar -xaf hex-sdk.tar.gz -C hex-sdk
|
||||
ls -l hex-sdk
|
||||
sudo mv hex-sdk /opt/hexagon
|
||||
echo "HEXAGON_SDK_ROOT=/opt/hexagon/$HEXSDK_VER" >> "$GITHUB_ENV"
|
||||
echo "HEXAGON_TOOLS_ROOT=/opt/hexagon/$HEXSDK_VER/tools/HEXAGON_Tools/$HEXTLS_VER" >> "$GITHUB_ENV"
|
||||
echo "DEFAULT_HLOS_ARCH=64" >> "$GITHUB_ENV"
|
||||
echo "DEFAULT_TOOLS_VARIANT=toolv19" >> "$GITHUB_ENV"
|
||||
echo "DEFAULT_NO_QURT_INC=0" >> "$GITHUB_ENV"
|
||||
echo "DEFAULT_DSP_ARCH=v73" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Update CMake presets
|
||||
id: update_presets
|
||||
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||
run: |
|
||||
cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||
|
||||
- name: Build
|
||||
id: ndk_build
|
||||
run: |
|
||||
cmake ${{ matrix.defines }} -B build
|
||||
cmake --build build
|
||||
cmake --install build --prefix pkg-adb/llama.cpp
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
echo "FIXME: test on devices"
|
||||
|
||||
openEuler-latest-cmake-cann:
|
||||
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }}
|
||||
defaults:
|
||||
|
|
@ -1515,3 +1623,29 @@ jobs:
|
|||
run: |
|
||||
vulkaninfo --summary
|
||||
GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
|
||||
|
||||
ggml-ci-arm64-cpu-kleidiai:
|
||||
runs-on: ubuntu-22.04-arm
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: ccache
|
||||
uses: ggml-org/ccache-action@v1.2.16
|
||||
with:
|
||||
key: ggml-ci-arm64-cpu-kleidiai
|
||||
evict-old-files: 1d
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential libcurl4-openssl-dev
|
||||
|
||||
- name: Test
|
||||
id: ggml-ci
|
||||
run: |
|
||||
GG_BUILD_KLEIDIAI=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
|
||||
|
|
|
|||
|
|
@ -134,6 +134,8 @@ jobs:
|
|||
include:
|
||||
- build: 'x64'
|
||||
os: ubuntu-22.04
|
||||
- build: 's390x-z15' # z15 because our CI runners are on z15
|
||||
os: ubuntu-22.04-s390x
|
||||
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
|
||||
# - build: 'arm64'
|
||||
# os: ubuntu-22.04-arm
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ name: Update Operations Documentation
|
|||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'docs/ops.md'
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/ops.md'
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@
|
|||
/ggml/src/ggml-cuda/common.cuh @slaren
|
||||
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
|
||||
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an
|
||||
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
|
||||
|
|
@ -65,6 +65,7 @@
|
|||
/ggml/src/ggml-impl.h @ggerganov @slaren
|
||||
/ggml/src/ggml-metal/ @ggerganov
|
||||
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
|
||||
/ggml/src/ggml-hexagon/ @max-krasnyansky
|
||||
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
||||
/ggml/src/ggml-quants.* @ggerganov
|
||||
/ggml/src/ggml-rpc/ @rgerganov
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
|
||||
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
|
||||
- [x] [Jamba](https://huggingface.co/ai21labs)
|
||||
- [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon)
|
||||
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
|
||||
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
|
||||
|
|
@ -138,6 +139,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
|
||||
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
|
||||
|
||||
#### Multimodal
|
||||
|
||||
|
|
@ -187,6 +189,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
- Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift)
|
||||
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
|
||||
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
|
||||
- Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
@ -278,6 +281,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
| [IBM zDNN](docs/backend/zDNN.md) | IBM Z & LinuxONE |
|
||||
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
|
||||
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
|
||||
| [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon |
|
||||
|
||||
## Obtaining and quantizing models
|
||||
|
||||
|
|
|
|||
33
ci/run.sh
33
ci/run.sh
|
|
@ -22,6 +22,9 @@
|
|||
# # with MUSA support
|
||||
# GG_BUILD_MUSA=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
# # with KLEIDIAI support
|
||||
# GG_BUILD_KLEIDIAI=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
|
||||
#
|
||||
|
||||
if [ -z "$2" ]; then
|
||||
echo "usage: $0 <output-dir> <mnt-dir>"
|
||||
|
|
@ -72,7 +75,7 @@ if [ ! -z ${GG_BUILD_ROCM} ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DAMDGPU_TARGETS=${GG_BUILD_AMDGPU_TARGETS}"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGPU_TARGETS=${GG_BUILD_AMDGPU_TARGETS}"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_SYCL} ]; then
|
||||
|
|
@ -115,6 +118,34 @@ if [ ! -z ${GG_BUILD_NO_SVE} ]; then
|
|||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm"
|
||||
fi
|
||||
|
||||
if [ -n "${GG_BUILD_KLEIDIAI}" ]; then
|
||||
echo ">>===== Enabling KleidiAI support"
|
||||
|
||||
CANDIDATES=("armv9-a+dotprod+i8mm" "armv8.6-a+dotprod+i8mm" "armv8.2-a+dotprod")
|
||||
CPU=""
|
||||
|
||||
for cpu in "${CANDIDATES[@]}"; do
|
||||
if echo 'int main(){}' | ${CXX:-c++} -march="$cpu" -x c++ - -c -o /dev/null >/dev/null 2>&1; then
|
||||
CPU="$cpu"
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$CPU" ]; then
|
||||
echo "ERROR: None of the required ARM baselines (armv9/armv8.6/armv8.2 + dotprod) are supported by this compiler."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ">>===== Using ARM baseline: ${CPU}"
|
||||
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA:+$CMAKE_EXTRA } \
|
||||
-DGGML_NATIVE=OFF \
|
||||
-DGGML_CPU_KLEIDIAI=ON \
|
||||
-DGGML_CPU_AARCH64=ON \
|
||||
-DGGML_CPU_ARM_ARCH=${CPU} \
|
||||
-DBUILD_SHARED_LIBS=OFF"
|
||||
fi
|
||||
|
||||
## helpers
|
||||
|
||||
# download a file if it does not exist or if it is outdated
|
||||
|
|
|
|||
309
common/arg.cpp
309
common/arg.cpp
|
|
@ -1760,7 +1760,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
|
||||
add_opt(common_arg(
|
||||
{"-t", "--threads"}, "N",
|
||||
string_format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads),
|
||||
string_format("number of CPU threads to use during generation (default: %d)", params.cpuparams.n_threads),
|
||||
[](common_params & params, int value) {
|
||||
params.cpuparams.n_threads = value;
|
||||
if (params.cpuparams.n_threads <= 0) {
|
||||
|
|
@ -1935,6 +1935,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.n_ctx_checkpoints = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--cache-ram", "-cram"}, "N",
|
||||
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)\n"
|
||||
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)", params.cache_ram_mib),
|
||||
[](common_params & params, int value) {
|
||||
params.cache_ram_mib = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_CACHE_RAM").set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
add_opt(common_arg(
|
||||
{"--kv-unified", "-kvu"},
|
||||
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
||||
|
|
@ -3350,7 +3358,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
add_opt(common_arg(
|
||||
{"--chat-template-kwargs"}, "STRING",
|
||||
string_format("sets additional params for the json template parser"),
|
||||
[](common_params & params, const std::string & value) {
|
||||
[](common_params & params, const std::string & value) {
|
||||
auto parsed = json::parse(value);
|
||||
for (const auto & item : parsed.items()) {
|
||||
params.default_template_kwargs[item.key()] = item.value().dump();
|
||||
|
|
@ -3427,12 +3435,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params) {
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-format"}, "FORMAT",
|
||||
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
||||
"- none: leaves thoughts unparsed in `message.content`\n"
|
||||
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
|
||||
"- deepseek: puts thoughts in `message.reasoning_content`\n"
|
||||
"- deepseek-legacy: keeps `<think>` tags in `message.content` while also populating `message.reasoning_content`\n"
|
||||
"(default: auto)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.reasoning_format = common_reasoning_format_from_name(value);
|
||||
|
|
@ -3561,21 +3570,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
common_log_set_file(common_log_main(), value.c_str());
|
||||
}
|
||||
));
|
||||
add_opt(common_arg({ "--log-colors" }, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
"'auto' enables colors when output is to a terminal",
|
||||
[](common_params &, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
|
||||
} else if (is_falsey(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
|
||||
} else if (is_autoy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
|
||||
}
|
||||
}).set_env("LLAMA_LOG_COLORS"));
|
||||
add_opt(common_arg(
|
||||
{"--log-colors"}, "[on|off|auto]",
|
||||
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
|
||||
"'auto' enables colors when output is to a terminal",
|
||||
[](common_params &, const std::string & value) {
|
||||
if (is_truthy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
|
||||
} else if (is_falsey(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
|
||||
} else if (is_autoy(value)) {
|
||||
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
|
||||
}
|
||||
}
|
||||
).set_env("LLAMA_LOG_COLORS"));
|
||||
add_opt(common_arg(
|
||||
{"-v", "--verbose", "--log-verbose"},
|
||||
"Set verbosity level to infinity (i.e. log all messages, useful for debugging)",
|
||||
|
|
@ -3841,7 +3852,87 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
}
|
||||
).set_examples({LLAMA_EXAMPLE_TTS}));
|
||||
|
||||
// model-specific
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-steps"}, "N",
|
||||
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
|
||||
[](common_params & params, int value) { params.diffusion.steps = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-visual"},
|
||||
string_format("enable visual diffusion mode (show progressive generation) (default: %s)", params.diffusion.visual_mode ? "true" : "false"),
|
||||
[](common_params & params) { params.diffusion.visual_mode = true; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-eps"}, "F",
|
||||
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-algorithm"}, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)", params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-alg-temp"}, "F",
|
||||
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-block-length"}, "N",
|
||||
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
|
||||
[](common_params & params, int value) { params.diffusion.block_length = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-cfg-scale"}, "F",
|
||||
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{"--diffusion-add-gumbel-noise"}, "F",
|
||||
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "-lr", "--learning-rate" }, "ALPHA",
|
||||
string_format("adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)", (double) params.lr.lr0),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
|
||||
string_format("(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
|
||||
(double) params.lr.lr_min),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-decay-epochs", "--learning-rate-decay-epochs"}, "ALPHA",
|
||||
string_format("(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)", (double) params.lr.decay_epochs),
|
||||
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-wd", "--weight-decay"}, "WD",
|
||||
string_format("adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).", (double) params.lr.wd),
|
||||
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-val-split", "--val-split"}, "FRACTION",
|
||||
string_format("fraction of data to use as validation set for training (default: %.2g).", (double) params.val_split),
|
||||
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-epochs", "--epochs"}, "N",
|
||||
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
|
||||
[](common_params & params, int epochs) { params.lr.epochs = epochs; }
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{"-opt", "--optimizer"}, "sgd|adamw", "adamw or sgd",
|
||||
[](common_params & params, const std::string & name) {
|
||||
params.optimizer = common_opt_get_optimizer(name.c_str());
|
||||
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
|
||||
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
|
||||
}
|
||||
}
|
||||
).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
|
||||
// presets
|
||||
add_opt(common_arg(
|
||||
{"--tts-oute-default"},
|
||||
string_format("use default OuteTTS models (note: can download weights from the internet)"),
|
||||
|
|
@ -3854,39 +3945,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_TTS}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-bge-small-en-default"},
|
||||
string_format("use default bge-small-en-v1.5 model (note: can download weights from the internet)"),
|
||||
{"--embd-gemma-default"},
|
||||
string_format("use default EmbeddingGemma model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/bge-small-en-v1.5-Q8_0-GGUF";
|
||||
params.model.hf_file = "bge-small-en-v1.5-q8_0.gguf";
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-e5-small-en-default"},
|
||||
string_format("use default e5-small-v2 model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/e5-small-v2-Q8_0-GGUF";
|
||||
params.model.hf_file = "e5-small-v2-q8_0.gguf";
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{"--embd-gte-small-default"},
|
||||
string_format("use default gte-small model (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gte-small-Q8_0-GGUF";
|
||||
params.model.hf_file = "gte-small-q8_0.gguf";
|
||||
params.embd_normalize = 2;
|
||||
params.n_ctx = 512;
|
||||
params.model.hf_repo = "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF";
|
||||
params.model.hf_file = "embeddinggemma-300M-qat-Q4_0.gguf";
|
||||
params.port = 8011;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 2048;
|
||||
params.n_parallel = 32;
|
||||
params.n_ctx = 2048*params.n_parallel;
|
||||
params.verbose_prompt = true;
|
||||
params.embedding = true;
|
||||
}
|
||||
|
|
@ -3981,96 +4049,65 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-steps" }, "N",
|
||||
string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
|
||||
[](common_params & params, int value) { params.diffusion.steps = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-visual" },
|
||||
string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
|
||||
params.diffusion.visual_mode ? "true" : "false"),
|
||||
[](common_params & params) { params.diffusion.visual_mode = true; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--gpt-oss-20b-default"},
|
||||
string_format("use gpt-oss-20b (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gpt-oss-20b-GGUF";
|
||||
params.model.hf_file = "gpt-oss-20b-mxfp4.gguf";
|
||||
params.port = 8013;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 32768;
|
||||
params.n_parallel = 2;
|
||||
params.n_ctx = 131072*params.n_parallel;
|
||||
params.sampling.temp = 1.0f;
|
||||
params.sampling.top_p = 1.0f;
|
||||
params.sampling.top_k = 0;
|
||||
params.sampling.min_p = 0.01f;
|
||||
params.use_jinja = true;
|
||||
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-eps" }, "F",
|
||||
string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-algorithm" }, "N",
|
||||
string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)",
|
||||
params.diffusion.algorithm),
|
||||
[](common_params & params, int value) { params.diffusion.algorithm = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-alg-temp" }, "F",
|
||||
string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--gpt-oss-120b-default"},
|
||||
string_format("use gpt-oss-120b (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gpt-oss-120b-GGUF";
|
||||
params.port = 8013;
|
||||
params.n_ubatch = 2048;
|
||||
params.n_batch = 32768;
|
||||
params.n_parallel = 2;
|
||||
params.n_ctx = 131072*params.n_parallel;
|
||||
params.sampling.temp = 1.0f;
|
||||
params.sampling.top_p = 1.0f;
|
||||
params.sampling.top_k = 0;
|
||||
params.sampling.min_p = 0.01f;
|
||||
params.use_jinja = true;
|
||||
//params.default_template_kwargs["reasoning_effort"] = "\"high\"";
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-block-length" }, "N",
|
||||
string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
|
||||
[](common_params & params, int value) { params.diffusion.block_length = value; }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-cfg-scale" }, "F",
|
||||
string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
add_opt(common_arg(
|
||||
{ "--diffusion-add-gumbel-noise" }, "F",
|
||||
string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
|
||||
[](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
|
||||
).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
|
||||
{"--vision-gemma-4b-default"},
|
||||
string_format("use Gemma 3 4B QAT (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gemma-3-4b-it-qat-GGUF";
|
||||
params.port = 8014;
|
||||
params.n_ctx = 0;
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
|
||||
add_opt(
|
||||
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
|
||||
string_format(
|
||||
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
|
||||
(double) params.lr.lr0),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(
|
||||
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
|
||||
string_format(
|
||||
"(if >0) final learning rate after decay (if -decay-epochs is set, default=%.2g)",
|
||||
(double) params.lr.lr_min),
|
||||
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(
|
||||
common_arg({ "-decay-epochs", "--learning-rate-decay-epochs" }, "ALPHA",
|
||||
string_format(
|
||||
"(if >0) decay learning rate to -lr-min after this many epochs (exponential decay, default=%.2g)",
|
||||
(double) params.lr.decay_epochs),
|
||||
[](common_params & params, const std::string & value) { params.lr.decay_epochs = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg(
|
||||
{ "-wd", "--weight-decay" }, "WD",
|
||||
string_format(
|
||||
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
|
||||
(double) params.lr.wd),
|
||||
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-val-split", "--val-split" }, "FRACTION",
|
||||
string_format("fraction of data to use as validation set for training (default: %.2g).",
|
||||
(double) params.val_split),
|
||||
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
|
||||
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
|
||||
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
|
||||
[](common_params & params, const std::string & name) {
|
||||
params.optimizer = common_opt_get_optimizer(name.c_str());
|
||||
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
|
||||
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
|
||||
}
|
||||
})
|
||||
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
|
||||
{"--vision-gemma-12b-default"},
|
||||
string_format("use Gemma 3 12B QAT (note: can download weights from the internet)"),
|
||||
[](common_params & params) {
|
||||
params.model.hf_repo = "ggml-org/gemma-3-12b-it-qat-GGUF";
|
||||
params.port = 8014;
|
||||
params.n_ctx = 0;
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER}));
|
||||
|
||||
return ctx_arg;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,9 +3,12 @@
|
|||
#include "log.h"
|
||||
#include "regex-partial.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
|
@ -166,6 +169,27 @@ void common_chat_msg_parser::consume_literal(const std::string & literal) {
|
|||
}
|
||||
|
||||
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
|
||||
std::string pending_reasoning_prefix;
|
||||
|
||||
if (syntax_.reasoning_format == COMMON_REASONING_FORMAT_NONE) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto set_reasoning_prefix = [&](size_t prefix_pos) {
|
||||
if (!syntax_.thinking_forced_open || syntax_.reasoning_in_content) {
|
||||
return;
|
||||
}
|
||||
if (prefix_pos + start_think.size() > input_.size()) {
|
||||
pending_reasoning_prefix.clear();
|
||||
return;
|
||||
}
|
||||
// Capture the exact literal that opened the reasoning section so we can
|
||||
// surface it back to callers. This ensures formats that force the
|
||||
// reasoning tag open (e.g. DeepSeek R1) retain their original prefix
|
||||
// instead of dropping it during parsing.
|
||||
pending_reasoning_prefix = input_.substr(prefix_pos, start_think.size());
|
||||
};
|
||||
|
||||
auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
|
||||
auto stripped_reasoning = string_strip(reasoning);
|
||||
if (stripped_reasoning.empty()) {
|
||||
|
|
@ -178,28 +202,116 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
|
|||
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
|
||||
}
|
||||
} else {
|
||||
if (!pending_reasoning_prefix.empty()) {
|
||||
add_reasoning_content(pending_reasoning_prefix);
|
||||
pending_reasoning_prefix.clear();
|
||||
}
|
||||
add_reasoning_content(stripped_reasoning);
|
||||
}
|
||||
};
|
||||
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
|
||||
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
|
||||
if (auto res = try_find_literal(end_think)) {
|
||||
handle_reasoning(res->prelude, /* closed */ true);
|
||||
consume_spaces();
|
||||
return true;
|
||||
}
|
||||
auto rest = consume_rest();
|
||||
|
||||
const size_t saved_pos = pos_;
|
||||
const size_t saved_content_size = result_.content.size();
|
||||
const size_t saved_reasoning_size = result_.reasoning_content.size();
|
||||
|
||||
auto restore_state = [&]() {
|
||||
move_to(saved_pos);
|
||||
result_.content.resize(saved_content_size);
|
||||
result_.reasoning_content.resize(saved_reasoning_size);
|
||||
};
|
||||
|
||||
// Allow leading whitespace to be preserved as content when reasoning is present at the start
|
||||
size_t cursor = pos_;
|
||||
size_t whitespace_end = cursor;
|
||||
while (whitespace_end < input_.size() && std::isspace(static_cast<unsigned char>(input_[whitespace_end]))) {
|
||||
++whitespace_end;
|
||||
}
|
||||
|
||||
if (whitespace_end >= input_.size()) {
|
||||
restore_state();
|
||||
if (syntax_.thinking_forced_open) {
|
||||
auto rest = input_.substr(saved_pos);
|
||||
if (!rest.empty()) {
|
||||
handle_reasoning(rest, /* closed */ !is_partial());
|
||||
}
|
||||
// Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877)
|
||||
// if (!syntax_.thinking_forced_open) {
|
||||
// throw common_chat_msg_partial_exception(end_think);
|
||||
// }
|
||||
move_to(input_.size());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
cursor = whitespace_end;
|
||||
const size_t remaining = input_.size() - cursor;
|
||||
const size_t start_prefix = std::min(start_think.size(), remaining);
|
||||
const bool has_start_tag = input_.compare(cursor, start_prefix, start_think, 0, start_prefix) == 0;
|
||||
|
||||
if (has_start_tag && start_prefix < start_think.size()) {
|
||||
move_to(input_.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
if (has_start_tag) {
|
||||
if (whitespace_end > pos_) {
|
||||
add_content(input_.substr(pos_, whitespace_end - pos_));
|
||||
}
|
||||
set_reasoning_prefix(cursor);
|
||||
cursor += start_think.size();
|
||||
} else if (syntax_.thinking_forced_open) {
|
||||
cursor = whitespace_end;
|
||||
} else {
|
||||
restore_state();
|
||||
return false;
|
||||
}
|
||||
while (true) {
|
||||
if (cursor >= input_.size()) {
|
||||
move_to(input_.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t end_pos = input_.find(end_think, cursor);
|
||||
if (end_pos == std::string::npos) {
|
||||
std::string_view remaining_view(input_.data() + cursor, input_.size() - cursor);
|
||||
size_t partial_off = string_find_partial_stop(remaining_view, end_think);
|
||||
size_t reasoning_end = partial_off == std::string::npos ? input_.size() : cursor + partial_off;
|
||||
if (reasoning_end > cursor) {
|
||||
handle_reasoning(input_.substr(cursor, reasoning_end - cursor), /* closed */ partial_off == std::string::npos && !is_partial());
|
||||
}
|
||||
move_to(input_.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
if (end_pos > cursor) {
|
||||
handle_reasoning(input_.substr(cursor, end_pos - cursor), /* closed */ true);
|
||||
} else {
|
||||
handle_reasoning("", /* closed */ true);
|
||||
}
|
||||
|
||||
cursor = end_pos + end_think.size();
|
||||
|
||||
while (cursor < input_.size() && std::isspace(static_cast<unsigned char>(input_[cursor]))) {
|
||||
++cursor;
|
||||
}
|
||||
|
||||
const size_t next_remaining = input_.size() - cursor;
|
||||
if (next_remaining == 0) {
|
||||
move_to(cursor);
|
||||
return true;
|
||||
}
|
||||
|
||||
const size_t next_prefix = std::min(start_think.size(), next_remaining);
|
||||
if (input_.compare(cursor, next_prefix, start_think, 0, next_prefix) == 0) {
|
||||
if (next_prefix < start_think.size()) {
|
||||
move_to(input_.size());
|
||||
return true;
|
||||
}
|
||||
set_reasoning_prefix(cursor);
|
||||
cursor += start_think.size();
|
||||
continue;
|
||||
}
|
||||
|
||||
move_to(cursor);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string common_chat_msg_parser::consume_rest() {
|
||||
|
|
@ -320,7 +432,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
|||
if (is_arguments_path({})) {
|
||||
// Entire JSON is the arguments and was parsed fully.
|
||||
return consume_json_result {
|
||||
partial->json.dump(),
|
||||
partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true),
|
||||
/* .is_partial = */ false,
|
||||
};
|
||||
}
|
||||
|
|
@ -332,7 +444,7 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
|
|||
std::vector<std::string> path;
|
||||
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
|
||||
if (is_arguments_path(path)) {
|
||||
auto arguments = j.dump();
|
||||
auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true);
|
||||
if (is_partial() && !partial->healing_marker.marker.empty()) {
|
||||
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
|
||||
if (idx != std::string::npos) {
|
||||
|
|
|
|||
|
|
@ -1408,6 +1408,8 @@ static common_chat_params common_chat_params_init_apertus(const common_chat_temp
|
|||
return data;
|
||||
}
|
||||
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
|
|
@ -2862,6 +2864,7 @@ common_chat_params common_chat_templates_apply(
|
|||
}
|
||||
|
||||
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ struct common_chat_msg_content_part {
|
|||
struct common_chat_msg {
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::vector<common_chat_msg_content_part> content_parts = {};
|
||||
std::vector<common_chat_tool_call> tool_calls = {};
|
||||
std::vector<common_chat_msg_content_part> content_parts;
|
||||
std::vector<common_chat_tool_call> tool_calls;
|
||||
std::string reasoning_content;
|
||||
std::string tool_name;
|
||||
std::string tool_call_id;
|
||||
|
|
@ -44,7 +44,7 @@ struct common_chat_msg {
|
|||
bool empty() const {
|
||||
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||
}
|
||||
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||
void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||
if (ids_cache.size() <= i) {
|
||||
auto id = tool_calls[i].id;
|
||||
|
|
|
|||
|
|
@ -378,7 +378,7 @@ struct common_params {
|
|||
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
|
||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||
bool no_perf = false; // disable performance metrics
|
||||
bool ctx_shift = false; // context shift on infinite text generation
|
||||
bool ctx_shift = false; // context shift on infinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
bool kv_unified = false; // enable unified KV cache
|
||||
|
||||
|
|
@ -425,7 +425,8 @@ struct common_params {
|
|||
int32_t timeout_write = timeout_read; // http write timeout in seconds
|
||||
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
|
||||
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
|
||||
int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot
|
||||
int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
std::string public_path = ""; // NOLINT
|
||||
|
|
@ -433,7 +434,7 @@ struct common_params {
|
|||
std::string chat_template = ""; // NOLINT
|
||||
bool use_jinja = false; // NOLINT
|
||||
bool enable_chat_template = true;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
|
||||
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
int reasoning_budget = -1;
|
||||
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <regex>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
|
|
@ -168,6 +169,47 @@ bool common_json_parse(
|
|||
}
|
||||
}
|
||||
|
||||
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
|
||||
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
|
||||
|
||||
auto is_high_surrogate = [&](const std::string & s) {
|
||||
// Check if a partial of a high surrogate (U+D800-U+DBFF)
|
||||
return s.length() >= 4 &&
|
||||
s[0] == '\\' && s[1] == 'u' &&
|
||||
std::tolower(s[2]) == 'd' &&
|
||||
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
|
||||
};
|
||||
|
||||
// Initialize the unicode marker to a low surrogate to handle the edge case
|
||||
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
|
||||
// backslash (\)
|
||||
std::string unicode_marker_padding = "udc00";
|
||||
std::smatch last_unicode_seq;
|
||||
|
||||
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
|
||||
std::smatch second_last_seq;
|
||||
std::string prelude = str.substr(0, last_unicode_seq.position());
|
||||
|
||||
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
|
||||
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
|
||||
|
||||
if (is_high_surrogate(last_unicode_seq.str())) {
|
||||
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
|
||||
unicode_marker_padding += "\\udc00";
|
||||
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
|
||||
if (is_high_surrogate(second_last_seq.str())) {
|
||||
// If this follows a high surrogate, pad it to be a low surrogate
|
||||
if (last_unicode_seq.length() == 2) {
|
||||
unicode_marker_padding = "dc00";
|
||||
} else if (last_unicode_seq.length() == 3) {
|
||||
unicode_marker_padding = "c00";
|
||||
} else {
|
||||
// The original unicode_marker_padding is already padded with 0s
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||
|
||||
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
|
|
@ -186,6 +228,9 @@ bool common_json_parse(
|
|||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an object value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an object value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
// find last :
|
||||
auto last_pos = str.find_last_of(':');
|
||||
|
|
@ -205,6 +250,9 @@ bool common_json_parse(
|
|||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an array value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an array value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||
// Had just finished a value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||
|
|
@ -230,6 +278,9 @@ bool common_json_parse(
|
|||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||
// Was inside an object key string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
|
||||
// Was inside an object key string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
|
|
|
|||
|
|
@ -41,9 +41,9 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
|||
return result;
|
||||
}
|
||||
|
||||
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||
auto has_min = min_value != std::numeric_limits<int>::min();
|
||||
auto has_max = max_value != std::numeric_limits<int>::max();
|
||||
static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||
auto has_min = min_value != std::numeric_limits<int64_t>::min();
|
||||
auto has_max = max_value != std::numeric_limits<int64_t>::max();
|
||||
|
||||
auto digit_range = [&](char from, char to) {
|
||||
out << "[";
|
||||
|
|
@ -159,7 +159,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||
if (has_min) {
|
||||
if (min_value < 0) {
|
||||
out << "\"-\" (";
|
||||
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
||||
_build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
||||
out << ") | [0] | [1-9] ";
|
||||
more_digits(0, decimals_left - 1);
|
||||
} else if (min_value == 0) {
|
||||
|
|
@ -194,7 +194,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||
}
|
||||
digit_range(c, c);
|
||||
out << " (";
|
||||
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
|
||||
_build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
|
||||
out << ")";
|
||||
if (c < '9') {
|
||||
out << " | ";
|
||||
|
|
@ -216,7 +216,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
|
||||
} else {
|
||||
out << "\"-\" (";
|
||||
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
|
||||
_build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
|
||||
out << ")";
|
||||
}
|
||||
return;
|
||||
|
|
@ -925,17 +925,17 @@ public:
|
|||
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
||||
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
||||
int min_value = std::numeric_limits<int>::min();
|
||||
int max_value = std::numeric_limits<int>::max();
|
||||
int64_t min_value = std::numeric_limits<int64_t>::min();
|
||||
int64_t max_value = std::numeric_limits<int64_t>::max();
|
||||
if (schema.contains("minimum")) {
|
||||
min_value = schema["minimum"].get<int>();
|
||||
min_value = schema["minimum"].get<int64_t>();
|
||||
} else if (schema.contains("exclusiveMinimum")) {
|
||||
min_value = schema["exclusiveMinimum"].get<int>() + 1;
|
||||
min_value = schema["exclusiveMinimum"].get<int64_t>() + 1;
|
||||
}
|
||||
if (schema.contains("maximum")) {
|
||||
max_value = schema["maximum"].get<int>();
|
||||
max_value = schema["maximum"].get<int64_t>();
|
||||
} else if (schema.contains("exclusiveMaximum")) {
|
||||
max_value = schema["exclusiveMaximum"].get<int>() - 1;
|
||||
max_value = schema["exclusiveMaximum"].get<int64_t>() - 1;
|
||||
}
|
||||
std::stringstream out;
|
||||
out << "(";
|
||||
|
|
|
|||
|
|
@ -29,12 +29,29 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
|||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||
import gguf
|
||||
from gguf.vocab import MistralTokenizerType, MistralVocab
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
|
||||
_mistral_common_installed = True
|
||||
_mistral_import_error_msg = ""
|
||||
except ImportError:
|
||||
_MISTRAL_COMMON_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||
_MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
_mistral_common_installed = False
|
||||
TokenizerVersion = None
|
||||
Tekkenizer = None
|
||||
SentencePieceTokenizer = None
|
||||
_mistral_import_error_msg = (
|
||||
"Mistral format requires `mistral-common` to be installed. Please run "
|
||||
"`pip install mistral-common[image,audio]` to install it."
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("hf-to-gguf")
|
||||
|
|
@ -73,10 +90,8 @@ class ModelBase:
|
|||
use_temp_file: bool
|
||||
lazy: bool
|
||||
dry_run: bool
|
||||
part_names: list[str]
|
||||
is_safetensors: bool
|
||||
hparams: dict[str, Any]
|
||||
tensor_names: set[str] | None
|
||||
model_tensors: dict[str, Callable[[], Tensor]]
|
||||
gguf_writer: gguf.GGUFWriter
|
||||
model_name: str | None
|
||||
metadata_override: Path | None
|
||||
|
|
@ -93,18 +108,23 @@ class ModelBase:
|
|||
# Mistral format specifics
|
||||
is_mistral_format: bool = False
|
||||
disable_mistral_community_chat_template: bool = False
|
||||
sentence_transformers_dense_modules: bool = False
|
||||
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
||||
use_temp_file: bool = False, eager: bool = False,
|
||||
metadata_override: Path | None = None, model_name: str | None = None,
|
||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
|
||||
small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
|
||||
disable_mistral_community_chat_template: bool = False):
|
||||
disable_mistral_community_chat_template: bool = False,
|
||||
sentence_transformers_dense_modules: bool = False):
|
||||
if type(self) is ModelBase or \
|
||||
type(self) is TextModel or \
|
||||
type(self) is MmprojModel:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
|
||||
if self.is_mistral_format and not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
self.dir_model = dir_model
|
||||
self.ftype = ftype
|
||||
self.fname_out = fname_out
|
||||
|
|
@ -114,25 +134,9 @@ class ModelBase:
|
|||
self.lazy = not eager or (remote_hf_model_id is not None)
|
||||
self.dry_run = dry_run
|
||||
self.remote_hf_model_id = remote_hf_model_id
|
||||
if remote_hf_model_id is not None:
|
||||
self.is_safetensors = True
|
||||
|
||||
def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
|
||||
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
||||
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
||||
self.tensor_names = set(name for name in remote_tensors.keys())
|
||||
for name, remote_tensor in remote_tensors.items():
|
||||
yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor))
|
||||
|
||||
self.get_tensors = get_remote_tensors
|
||||
else:
|
||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||
self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||
self.is_safetensors = len(self.part_names) > 0
|
||||
if not self.is_safetensors:
|
||||
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
|
||||
self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
|
||||
self.tensor_names = None
|
||||
self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
|
||||
self.metadata_override = metadata_override
|
||||
self.model_name = model_name
|
||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||
|
|
@ -148,6 +152,8 @@ class ModelBase:
|
|||
logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})")
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_BF16
|
||||
|
||||
self.dequant_model()
|
||||
|
||||
# Configure GGUF Writer
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
|
||||
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
|
||||
|
|
@ -169,67 +175,215 @@ class ModelBase:
|
|||
return None
|
||||
raise KeyError(f"could not find any of: {keys}")
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
tensor_names_from_parts: set[str] = set()
|
||||
def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]:
|
||||
tensors: dict[str, Callable[[], Tensor]] = {}
|
||||
|
||||
if remote_hf_model_id is not None:
|
||||
is_safetensors = True
|
||||
|
||||
logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}")
|
||||
remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id)
|
||||
for name, remote_tensor in remote_tensors.items():
|
||||
tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r)
|
||||
|
||||
return tensors
|
||||
|
||||
prefix = "model" if not self.is_mistral_format else "consolidated"
|
||||
part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors")
|
||||
is_safetensors: bool = len(part_names) > 0
|
||||
if not is_safetensors:
|
||||
part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||
|
||||
tensor_names_from_index: set[str] = set()
|
||||
|
||||
if not self.is_mistral_format:
|
||||
index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin"
|
||||
index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin"
|
||||
index_name += ".index.json"
|
||||
index_file = self.dir_model / index_name
|
||||
|
||||
if index_file.is_file():
|
||||
self.tensor_names = set()
|
||||
logger.info(f"gguf: loading model weight map from '{index_name}'")
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index: dict[str, Any] = json.load(f)
|
||||
weight_map = index.get("weight_map")
|
||||
if weight_map is None or not isinstance(weight_map, dict):
|
||||
raise ValueError(f"Can't load 'weight_map' from {index_name!r}")
|
||||
self.tensor_names.update(weight_map.keys())
|
||||
tensor_names_from_index.update(weight_map.keys())
|
||||
else:
|
||||
self.tensor_names = tensor_names_from_parts
|
||||
weight_map = {}
|
||||
else:
|
||||
self.tensor_names = tensor_names_from_parts
|
||||
weight_map = {}
|
||||
|
||||
for part_name in self.part_names:
|
||||
logger.info(f"gguf: loading model part '{part_name}'")
|
||||
for part_name in part_names:
|
||||
logger.info(f"gguf: indexing model part '{part_name}'")
|
||||
ctx: ContextManager[Any]
|
||||
if self.is_safetensors:
|
||||
if is_safetensors:
|
||||
from safetensors import safe_open
|
||||
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
|
||||
else:
|
||||
ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
|
||||
|
||||
with ctx as model_part:
|
||||
tensor_names_from_parts.update(model_part.keys())
|
||||
assert model_part is not None
|
||||
|
||||
for name in model_part.keys():
|
||||
if self.is_safetensors:
|
||||
if is_safetensors:
|
||||
if self.lazy:
|
||||
data = model_part.get_slice(name)
|
||||
data = LazyTorchTensor.from_safetensors_slice(data)
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731
|
||||
else:
|
||||
data = model_part.get_tensor(name)
|
||||
data_gen = lambda data=data: data # noqa: E731
|
||||
else:
|
||||
data = model_part[name]
|
||||
if self.lazy:
|
||||
data = LazyTorchTensor.from_eager(data)
|
||||
yield name, data
|
||||
data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731
|
||||
else:
|
||||
data_gen = lambda data=data: data # noqa: E731
|
||||
tensors[name] = data_gen
|
||||
|
||||
# verify tensor name presence and identify potentially missing files
|
||||
if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0:
|
||||
missing = sorted(self.tensor_names.difference(tensor_names_from_parts))
|
||||
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
|
||||
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
||||
if len(extra) == 0 and len(missing_files) > 0:
|
||||
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
|
||||
f"Missing tensors: {missing}")
|
||||
if len(tensor_names_from_index) > 0:
|
||||
tensor_names_from_parts = set(tensors.keys())
|
||||
if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0:
|
||||
missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts))
|
||||
extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index))
|
||||
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
|
||||
if len(extra) == 0 and len(missing_files) > 0:
|
||||
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
|
||||
f"Missing tensors: {missing}")
|
||||
else:
|
||||
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
||||
f"Missing tensors: {missing}\n"
|
||||
f"Extra tensors: {extra}")
|
||||
|
||||
return tensors
|
||||
|
||||
def dequant_model(self):
|
||||
tensors_to_remove: list[str] = []
|
||||
new_tensors: dict[str, Callable[[], Tensor]] = {}
|
||||
|
||||
if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict):
|
||||
quant_method = quant_config.get("quant_method")
|
||||
|
||||
def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor:
|
||||
weight = weight.view(torch.uint8)
|
||||
orig_shape = weight.shape
|
||||
|
||||
shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape)))))
|
||||
data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift
|
||||
data = data & 3
|
||||
data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:]))
|
||||
|
||||
# The scale is inverted
|
||||
return data / scale.float()
|
||||
|
||||
def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor:
|
||||
scale = scale.float()
|
||||
|
||||
if (weight_block_size := quant_config.get("weight_block_size")):
|
||||
# TODO: make sure it's a list of integers
|
||||
for i, size in enumerate(weight_block_size):
|
||||
scale = scale.repeat_interleave(size, i)
|
||||
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
|
||||
scale = scale[tuple(slice(0, size) for size in weight.shape)]
|
||||
|
||||
return weight.float() * scale
|
||||
|
||||
# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
|
||||
def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor:
|
||||
bits = quant_config["bits"]
|
||||
assert bits in (2, 3, 4, 8)
|
||||
assert qweight.dtype == qzeros.dtype
|
||||
maxq = (2 ** bits) - 1
|
||||
weight = None
|
||||
zeros = None
|
||||
pack_dtype_bits = qweight.dtype.itemsize * 8
|
||||
|
||||
if bits in [2, 4, 8]:
|
||||
pack_factor = pack_dtype_bits // bits
|
||||
wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0)
|
||||
if self.lazy:
|
||||
wf = LazyTorchTensor.from_eager(wf)
|
||||
|
||||
zeros = torch.bitwise_right_shift(
|
||||
qzeros.unsqueeze(2).expand(-1, -1, pack_factor),
|
||||
wf.unsqueeze(0)
|
||||
).to(torch.int16 if bits == 8 else torch.int8)
|
||||
zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape)
|
||||
|
||||
weight = torch.bitwise_and(
|
||||
torch.bitwise_right_shift(
|
||||
qweight.unsqueeze(1).expand(-1, pack_factor, -1),
|
||||
wf.unsqueeze(-1)
|
||||
).to(torch.int16 if bits == 8 else torch.int8),
|
||||
maxq
|
||||
)
|
||||
elif bits == 3:
|
||||
raise NotImplementedError("3-bit gptq dequantization is not yet implemented")
|
||||
|
||||
assert weight is not None
|
||||
assert zeros is not None
|
||||
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
|
||||
# gptq_v2 doesn't need to offset zeros
|
||||
if quant_config.get("checkpoint_format", "gptq") == "gptq":
|
||||
zeros += 1
|
||||
|
||||
return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T
|
||||
|
||||
if quant_method == "bitnet":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale"):
|
||||
weight_name = name.removesuffix("_scale")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s())
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method == "fp8":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".weight_scale_inv"):
|
||||
weight_name = name.removesuffix("_scale_inv")
|
||||
w = self.model_tensors[weight_name]
|
||||
s = self.model_tensors[name]
|
||||
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s())
|
||||
tensors_to_remove.append(name)
|
||||
elif quant_method == "gptq":
|
||||
for name in self.model_tensors.keys():
|
||||
if name.endswith(".qweight"):
|
||||
base_name = name.removesuffix(".qweight")
|
||||
g_idx = self.model_tensors[base_name + ".g_idx"]
|
||||
qweight = self.model_tensors[base_name + ".qweight"]
|
||||
qzeros = self.model_tensors[base_name + ".qzeros"]
|
||||
scales = self.model_tensors[base_name + ".scales"]
|
||||
new_tensors[base_name + ".weight"] = (
|
||||
lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq(
|
||||
g(), w(), z(), s()
|
||||
)
|
||||
)
|
||||
tensors_to_remove += [
|
||||
base_name + n
|
||||
for n in (
|
||||
".g_idx",
|
||||
".qzeros",
|
||||
".qweight",
|
||||
".scales",
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
|
||||
f"Missing tensors: {missing}\n"
|
||||
f"Extra tensors: {extra}")
|
||||
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
|
||||
|
||||
for name in tensors_to_remove:
|
||||
if name in self.model_tensors:
|
||||
del self.model_tensors[name]
|
||||
|
||||
for name, value in new_tensors.items():
|
||||
self.model_tensors[name] = value
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
for name, gen in self.model_tensors.items():
|
||||
yield name, gen()
|
||||
|
||||
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
|
||||
if key not in gguf.MODEL_TENSORS[self.model_arch]:
|
||||
|
|
@ -588,6 +742,12 @@ class TextModel(ModelBase):
|
|||
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
logger.info(f"gguf: experts used count = {n_experts_used}")
|
||||
if (n_expert_groups := self.hparams.get("n_group")) is not None:
|
||||
self.gguf_writer.add_expert_group_count(n_expert_groups)
|
||||
logger.info(f"gguf: expert groups count = {n_expert_groups}")
|
||||
if (n_group_used := self.hparams.get("topk_group")) is not None:
|
||||
self.gguf_writer.add_expert_group_used_count(n_group_used)
|
||||
logger.info(f"gguf: expert groups used count = {n_group_used}")
|
||||
|
||||
if (head_dim := self.hparams.get("head_dim")) is not None:
|
||||
self.gguf_writer.add_key_length(head_dim)
|
||||
|
|
@ -889,8 +1049,8 @@ class TextModel(ModelBase):
|
|||
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
|
||||
res = "mellum"
|
||||
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
|
||||
res = "llada-moe"
|
||||
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
|
||||
res = "bailingmoe2"
|
||||
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
|
||||
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
|
||||
res = "granite-docling"
|
||||
|
|
@ -1343,6 +1503,17 @@ class MmprojModel(ModelBase):
|
|||
def set_type(self):
|
||||
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
|
||||
|
||||
def prepare_metadata(self, vocab_only: bool):
|
||||
super().prepare_metadata(vocab_only=vocab_only)
|
||||
|
||||
output_type: str = self.ftype.name.partition("_")[2]
|
||||
|
||||
if self.fname_out.is_dir():
|
||||
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=output_type, model_type=None)
|
||||
self.fname_out = self.fname_out / f"mmproj-{fname_default}.gguf"
|
||||
else:
|
||||
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
|
|
@ -1360,8 +1531,8 @@ class MmprojModel(ModelBase):
|
|||
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
|
||||
|
||||
# preprocessor config
|
||||
image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||
image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
|
||||
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||
image_std = _MISTRAL_COMMON_DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
|
||||
|
||||
self.gguf_writer.add_vision_image_mean(image_mean)
|
||||
self.gguf_writer.add_vision_image_std(image_std)
|
||||
|
|
@ -2030,6 +2201,9 @@ class LlamaModel(TextModel):
|
|||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
||||
|
||||
def _set_vocab_mistral(self):
|
||||
if not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
vocab = MistralVocab(self.dir_model)
|
||||
logger.info(
|
||||
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
|
||||
|
|
@ -2286,18 +2460,21 @@ class ArceeModel(LlamaModel):
|
|||
)
|
||||
class LlavaVisionModel(MmprojModel):
|
||||
img_break_tok_id = -1
|
||||
use_break_tok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.hparams.get("model_type") == "pixtral":
|
||||
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
|
||||
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
||||
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
|
||||
if self.use_break_tok:
|
||||
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
|
||||
elif self.is_mistral_format:
|
||||
# hparams is already vision config here so norm_eps is only defined in global_config.
|
||||
self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
|
||||
assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
|
||||
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
|
||||
if self.use_break_tok:
|
||||
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
|
||||
logger.info(f"Image break token id: {self.img_break_tok_id}")
|
||||
|
|
@ -3788,6 +3965,10 @@ class Qwen3Model(Qwen2Model):
|
|||
return torch.stack([true_row, false_row], dim=0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "model.vision_" in name:
|
||||
# skip multimodal tensors
|
||||
return []
|
||||
|
||||
if self.is_rerank:
|
||||
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
|
||||
is_real_head = not self.is_tied_embeddings and "lm_head" in name
|
||||
|
|
@ -4355,27 +4536,6 @@ class CodeShellModel(TextModel):
|
|||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||
|
||||
_has_tok_embd = False
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
|
||||
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
# assuming token_embd.weight is seen before output.weight
|
||||
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
|
||||
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
|
||||
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
|
||||
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
|
||||
self.tensor_names.remove("transformer.wte.weight")
|
||||
elif new_name == tok_embd_name:
|
||||
self._has_tok_embd = True
|
||||
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register("InternLM2ForCausalLM")
|
||||
class InternLM2Model(TextModel):
|
||||
|
|
@ -5269,6 +5429,53 @@ class Gemma3Model(TextModel):
|
|||
@ModelBase.register("Gemma3TextModel")
|
||||
class EmbeddingGemma(Gemma3Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
|
||||
module_paths = []
|
||||
dense_features_dims = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if self.sentence_transformers_dense_modules:
|
||||
# read modules.json to determine if model has Dense layers
|
||||
modules_file = self.dir_model / "modules.json"
|
||||
if modules_file.is_file():
|
||||
with open(modules_file, encoding="utf-8") as modules_json_file:
|
||||
mods = json.load(modules_json_file)
|
||||
for mod in mods:
|
||||
if mod["type"] == "sentence_transformers.models.Dense":
|
||||
mod_path = mod["path"]
|
||||
# check if model.safetensors file for Dense layer exists
|
||||
model_tensors_file = self.dir_model / mod_path / "model.safetensors"
|
||||
if model_tensors_file.is_file():
|
||||
self.module_paths.append(mod_path)
|
||||
# read config.json of the Dense layer to get in/out features
|
||||
mod_conf_file = self.dir_model / mod_path / "config.json"
|
||||
if mod_conf_file.is_file():
|
||||
with open(mod_conf_file, encoding="utf-8") as mod_conf_json_file:
|
||||
mod_conf = json.load(mod_conf_json_file)
|
||||
# hparams dense_2_feat_out and dense_3_feat_in are required when loading model's dense weights
|
||||
prefix = self._get_dense_prefix(mod_path)
|
||||
if mod_conf["in_features"] is not None and mod_conf["out_features"] is not None:
|
||||
self.dense_features_dims[prefix] = (mod_conf["in_features"], mod_conf["out_features"])
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
from safetensors.torch import load_file
|
||||
module_paths = list(self.module_paths)
|
||||
for i, module_path in enumerate(module_paths):
|
||||
tensors_file = self.dir_model / module_path / "model.safetensors"
|
||||
local_tensors = load_file(tensors_file)
|
||||
tensor_name = self._get_dense_prefix(module_path)
|
||||
for name, local_tensor in local_tensors.items():
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
orig_name = name.replace("linear", tensor_name)
|
||||
name = self.map_tensor_name(orig_name)
|
||||
yield name, local_tensor.clone()
|
||||
|
||||
@staticmethod
|
||||
def _get_dense_prefix(module_path) -> str:
|
||||
"""Get the tensor name prefix for the Dense layer from module path."""
|
||||
tensor_name = "dense_2" if module_path == "2_Dense" else "dense_3"
|
||||
return tensor_name
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
|
@ -5285,6 +5492,10 @@ class EmbeddingGemma(Gemma3Model):
|
|||
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
|
||||
f"instead of {self.hparams['sliding_window']}")
|
||||
self.gguf_writer.add_sliding_window(orig_sliding_window)
|
||||
if self.sentence_transformers_dense_modules:
|
||||
for dense, dims in self.dense_features_dims.items():
|
||||
logger.info(f"Setting dense layer {dense} in/out features to {dims}")
|
||||
self.gguf_writer.add_dense_features_dims(dense, dims[0], dims[1])
|
||||
|
||||
self._try_set_pooling_type()
|
||||
|
||||
|
|
@ -5912,20 +6123,12 @@ class Mamba2Model(TextModel):
|
|||
class JambaModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.JAMBA
|
||||
|
||||
def get_vocab_base_pre(self, tokenizer) -> str:
|
||||
del tokenizer # unused
|
||||
|
||||
return "gpt-2"
|
||||
|
||||
def set_vocab(self):
|
||||
if (self.dir_model / "tokenizer.model").is_file():
|
||||
# Using Jamba's tokenizer.json causes errors on model load
|
||||
# (something about "byte not found in vocab"),
|
||||
# but there's a working tokenizer.model
|
||||
self._set_vocab_sentencepiece()
|
||||
else:
|
||||
# Some Jamba models only have a tokenizer.json, which works.
|
||||
self._set_vocab_gpt2()
|
||||
self._set_vocab_llama_hf()
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
|
||||
|
|
@ -8009,6 +8212,101 @@ class BailingMoeModel(TextModel):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("BailingMoeV2ForCausalLM")
|
||||
class BailingMoeV2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.BAILINGMOE2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if nextn_layers := self.hparams.get("num_nextn_predict_layers", 0):
|
||||
self.block_count = self.hparams["num_hidden_layers"] + nextn_layers
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
if (rope_dim := hparams.get("head_dim")) is None:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
else:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"]))
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["score_function"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif hparams["score_function"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
|
||||
|
||||
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
|
||||
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "mlp.experts" in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
|
||||
if name.endswith(".expert_bias"):
|
||||
name = name.replace(".expert_bias", ".expert_bias.bias")
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
|
||||
class GroveMoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GROVEMOE
|
||||
|
|
@ -8667,6 +8965,13 @@ class SmolLM3Model(LlamaModel):
|
|||
class GptOssModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GPT_OSS
|
||||
|
||||
# TODO: remove once MXFP4 is supported more generally
|
||||
def dequant_model(self):
|
||||
quant_config = self.hparams.get("quantization_config")
|
||||
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
|
||||
return
|
||||
return super().dequant_model()
|
||||
|
||||
def transform_nibble_layout(self, tensor):
|
||||
assert tensor.dtype == torch.uint8
|
||||
assert tensor.shape[-1] == 16
|
||||
|
|
@ -9069,7 +9374,7 @@ class MistralModel(LlamaModel):
|
|||
|
||||
@staticmethod
|
||||
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
|
||||
assert TokenizerVersion is not None, "mistral_common is not installed"
|
||||
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
|
||||
assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), (
|
||||
f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}"
|
||||
)
|
||||
|
|
@ -9137,6 +9442,21 @@ class PixtralModel(LlavaVisionModel):
|
|||
return super().map_tensor_name(name, try_suffixes)
|
||||
|
||||
|
||||
@ModelBase.register("LightOnOCRForConditionalGeneration")
|
||||
class LightOnOCRVisionModel(LlavaVisionModel):
|
||||
is_mistral_format = False
|
||||
use_break_tok = False
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("model.vision_encoder.", "vision_tower.")
|
||||
name = name.replace("model.vision_projection.", "multi_modal_projector.")
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("KimiVLForConditionalGeneration")
|
||||
class KimiVLModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
@ -9335,6 +9655,13 @@ def parse_args() -> argparse.Namespace:
|
|||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sentence-transformers-dense-modules", action="store_true",
|
||||
help=("Whether to include sentence-transformers dense modules."
|
||||
"It can be used for sentence-transformers models, like google/embeddinggemma-300m"
|
||||
"Default these modules are not included.")
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.print_supported_models and args.model is None:
|
||||
parser.error("the following arguments are required: model")
|
||||
|
|
@ -9397,9 +9724,13 @@ def main() -> None:
|
|||
if args.remote:
|
||||
hf_repo_id = args.model
|
||||
from huggingface_hub import snapshot_download
|
||||
allowed_patterns = ["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]
|
||||
if args.sentence_transformers_dense_modules:
|
||||
# include sentence-transformers dense modules safetensors files
|
||||
allowed_patterns.append("*.safetensors")
|
||||
local_dir = snapshot_download(
|
||||
repo_id=hf_repo_id,
|
||||
allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"])
|
||||
allow_patterns=allowed_patterns)
|
||||
dir_model = Path(local_dir)
|
||||
logger.info(f"Downloaded config and tokenizer to {local_dir}")
|
||||
else:
|
||||
|
|
@ -9435,11 +9766,9 @@ def main() -> None:
|
|||
|
||||
logger.info(f"Loading model: {dir_model.name}")
|
||||
|
||||
if args.mmproj:
|
||||
if "mmproj" not in fname_out.name:
|
||||
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
|
||||
|
||||
is_mistral_format = args.mistral_format
|
||||
if is_mistral_format and not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
disable_mistral_community_chat_template = args.disable_mistral_community_chat_template
|
||||
|
||||
with torch.inference_mode():
|
||||
|
|
@ -9467,7 +9796,8 @@ def main() -> None:
|
|||
split_max_tensors=args.split_max_tensors,
|
||||
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
|
||||
small_first_shard=args.no_tensor_first_split,
|
||||
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template
|
||||
remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
|
||||
sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
|
||||
)
|
||||
|
||||
if args.vocab_only:
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ models = [
|
|||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
|
||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
{
|
||||
"version": 4,
|
||||
"configurePresets": [
|
||||
{
|
||||
"name": "arm64-android-snapdragon",
|
||||
"hidden": true,
|
||||
"architecture": { "value": "arm64", "strategy": "external" },
|
||||
"toolset": { "value": "host=x86_64", "strategy": "external" },
|
||||
"cacheVariables": {
|
||||
"ANDROID_ABI": "arm64-v8a",
|
||||
"ANDROID_PLATFORM": "android-31",
|
||||
"CMAKE_TOOLCHAIN_FILE": "$env{ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake",
|
||||
"CMAKE_C_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "android_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "ON",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"LLAMA_CURL": "OFF"
|
||||
}
|
||||
},
|
||||
|
||||
{
|
||||
"name": "arm64-windows-snapdragon",
|
||||
"inherits": [ "base", "arm64-windows-llvm" ],
|
||||
"cacheVariables": {
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "windows_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "ON",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"LLAMA_CURL": "OFF"
|
||||
}
|
||||
},
|
||||
|
||||
{ "name": "arm64-android-snapdragon-debug" , "inherits": [ "base", "arm64-android-snapdragon", "debug" ] },
|
||||
{ "name": "arm64-android-snapdragon-release", "inherits": [ "base", "arm64-android-snapdragon", "release" ] },
|
||||
|
||||
{ "name": "arm64-windows-snapdragon-debug" , "inherits": [ "base", "arm64-windows-snapdragon", "debug" ] },
|
||||
{ "name": "arm64-windows-snapdragon-release", "inherits": [ "base", "arm64-windows-snapdragon", "release" ] }
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
# Snapdragon-based Android devices
|
||||
|
||||
## How to Build
|
||||
|
||||
The easiest way to build llama.cpp for a Snapdragon-based Android device is using the toolchain Docker image (see github.com/snapdragon-toolchain).
|
||||
This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc.
|
||||
|
||||
This method works on Linux, macOS, and Windows. macOS and Windows users should install Docker Desktop.
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ docker run -it -u $(id -u):$(id -g) --volume $(pwd):/workspace --platform linux/amd64 ghcr.io/snapdragon-toolchain/arm64-android:v0.3
|
||||
[d]/> cd /workspace
|
||||
```
|
||||
|
||||
The rest of the Android build process assumes that you're running inside the toolchain container.
|
||||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||
|
||||
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
|
||||
Preset CMake variables:
|
||||
ANDROID_ABI="arm64-v8a"
|
||||
...
|
||||
CMAKE_TOOLCHAIN_FILE="/opt/android-ndk-r28b/build/cmake/android.toolchain.cmake"
|
||||
GGML_HEXAGON="ON"
|
||||
GGML_OPENCL="ON"
|
||||
GGML_OPENMP="OFF"
|
||||
HEXAGON_SDK_ROOT="/opt/hexagon/6.4.0.2"
|
||||
...
|
||||
-- Including OpenCL backend
|
||||
-- Including Hexagon backend
|
||||
...
|
||||
-- Build files have been written to: /workspace/build-snapdragon
|
||||
|
||||
[d]/workspace> cmake --build build-snapdragon
|
||||
...
|
||||
[144/356] Performing build step for 'htp-v73'
|
||||
[1/16] Generating htp_iface_skel.c, htp_iface_stub.c, htp_iface.h
|
||||
[2/16] Building C object CMakeFiles/ggml-htp-v73.dir/hvx-sigmoid.c.obj
|
||||
[3/16] Building C object CMakeFiles/ggml-htp-v73.dir/htp-dma.c.obj
|
||||
[4/16] Building C object CMakeFiles/ggml-htp-v73.dir/worker-pool.c.obj
|
||||
...
|
||||
-- Installing: /workspace/build-snapdragon/ggml/src/ggml-hexagon/libggml-htp-v73.so
|
||||
-- Installing: /workspace/build-snapdragon/ggml/src/ggml-hexagon/libggml-htp-v75.so
|
||||
...
|
||||
```
|
||||
|
||||
To generate an installable "package" simply use cmake --install:
|
||||
|
||||
```
|
||||
[d]/workspace> cmake --install build-snapdragon --prefix pkg-adb/llama.cpp
|
||||
-- Install configuration: "Release"
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-cpu.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-opencl.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-hexagon.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v73.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v75.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v79.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v81.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml.so
|
||||
...
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-bench
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-cli
|
||||
...
|
||||
```
|
||||
|
||||
## How to Install
|
||||
|
||||
For this step, your device needs to be configured for on-device development.
|
||||
Please see https://developer.android.com/studio/debug/dev-options for details.
|
||||
|
||||
Once ADB is enabled, use `adb push` to install `pkg-snapdragon` on the device.
|
||||
**Note that the toolchain Docker image doesn't have ADB and doesn't set up the ADB bridge. Please use native ADB on the host.**
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ adb push pkg-adb/llama.cpp /data/local/tmp/
|
||||
pkg-adb/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
|
||||
pkg-adb/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
|
||||
pkg-adb/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
|
||||
102 files pushed, 0 skipped. 186.9 MB/s (963151597 bytes in 4.914s)
|
||||
```
|
||||
|
||||
At this point, you should also install some models:
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ wget https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||
...
|
||||
2025-10-11 12:04:52 (10.7 MB/s) - ‘Llama-3.2-1B-Instruct-Q4_0.gguf’ saved [773025920/773025920]
|
||||
|
||||
~/src/llama.cpp$ adb push Llama-3.2-1B-Instruct-Q4_0.gguf /data/local/tmp/gguf
|
||||
Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 bytes in 19.250s)
|
||||
```
|
||||
|
||||
## How to Run
|
||||
|
||||
The easiest way to run llama.cpp cli tools is using provided wrapper scripts that properly set up all required environment variables.
|
||||
|
||||
llama.cpp supports three backends on Snapdragon-based devices: CPU, Adreno GPU (GPUOpenCL), and Hexagon NPU (HTP0-4).
|
||||
You can select which backend to run the model on using the `D=` variable, which maps to the `--device` option.
|
||||
|
||||
Hexagon NPU behaves as a "GPU" device when it comes to `-ngl` and other offload-related options.
|
||||
|
||||
Here are some examples of running various llama.cpp tools via ADB.
|
||||
|
||||
Simple question for Llama-3.2-1B
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ M=Llama-3.2-1B-Instruct-Q4_0.gguf D=HTP0 ./scripts/snapdragon/adb/run-cli.sh -no-cnv -p "what is the most popular cookie in the world?"
|
||||
...
|
||||
ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1
|
||||
ggml-hex: Hexagon Arch version v79
|
||||
ggml-hex: allocating new session: HTP0
|
||||
ggml-hex: new session: HTP0 : session-id 0 domain-id 3 uri file:///libggml-htp-v79.so?htp_iface_skel_handle_invoke&_modver=1.0&_dom=cdsp&_session=0 handle 0xb4000072c7955e50
|
||||
...
|
||||
load_tensors: offloading output layer to GPU
|
||||
load_tensors: offloaded 17/17 layers to GPU
|
||||
load_tensors: CPU model buffer size = 225.49 MiB
|
||||
load_tensors: HTP0 model buffer size = 0.26 MiB
|
||||
load_tensors: HTP0-REPACK model buffer size = 504.00 MiB
|
||||
...
|
||||
I hope this helps you understand the world's most popular cookies! [end of text]
|
||||
...
|
||||
llama_perf_sampler_print: sampling time = 30.08 ms / 487 runs ( 0.06 ms per token, 16191.77 tokens per second)
|
||||
llama_perf_context_print: load time = 617.94 ms
|
||||
llama_perf_context_print: prompt eval time = 80.76 ms / 11 tokens ( 7.34 ms per token, 136.21 tokens per second)
|
||||
llama_perf_context_print: eval time = 9210.59 ms / 475 runs ( 19.39 ms per token, 51.57 tokens per second)
|
||||
llama_perf_context_print: total time = 9454.92 ms / 486 tokens
|
||||
llama_perf_context_print: graphs reused = 473
|
||||
llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted |
|
||||
llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - Host | 439 = 225 + 136 + 77 |
|
||||
llama_memory_breakdown_print: | - HTP0-REPACK | 504 = 504 + 0 + 0 |
|
||||
```
|
||||
|
||||
Summary request for OLMoE-1B-7B. This is a large model that requires two HTP sessions/devices
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ M=OLMoE-1B-7B-0125-Instruct-Q4_0.gguf NDEV=2 D=HTP0,HTP1 ./scripts/snapdragon/adb/run-cli.sh -f surfing.txt -no-cnv
|
||||
...
|
||||
ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1
|
||||
ggml-hex: Hexagon Arch version v81
|
||||
ggml-hex: allocating new session: HTP0
|
||||
ggml-hex: allocating new session: HTP1
|
||||
...
|
||||
load_tensors: offloading output layer to GPU
|
||||
load_tensors: offloaded 17/17 layers to GPU
|
||||
load_tensors: CPU model buffer size = 143.86 MiB
|
||||
load_tensors: HTP1 model buffer size = 0.23 MiB
|
||||
load_tensors: HTP1-REPACK model buffer size = 1575.00 MiB
|
||||
load_tensors: HTP0 model buffer size = 0.28 MiB
|
||||
load_tensors: HTP0-REPACK model buffer size = 2025.00 MiB
|
||||
...
|
||||
llama_context: CPU output buffer size = 0.19 MiB
|
||||
llama_kv_cache: HTP1 KV buffer size = 238.00 MiB
|
||||
llama_kv_cache: HTP0 KV buffer size = 306.00 MiB
|
||||
llama_kv_cache: size = 544.00 MiB ( 8192 cells, 16 layers, 1/1 seqs), K (q8_0): 272.00 MiB, V (q8_0): 272.00 MiB
|
||||
llama_context: HTP0 compute buffer size = 15.00 MiB
|
||||
llama_context: HTP1 compute buffer size = 15.00 MiB
|
||||
llama_context: CPU compute buffer size = 24.56 MiB
|
||||
...
|
||||
llama_perf_context_print: prompt eval time = 1730.57 ms / 212 tokens ( 8.16 ms per token, 122.50 tokens per second)
|
||||
llama_perf_context_print: eval time = 5624.75 ms / 257 runs ( 21.89 ms per token, 45.69 tokens per second)
|
||||
llama_perf_context_print: total time = 7377.33 ms / 469 tokens
|
||||
llama_perf_context_print: graphs reused = 255
|
||||
llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted |
|
||||
llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - HTP1 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - Host | 742 = 144 + 544 + 54 |
|
||||
llama_memory_breakdown_print: | - HTP1-REPACK | 1575 = 1575 + 0 + 0 |
|
||||
llama_memory_breakdown_print: | - HTP0-REPACK | 2025 = 2025 + 0 + 0 |
|
||||
```
|
||||
|
||||
Op test for MUL_MAT
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ HB=0 ./scripts/snapdragon/adb/run-tool.sh test-backend-ops -b HTP0 -o MUL_MAT
|
||||
...
|
||||
Backend 2/3: HTP0
|
||||
Device description: Hexagon
|
||||
Device memory: 2048 MB (2048 MB free)
|
||||
MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK
|
||||
MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK
|
||||
MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK
|
||||
|
||||
~/src/llama.cpp-hexagon$ M=Llama-3.2-1B-Instruct-Q4_0.gguf ./scripts/snapdragon/adb/run-bench.sh -p 128 -n 64
|
||||
...
|
||||
ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1
|
||||
ggml-hex: Hexagon Arch version v79
|
||||
ggml-hex: allocating new session: HTP0
|
||||
ggml-hex: new session: HTP0 : session-id 0 domain-id 3 uri file:///libggml-htp-v79.so?htp_iface_skel_handle_invoke&_modver=1.0&_dom=cdsp&_session=0 handle 0xb400007d4b231090
|
||||
| model | size | params | backend | ngl | threads | n_batch | mmap | test | t/s |
|
||||
| ---------------| ---------: | -----: | ---------- | --: | ------: | ------: | ---: | ----: | ------------: |
|
||||
| llama 1B Q4_0 | 729.75 MiB | 1.24 B | HTP | 99 | 4 | 128 | 0 | pp128 | 169.42 ± 1.75 |
|
||||
| llama 1B Q4_0 | 729.75 MiB | 1.24 B | HTP | 99 | 4 | 128 | 0 | tg64 | 51.54 ± 1.13 |
|
||||
|
||||
build: 6a8cf8914 (6733)
|
||||
```
|
||||
|
||||
## Environment variables
|
||||
|
||||
- `GGML_HEXAGON_NDEV=1`
|
||||
Controls the number of devices/sessions to allocate. The default is 1.
|
||||
Most quantized models under 4B fit into a single session; an 8B model needs two, and a 20B model needs four.
|
||||
|
||||
- `GGML_HEXAGON_NHVX=0`
|
||||
Controls the number of HVX hardware threads to use. The default is all (actual number varies depending on the hardware version).
|
||||
|
||||
- `GGML_HEXAGON_HOSTBUF=1`
|
||||
Controls whether the Hexagon backend allocates host buffers. By default, all buffers except for REPACK are host buffers.
|
||||
This option is required for testing Ops that require REPACK buffers (MUL_MAT and MUL_MAT_ID).
|
||||
|
||||
- `GGML_HEXAGON_VERBOSE=1`
|
||||
Enables verbose logging of Ops from the backend. Example output:
|
||||
|
||||
```
|
||||
ggml-hex: HTP0 graph-compute n_nodes 2
|
||||
ggml-hex: HTP0 matmul : blk.27.ffn_up.weight x ffn_norm-27 -> ffn_up-27 : 3072:8192 x 3072:1 -> 8192:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x1
|
||||
ggml-hex: HTP0 matmul : blk.27.ffn_gate.weight x ffn_norm-27 -> ffn_gate-27 : 3072:8192 x 3072:1 -> 8192:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x3
|
||||
ggml-hex: HTP0 graph-compute n_nodes 1
|
||||
ggml-hex: HTP0 matmul : blk.27.ffn_down.weight x ffn_gate_par-27 -> ffn_out-27 : 8192:3072 x 8192:1 -> 3072:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x0
|
||||
ggml-hex: HTP0 get-tensor result_output : data 0x7592487000 offset 0 size 513024
|
||||
```
|
||||
|
||||
- `GGML_HEXAGON_PROFILE=1`
|
||||
Generates a host-side profile for the ggml-hexagon Ops.
|
||||
|
||||
- `GGML_HEXAGON_OPMASK=0x0`
|
||||
Allows enabling specific stages of the processing pipeline:
|
||||
|
||||
- `0x1` Enable Op Queue (i.e., queuing Ops into NPU)
|
||||
- `0x2` Enable Dynamic Quantizer (if needed for the Op)
|
||||
- `0x4` Enable Op Compute (MUL_MAT, etc.)
|
||||
|
||||
Examples:
|
||||
|
||||
`GGML_HEXAGON_OPMASK=0x1 llama-cli ...` - Ops are enqueued but NPU-side processing is stubbed out
|
||||
`GGML_HEXAGON_OPMASK=0x3 llama-cli ...` - NPU performs dynamic quantization and skips the rest
|
||||
`GGML_HEXAGON_OPMASK=0x7 llama-cli ...` - Full queuing and processing of Ops (default)
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
# Hexagon backend developer details
|
||||
|
||||
## Backend libraries
|
||||
|
||||
The Hexagon backend consist of two parts:
|
||||
|
||||
- `libggml-hexagon`
|
||||
This is the regular CPU-side GGML backend library, either shared or statically linked
|
||||
|
||||
- `libggml-htp-vNN`
|
||||
This is the NPU-side (HTP stands for Hexagon Tensor Processor) shared library that contains the Op dispatcher and kernels.
|
||||
The correct library is selected automatically at runtime based on the HW version.
|
||||
|
||||
Here is an example of the build artifacts
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ ls -l pkg-adb/llama.cpp/lib/libggml*
|
||||
pkg-adb/llama.cpp/lib/libggml-base.so
|
||||
pkg-adb/llama.cpp/lib/libggml-cpu.so
|
||||
pkg-adb/llama.cpp/lib/libggml-hexagon.so <<< CPU library
|
||||
pkg-adb/llama.cpp/lib/libggml-htp-v73.so <<< HTP op/kernels for Hexagon v73
|
||||
pkg-adb/llama.cpp/lib/libggml-htp-v75.so
|
||||
pkg-adb/llama.cpp/lib/libggml-htp-v79.so
|
||||
pkg-adb/llama.cpp/lib/libggml-htp-v81.so
|
||||
```
|
||||
|
||||
## Memory buffers
|
||||
|
||||
Hexagon NPU backend takes advantage of the Snapdragon's unified memory model where all buffers are fully accessible by the CPU and GPU.
|
||||
The NPU does have a dedicated tightly-coupled memory called VTCM but that memory is used only for intermediate data (e.g. dynamically
|
||||
quantized tensors) or temporary data (chunks of the weight tensors fetched via DMA).
|
||||
|
||||
Please note that currently the Hexagon backend does not implement SET/GET_ROWS Ops because there is no advantage in offloading those
|
||||
to the NPU at this point.
|
||||
|
||||
The backend does allocates non-host buffers for the tensors with datatypes that require repacking: Q4_0, Q8_0, MXFP4.
|
||||
From the MMU perspective these buffers are still regular buffers (normal access by the CPU) they are marked as non-host simply to force
|
||||
the repacking.
|
||||
|
||||
## Large model handling
|
||||
|
||||
Hexagon NPU session (aka Process Domain (PD) in the Hexagon docs) is limited to a memory mapping of around 3.5GB.
|
||||
In llama.cpp/GGML the Hexagon session is mapped to a single GGML backend device (HTP0, HTP1, etc).
|
||||
|
||||
In order to map models larger than 3.5GB we need to allocate multiple devices and split the model.
|
||||
For this we're taking advantage of the llama.cpp/GGML multi-GPU layer-splitting support.
|
||||
Each Hexagon device behaves like a GPU from the offload and model splitting perspective.
|
||||
|
||||
Here is an example of running GPT-OSS-20B model on a newer Snapdragon device with 16GB of DDR.
|
||||
|
||||
```
|
||||
M=gpt-oss-20b-Q4_0.gguf NDEV=4 D=HTP0,HTP1,HTP2,HTP3 P=surfing.txt scripts/snapdragon/adb/run-cli.sh -no-cnv -f surfing.txt -n 32
|
||||
...
|
||||
LD_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib
|
||||
ADSP_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib
|
||||
GGML_HEXAGON_NDEV=4 ./bin/llama-cli --no-mmap -m /data/local/tmp/llama.cpp/../gguf/gpt-oss-20b-Q4_0.gguf
|
||||
-t 4 --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on -ngl 99 --device HTP0,HTP1,HTP2,HTP3 -no-cnv -f surfing.txt
|
||||
...
|
||||
llama_model_loader: - type f32: 289 tensors
|
||||
llama_model_loader: - type q4_0: 96 tensors
|
||||
llama_model_loader: - type q8_0: 2 tensors
|
||||
llama_model_loader: - type mxfp4: 72 tensors
|
||||
...
|
||||
load_tensors: offloaded 25/25 layers to GPU
|
||||
load_tensors: CPU model buffer size = 1182.09 MiB
|
||||
load_tensors: HTP1 model buffer size = 6.64 MiB
|
||||
load_tensors: HTP1-REPACK model buffer size = 2505.94 MiB
|
||||
load_tensors: HTP3 model buffer size = 5.55 MiB
|
||||
load_tensors: HTP3-REPACK model buffer size = 2088.28 MiB
|
||||
load_tensors: HTP0 model buffer size = 7.75 MiB
|
||||
load_tensors: HTP0-REPACK model buffer size = 2923.59 MiB
|
||||
load_tensors: HTP2 model buffer size = 6.64 MiB
|
||||
load_tensors: HTP2-REPACK model buffer size = 2505.94 MiB
|
||||
...
|
||||
llama_context: n_ctx_per_seq (8192) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
|
||||
llama_context: CPU output buffer size = 0.77 MiB
|
||||
llama_kv_cache_iswa: creating non-SWA KV cache, size = 8192 cells
|
||||
llama_kv_cache: HTP1 KV buffer size = 25.50 MiB
|
||||
llama_kv_cache: HTP3 KV buffer size = 25.50 MiB
|
||||
llama_kv_cache: HTP0 KV buffer size = 25.50 MiB
|
||||
llama_kv_cache: HTP2 KV buffer size = 25.50 MiB
|
||||
llama_kv_cache: size = 102.00 MiB ( 8192 cells, 12 layers, 1/1 seqs), K (q8_0): 51.00 MiB, V (q8_0): 51.00 MiB
|
||||
llama_kv_cache_iswa: creating SWA KV cache, size = 256 cells
|
||||
llama_kv_cache: HTP1 KV buffer size = 0.80 MiB
|
||||
llama_kv_cache: HTP3 KV buffer size = 0.53 MiB
|
||||
llama_kv_cache: HTP0 KV buffer size = 1.06 MiB
|
||||
llama_kv_cache: HTP2 KV buffer size = 0.80 MiB
|
||||
llama_kv_cache: size = 3.19 MiB ( 256 cells, 12 layers, 1/1 seqs), K (q8_0): 1.59 MiB, V (q8_0): 1.59 MiB
|
||||
llama_context: HTP0 compute buffer size = 16.06 MiB
|
||||
llama_context: HTP1 compute buffer size = 16.06 MiB
|
||||
llama_context: HTP2 compute buffer size = 16.06 MiB
|
||||
llama_context: HTP3 compute buffer size = 16.06 MiB
|
||||
llama_context: CPU compute buffer size = 98.19 MiB
|
||||
...
|
||||
llama_perf_context_print: prompt eval time = 3843.67 ms / 197 tokens ( 19.51 ms per token, 51.25 tokens per second)
|
||||
llama_perf_context_print: eval time = 1686.13 ms / 31 runs ( 54.39 ms per token, 18.39 tokens per second)
|
||||
llama_perf_context_print: total time = 6266.30 ms / 228 tokens
|
||||
llama_perf_context_print: graphs reused = 30
|
||||
llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted |
|
||||
llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - HTP1 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - HTP2 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - HTP3 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - Host | 1476 = 1208 + 105 + 162 |
|
||||
llama_memory_breakdown_print: | - HTP1-REPACK | 2505 = 2505 + 0 + 0 |
|
||||
llama_memory_breakdown_print: | - HTP3-REPACK | 2088 = 2088 + 0 + 0 |
|
||||
llama_memory_breakdown_print: | - HTP0-REPACK | 2923 = 2923 + 0 + 0 |
|
||||
llama_memory_breakdown_print: | - HTP2-REPACK | 2505 = 2505 + 0 + 0 |
|
||||
```
|
||||
28
docs/ops.md
28
docs/ops.md
|
|
@ -22,6 +22,7 @@ Legend:
|
|||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
|
|
@ -31,7 +32,7 @@ Legend:
|
|||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
|
@ -41,6 +42,7 @@ Legend:
|
|||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
|
|
@ -51,7 +53,7 @@ Legend:
|
|||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||
|
|
@ -65,12 +67,12 @@ Legend:
|
|||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
|
|
@ -82,6 +84,7 @@ Legend:
|
|||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
|
@ -92,19 +95,22 @@ Legend:
|
|||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
|
|
|||
|
|
@ -59,6 +59,14 @@
|
|||
"CPU","EXP","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
|
||||
"CPU","GELU_ERF","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","CPU"
|
||||
"CPU","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
|
||||
"CPU","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
|
|
@ -119,6 +127,14 @@
|
|||
"CPU","EXP","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
|
||||
"CPU","GELU_ERF","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","CPU"
|
||||
"CPU","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
|
||||
"CPU","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
|
||||
"CPU","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
|
||||
"CPU","REGLU","type=f16,ne_a=[128,2,2,2],v=0,swapped=0","support","1","yes","CPU"
|
||||
"CPU","REGLU","type=f16,ne_a=[5,7,11,13],v=0,swapped=0","support","1","yes","CPU"
|
||||
"CPU","REGLU","type=f16,ne_a=[128,2,2,2],v=0,swapped=1","support","1","yes","CPU"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
16362
docs/ops/SYCL.csv
16362
docs/ops/SYCL.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -3263,27 +3263,27 @@
|
|||
"Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=1,n_seqs=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
|
|
@ -116,20 +116,39 @@ embedding-convert-model:
|
|||
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
|
||||
./scripts/embedding/convert-model.sh
|
||||
|
||||
embedding-convert-model-st:
|
||||
$(call validate_embedding_model_path,embedding-convert-model-st)
|
||||
@MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
|
||||
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
|
||||
./scripts/embedding/convert-model.sh -st
|
||||
|
||||
embedding-run-original-model:
|
||||
$(call validate_embedding_model_path,embedding-run-original-model)
|
||||
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \
|
||||
USE_SENTENCE_TRANSFORMERS="$(USE_SENTENCE_TRANSFORMERS)" \
|
||||
./scripts/embedding/run-original-model.py \
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
|
||||
$(if $(USE_SENTENCE_TRANSFORMERS),--use-sentence-transformers)
|
||||
|
||||
embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1
|
||||
embedding-run-original-model-st: embedding-run-original-model
|
||||
|
||||
embedding-run-converted-model:
|
||||
@./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
|
||||
$(if $(USE_POOLING),--pooling)
|
||||
|
||||
embedding-run-converted-model-st: USE_POOLING=1
|
||||
embedding-run-converted-model-st: embedding-run-converted-model
|
||||
|
||||
embedding-verify-logits: embedding-run-original-model embedding-run-converted-model
|
||||
@./scripts/embedding/compare-embeddings-logits.sh \
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
|
||||
|
||||
embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st
|
||||
@./scripts/embedding/compare-embeddings-logits.sh \
|
||||
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
|
||||
|
||||
embedding-inspect-original-model:
|
||||
$(call validate_embedding_model_path,embedding-inspect-original-model)
|
||||
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH}
|
||||
|
|
|
|||
|
|
@ -189,6 +189,23 @@ This command will save two files to the `data` directory, one is a binary
|
|||
file containing logits which will be used for comparison with the converted
|
||||
model, and the other is a text file which allows for manual visual inspection.
|
||||
|
||||
#### Using SentenceTransformer with numbered layers
|
||||
For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense,
|
||||
03_Dense, 04_Normalize), use the `-st` targets to apply all these layers:
|
||||
|
||||
```console
|
||||
# Run original model with SentenceTransformer (applies all numbered layers)
|
||||
(venv) $ make embedding-run-original-model-st
|
||||
|
||||
# Run converted model with pooling enabled
|
||||
(venv) $ make embedding-run-converted-model-st
|
||||
```
|
||||
|
||||
This will use the SentenceTransformer library to load and run the model, which
|
||||
automatically applies all the numbered layers in the correct order. This is
|
||||
particularly useful when comparing with models that should include these
|
||||
additional transformation layers beyond just the base model output.
|
||||
|
||||
### Model conversion
|
||||
After updates have been made to [gguf-py](../../gguf-py) to add support for the
|
||||
new model the model can be converted to GGUF format using the following command:
|
||||
|
|
@ -208,6 +225,13 @@ was done manually in the previous steps) and compare the logits:
|
|||
(venv) $ make embedding-verify-logits
|
||||
```
|
||||
|
||||
For models with SentenceTransformer layers, use the `-st` verification target:
|
||||
```console
|
||||
(venv) $ make embedding-verify-logits-st
|
||||
```
|
||||
This convenience target automatically runs both the original model with SentenceTransformer
|
||||
and the converted model with pooling enabled, then compares the results.
|
||||
|
||||
### llama-server verification
|
||||
To verify that the converted model works with llama-server, the following
|
||||
command can be used:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
|
|
@ -8,7 +11,10 @@
|
|||
|
||||
static void print_usage(int, char ** argv) {
|
||||
printf("\nexample usage:\n");
|
||||
printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [prompt]\n", argv[0]);
|
||||
printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm <norm>] [prompt]\n", argv[0]);
|
||||
printf("\n");
|
||||
printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n");
|
||||
printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n");
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
|
|
@ -17,6 +23,8 @@ int main(int argc, char ** argv) {
|
|||
std::string prompt = "Hello, my name is";
|
||||
int ngl = 0;
|
||||
bool embedding_mode = false;
|
||||
bool pooling_enabled = false;
|
||||
int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
|
||||
|
||||
{
|
||||
int i = 1;
|
||||
|
|
@ -41,9 +49,13 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
} else if (strcmp(argv[i], "-embd-mode") == 0) {
|
||||
embedding_mode = true;
|
||||
} else if (strcmp(argv[i], "-pooling") == 0) {
|
||||
pooling_enabled = true;
|
||||
} else if (strcmp(argv[i], "-embd-norm") == 0) {
|
||||
if (i + 1 < argc) {
|
||||
try {
|
||||
embedding_mode = true;
|
||||
embd_norm = std::stoi(argv[++i]);
|
||||
} catch (...) {
|
||||
print_usage(argc, argv);
|
||||
return 1;
|
||||
|
|
@ -112,7 +124,7 @@ int main(int argc, char ** argv) {
|
|||
ctx_params.no_perf = false;
|
||||
if (embedding_mode) {
|
||||
ctx_params.embeddings = true;
|
||||
ctx_params.pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE;
|
||||
ctx_params.n_ubatch = ctx_params.n_batch;
|
||||
}
|
||||
|
||||
|
|
@ -143,17 +155,27 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
float * logits;
|
||||
int n_logits;
|
||||
float * data_ptr;
|
||||
int data_size;
|
||||
const char * type;
|
||||
std::vector<float> embd_out;
|
||||
|
||||
if (embedding_mode) {
|
||||
logits = llama_get_embeddings(ctx);
|
||||
n_logits = llama_model_n_embd(model) * batch.n_tokens;
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
|
||||
const int n_embeddings = n_embd * n_embd_count;
|
||||
float * embeddings;
|
||||
type = "-embeddings";
|
||||
|
||||
const int n_embd = llama_model_n_embd(model);
|
||||
const int n_embd_count = batch.n_tokens;
|
||||
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
|
||||
embeddings = llama_get_embeddings_seq(ctx, 0);
|
||||
embd_out.resize(n_embeddings);
|
||||
printf("Normalizing embeddings using norm: %d\n", embd_norm);
|
||||
common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm);
|
||||
embeddings = embd_out.data();
|
||||
} else {
|
||||
embeddings = llama_get_embeddings(ctx);
|
||||
}
|
||||
|
||||
printf("Embedding dimension: %d\n", n_embd);
|
||||
printf("\n");
|
||||
|
|
@ -164,7 +186,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// Print first 3 values
|
||||
for (int i = 0; i < 3 && i < n_embd; i++) {
|
||||
printf("%9.6f ", logits[j * n_embd + i]);
|
||||
printf("%9.6f ", embeddings[j * n_embd + i]);
|
||||
}
|
||||
|
||||
printf(" ... ");
|
||||
|
|
@ -172,7 +194,7 @@ int main(int argc, char ** argv) {
|
|||
// Print last 3 values
|
||||
for (int i = n_embd - 3; i < n_embd; i++) {
|
||||
if (i >= 0) {
|
||||
printf("%9.6f ", logits[j * n_embd + i]);
|
||||
printf("%9.6f ", embeddings[j * n_embd + i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -180,27 +202,33 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
printf("\n");
|
||||
|
||||
printf("Embeddings size: %d\n", n_logits);
|
||||
printf("Embeddings size: %d\n", n_embeddings);
|
||||
|
||||
data_ptr = embeddings;
|
||||
data_size = n_embeddings;
|
||||
} else {
|
||||
logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
n_logits = llama_vocab_n_tokens(vocab);
|
||||
float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
const int n_logits = llama_vocab_n_tokens(vocab);
|
||||
type = "";
|
||||
printf("Vocab size: %d\n", n_logits);
|
||||
|
||||
data_ptr = logits;
|
||||
data_size = n_logits;
|
||||
}
|
||||
|
||||
std::filesystem::create_directory("data");
|
||||
|
||||
// Save logits to binary file
|
||||
// Save data to binary file
|
||||
char bin_filename[512];
|
||||
snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type);
|
||||
printf("Saving logits to %s\n", bin_filename);
|
||||
printf("Saving data to %s\n", bin_filename);
|
||||
|
||||
FILE * f = fopen(bin_filename, "wb");
|
||||
if (f == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to open binary output file\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
fwrite(logits, sizeof(float), n_logits, f);
|
||||
fwrite(data_ptr, sizeof(float), data_size, f);
|
||||
fclose(f);
|
||||
|
||||
// Also save as text for debugging
|
||||
|
|
@ -211,27 +239,27 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "%s: error: failed to open text output file\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
fprintf(f, "%d: %.6f\n", i, logits[i]);
|
||||
for (int i = 0; i < data_size; i++) {
|
||||
fprintf(f, "%d: %.6f\n", i, data_ptr[i]);
|
||||
}
|
||||
fclose(f);
|
||||
|
||||
if (!embedding_mode) {
|
||||
printf("First 10 logits: ");
|
||||
for (int i = 0; i < 10 && i < n_logits; i++) {
|
||||
printf("%.6f ", logits[i]);
|
||||
for (int i = 0; i < 10 && i < data_size; i++) {
|
||||
printf("%.6f ", data_ptr[i]);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
printf("Last 10 logits: ");
|
||||
for (int i = n_logits - 10; i < n_logits; i++) {
|
||||
if (i >= 0) printf("%.6f ", logits[i]);
|
||||
for (int i = data_size - 10; i < data_size; i++) {
|
||||
if (i >= 0) printf("%.6f ", data_ptr[i]);
|
||||
}
|
||||
printf("\n\n");
|
||||
}
|
||||
|
||||
printf("Logits saved to %s\n", bin_filename);
|
||||
printf("Logits saved to %s\n", txt_filename);
|
||||
printf("Data saved to %s\n", bin_filename);
|
||||
printf("Data saved to %s\n", txt_filename);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
|
|
|
|||
|
|
@ -4,3 +4,4 @@ torchvision
|
|||
transformers
|
||||
huggingface-hub
|
||||
accelerate
|
||||
sentence-transformers
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ if model_path is None:
|
|||
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
print("Model type: ", config.model_type)
|
||||
print("Vocab size: ", config.vocab_size)
|
||||
|
|
@ -148,8 +148,8 @@ print("BOS token id: ", config.bos_token_id)
|
|||
print("EOS token id: ", config.eos_token_id)
|
||||
|
||||
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
|
|
@ -171,7 +171,7 @@ if unreleased_model_name:
|
|||
exit(1)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", offload_folder="offload"
|
||||
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
|
|
|
|||
|
|
@ -2,6 +2,21 @@
|
|||
|
||||
set -e
|
||||
|
||||
# Parse command line arguments
|
||||
SENTENCE_TRANSFORMERS=""
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-st|--sentence-transformers)
|
||||
SENTENCE_TRANSFORMERS="--sentence-transformers-dense-modules"
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
MODEL_NAME="${MODEL_NAME:-$(basename "$EMBEDDING_MODEL_PATH")}"
|
||||
OUTPUT_DIR="${OUTPUT_DIR:-../../models}"
|
||||
TYPE="${OUTTYPE:-f16}"
|
||||
|
|
@ -15,7 +30,8 @@ echo "Converted model path:: ${CONVERTED_MODEL}"
|
|||
python ../../convert_hf_to_gguf.py --verbose \
|
||||
${EMBEDDING_MODEL_PATH} \
|
||||
--outfile ${CONVERTED_MODEL} \
|
||||
--outtype ${TYPE}
|
||||
--outtype ${TYPE} \
|
||||
${SENTENCE_TRANSFORMERS}
|
||||
|
||||
echo ""
|
||||
echo "The environment variable CONVERTED_EMBEDDING MODEL can be set to this path using:"
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ set -e
|
|||
# Parse command line arguments
|
||||
CONVERTED_MODEL=""
|
||||
PROMPTS_FILE=""
|
||||
USE_POOLING=""
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
|
|
@ -12,6 +13,10 @@ while [[ $# -gt 0 ]]; do
|
|||
PROMPTS_FILE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--pooling)
|
||||
USE_POOLING="1"
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
if [ -z "$CONVERTED_MODEL" ]; then
|
||||
CONVERTED_MODEL="$1"
|
||||
|
|
@ -47,4 +52,8 @@ echo $CONVERTED_MODEL
|
|||
|
||||
cmake --build ../../build --target llama-logits -j8
|
||||
# TODO: update logits.cpp to accept a --file/-f option for the prompt
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
|
||||
if [ -n "$USE_POOLING" ]; then
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
|
||||
else
|
||||
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
|
|||
parser = argparse.ArgumentParser(description='Process model with specified path')
|
||||
parser.add_argument('--model-path', '-m', help='Path to the model')
|
||||
parser.add_argument('--prompts-file', '-p', help='Path to file containing prompts (one per line)')
|
||||
parser.add_argument('--use-sentence-transformers', action='store_true',
|
||||
help='Use SentenceTransformer to apply all numbered layers (01_Pooling, 02_Dense, 03_Dense, 04_Normalize)')
|
||||
args = parser.parse_args()
|
||||
|
||||
def read_prompt_from_file(file_path):
|
||||
|
|
@ -31,41 +33,52 @@ model_path = os.environ.get('EMBEDDING_MODEL_PATH', args.model_path)
|
|||
if model_path is None:
|
||||
parser.error("Model path must be specified either via --model-path argument or EMBEDDING_MODEL_PATH environment variable")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
# Determine if we should use SentenceTransformer
|
||||
use_sentence_transformers = args.use_sentence_transformers or os.environ.get('USE_SENTENCE_TRANSFORMERS', '').lower() in ('1', 'true', 'yes')
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
|
||||
# This can be used to override the sliding window size for manual testing. This
|
||||
# can be useful to verify the sliding window attention mask in the original model
|
||||
# and compare it with the converted .gguf model.
|
||||
if hasattr(config, 'sliding_window'):
|
||||
original_sliding_window = config.sliding_window
|
||||
#original_sliding_window = 6
|
||||
print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
|
||||
|
||||
print(f"Using unreleased model: {unreleased_model_name}")
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
|
||||
class_name = f"{unreleased_model_name}Model"
|
||||
print(f"Importing unreleased model module: {unreleased_module_path}")
|
||||
|
||||
try:
|
||||
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
|
||||
model = model_class.from_pretrained(model_path, config=config)
|
||||
except (ImportError, AttributeError) as e:
|
||||
print(f"Failed to import or load model: {e}")
|
||||
exit(1)
|
||||
if use_sentence_transformers:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
print("Using SentenceTransformer to apply all numbered layers")
|
||||
model = SentenceTransformer(model_path)
|
||||
tokenizer = model.tokenizer
|
||||
config = model[0].auto_model.config # type: ignore
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_path, config=config)
|
||||
print(f"Model class: {type(model)}")
|
||||
print(f"Model file: {type(model).__module__}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
|
||||
# This can be used to override the sliding window size for manual testing. This
|
||||
# can be useful to verify the sliding window attention mask in the original model
|
||||
# and compare it with the converted .gguf model.
|
||||
if hasattr(config, 'sliding_window'):
|
||||
original_sliding_window = config.sliding_window
|
||||
#original_sliding_window = 6
|
||||
print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}")
|
||||
|
||||
print(f"Using unreleased model: {unreleased_model_name}")
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
|
||||
class_name = f"{unreleased_model_name}Model"
|
||||
print(f"Importing unreleased model module: {unreleased_module_path}")
|
||||
|
||||
try:
|
||||
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
|
||||
model = model_class.from_pretrained(model_path, config=config)
|
||||
except (ImportError, AttributeError) as e:
|
||||
print(f"Failed to import or load model: {e}")
|
||||
exit(1)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(model_path, config=config)
|
||||
print(f"Model class: {type(model)}")
|
||||
print(f"Model file: {type(model).__module__}")
|
||||
|
||||
# Verify the model is using the correct sliding window
|
||||
if hasattr(model.config, 'sliding_window'):
|
||||
print(f"Model's sliding_window: {model.config.sliding_window}")
|
||||
else:
|
||||
print("Model config does not have sliding_window attribute")
|
||||
if not use_sentence_transformers:
|
||||
if hasattr(model.config, 'sliding_window'): # type: ignore
|
||||
print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore
|
||||
else:
|
||||
print("Model config does not have sliding_window attribute")
|
||||
|
||||
model_name = os.path.basename(model_path)
|
||||
|
||||
|
|
@ -75,34 +88,56 @@ if args.prompts_file:
|
|||
else:
|
||||
texts = ["Hello world today"]
|
||||
|
||||
encoded = tokenizer(
|
||||
texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
tokens = encoded['input_ids'][0]
|
||||
token_strings = tokenizer.convert_ids_to_tokens(tokens)
|
||||
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
|
||||
print(f"{token_id:6d} -> '{token_str}'")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**encoded)
|
||||
hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size]
|
||||
if use_sentence_transformers:
|
||||
embeddings = model.encode(texts, convert_to_numpy=True)
|
||||
all_embeddings = embeddings # Shape: [batch_size, hidden_size]
|
||||
|
||||
# Extract embeddings for each token (matching LLAMA_POOLING_TYPE_NONE behavior)
|
||||
all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size]
|
||||
encoded = tokenizer(
|
||||
texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
tokens = encoded['input_ids'][0]
|
||||
token_strings = tokenizer.convert_ids_to_tokens(tokens)
|
||||
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
|
||||
print(f"{token_id:6d} -> '{token_str}'")
|
||||
|
||||
print(f"Hidden states shape: {hidden_states.shape}")
|
||||
print(f"All embeddings shape: {all_embeddings.shape}")
|
||||
print(f"Embedding dimension: {all_embeddings.shape[1]}")
|
||||
print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}")
|
||||
print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore
|
||||
else:
|
||||
# Standard approach: use base model output only
|
||||
encoded = tokenizer(
|
||||
texts,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Print embeddings exactly like embedding.cpp does for LLAMA_POOLING_TYPE_NONE
|
||||
n_embd = all_embeddings.shape[1]
|
||||
n_embd_count = all_embeddings.shape[0]
|
||||
tokens = encoded['input_ids'][0]
|
||||
token_strings = tokenizer.convert_ids_to_tokens(tokens)
|
||||
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
|
||||
print(f"{token_id:6d} -> '{token_str}'")
|
||||
|
||||
print() # Empty line to match C++ output
|
||||
outputs = model(**encoded)
|
||||
hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size]
|
||||
|
||||
all_embeddings = hidden_states[0].cpu().numpy() # Shape: [seq_len, hidden_size]
|
||||
|
||||
print(f"Hidden states shape: {hidden_states.shape}")
|
||||
print(f"All embeddings shape: {all_embeddings.shape}")
|
||||
print(f"Embedding dimension: {all_embeddings.shape[1]}")
|
||||
|
||||
if len(all_embeddings.shape) == 1:
|
||||
n_embd = all_embeddings.shape[0] # type: ignore
|
||||
n_embd_count = 1
|
||||
all_embeddings = all_embeddings.reshape(1, -1)
|
||||
else:
|
||||
n_embd = all_embeddings.shape[1] # type: ignore
|
||||
n_embd_count = all_embeddings.shape[0] # type: ignore
|
||||
|
||||
print()
|
||||
|
||||
for j in range(n_embd_count):
|
||||
embedding = all_embeddings[j]
|
||||
|
|
@ -120,29 +155,23 @@ with torch.no_grad():
|
|||
|
||||
print() # New line
|
||||
|
||||
print() # Final empty line to match C++ output
|
||||
print()
|
||||
|
||||
data_dir = Path("data")
|
||||
data_dir.mkdir(exist_ok=True)
|
||||
bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
|
||||
txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
|
||||
|
||||
# Save all embeddings flattened (matching what embedding.cpp would save if it did)
|
||||
flattened_embeddings = all_embeddings.flatten()
|
||||
flattened_embeddings.astype(np.float32).tofile(bin_filename)
|
||||
|
||||
with open(txt_filename, "w") as f:
|
||||
f.write(f"# Model class: {model_name}\n")
|
||||
f.write(f"# Tokens: {token_strings}\n")
|
||||
f.write(f"# Shape: {all_embeddings.shape}\n")
|
||||
f.write(f"# n_embd_count: {n_embd_count}, n_embd: {n_embd}\n\n")
|
||||
|
||||
idx = 0
|
||||
for j in range(n_embd_count):
|
||||
f.write(f"# Token {j} ({token_strings[j]}):\n")
|
||||
for i, value in enumerate(all_embeddings[j]):
|
||||
f.write(f"{j}_{i}: {value:.6f}\n")
|
||||
f.write("\n")
|
||||
print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} tokens × {n_embd} dimensions)")
|
||||
for value in all_embeddings[j]:
|
||||
f.write(f"{idx}: {value:.6f}\n")
|
||||
idx += 1
|
||||
print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)")
|
||||
print("")
|
||||
print(f"Saved bin embeddings to: {bin_filename}")
|
||||
print(f"Saved txt embeddings to: {txt_filename}")
|
||||
|
|
|
|||
|
|
@ -35,7 +35,11 @@ def cosine_similarity(a, b=None):
|
|||
|
||||
def load_embeddings_from_file(filename, n_tokens, n_embd):
|
||||
embeddings = np.fromfile(filename, dtype=np.float32)
|
||||
return embeddings.reshape(n_tokens, n_embd)
|
||||
# Check if this is pooled (single embedding) or per-token embeddings
|
||||
if len(embeddings) == n_embd:
|
||||
return embeddings.reshape(1, n_embd)
|
||||
else:
|
||||
return embeddings.reshape(n_tokens, n_embd)
|
||||
|
||||
def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
|
||||
np.set_printoptions(suppress=True, precision=6)
|
||||
|
|
@ -48,58 +52,83 @@ def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt):
|
|||
print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}")
|
||||
|
||||
n_tokens = len(tokens)
|
||||
is_pooled = python_emb.shape[0] == 1
|
||||
|
||||
# 1. Direct embedding comparison
|
||||
print(f"\n1. Raw Embedding Magnitude Comparison:")
|
||||
# Check if the distance of each token embedding from the origin and compare
|
||||
# if the vectors are on the same "sphere". This does not tell us about
|
||||
# direction (meaning of the token embedding), just magnitude.
|
||||
for i in range(n_tokens):
|
||||
py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
|
||||
cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings
|
||||
if is_pooled:
|
||||
print(f"\n[Pooled Embeddings Mode - comparing single sentence embeddings]")
|
||||
|
||||
# 1. Direct embedding comparison for pooled embeddings
|
||||
print(f"\n1. Raw Embedding Magnitude Comparison:")
|
||||
py_mag = np.linalg.norm(python_emb[0])
|
||||
cpp_mag = np.linalg.norm(cpp_emb[0])
|
||||
ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
|
||||
print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
|
||||
print(f" Pooled embedding: Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
|
||||
|
||||
# 2. Cosine similarity between tokens within each model
|
||||
# Here we check the direction of token embeddings to see if the have the
|
||||
# same meaning (similarity). This is done by calculating cosine similarity
|
||||
# of a pair of token embeddings within each model.
|
||||
print(f"\n2. Within-Model Token Similarities:")
|
||||
print(" Python model:")
|
||||
for i in range(n_tokens):
|
||||
for j in range(i+1, n_tokens):
|
||||
sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
|
||||
print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
|
||||
# 2. Cross-model similarity for pooled embeddings
|
||||
print(f"\n2. Cross-Model Pooled Embedding Similarity:")
|
||||
sim = cosine_similarity([python_emb[0]], [cpp_emb[0]])[0][0]
|
||||
print(f" Cosine similarity: {sim:.6f}")
|
||||
|
||||
print(" llama.cpp model:")
|
||||
for i in range(n_tokens):
|
||||
for j in range(i+1, n_tokens):
|
||||
sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
|
||||
print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
|
||||
return {
|
||||
'cross_model_similarities': [sim],
|
||||
'similarity_matrix_diff': np.array([[0.0]]),
|
||||
'max_diff': 0.0,
|
||||
'mean_diff': 0.0,
|
||||
'rms_diff': 0.0
|
||||
}
|
||||
else:
|
||||
# Original per-token comparison logic
|
||||
# 1. Direct embedding comparison
|
||||
print(f"\n1. Raw Embedding Magnitude Comparison:")
|
||||
# Check if the distance of each token embedding from the origin and compare
|
||||
# if the vectors are on the same "sphere". This does not tell us about
|
||||
# direction (meaning of the token embedding), just magnitude.
|
||||
for i in range(n_tokens):
|
||||
py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings
|
||||
cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings
|
||||
ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf')
|
||||
print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}")
|
||||
|
||||
# 3. Cross-model similarity (same token position)
|
||||
print(f"\n3. Cross-Model Same-Token Similarities:")
|
||||
for i in range(n_tokens):
|
||||
sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
|
||||
print(f" Token {i} ({tokens[i]}): {sim:.4f}")
|
||||
# 2. Cosine similarity between tokens within each model
|
||||
# Here we check the direction of token embeddings to see if the have the
|
||||
# same meaning (similarity). This is done by calculating cosine similarity
|
||||
# of a pair of token embeddings within each model.
|
||||
print(f"\n2. Within-Model Token Similarities:")
|
||||
print(" Python model:")
|
||||
for i in range(n_tokens):
|
||||
for j in range(i+1, n_tokens):
|
||||
sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0]
|
||||
print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
|
||||
|
||||
# 4. Similarity matrix comparison
|
||||
print(f"\n4. Similarity Matrix Differences:")
|
||||
py_sim_matrix = cosine_similarity(python_emb)
|
||||
cpp_sim_matrix = cosine_similarity(cpp_emb)
|
||||
diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
|
||||
print(" llama.cpp model:")
|
||||
for i in range(n_tokens):
|
||||
for j in range(i+1, n_tokens):
|
||||
sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0]
|
||||
print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}")
|
||||
|
||||
print(f" Max difference: {np.max(diff_matrix):.4f}")
|
||||
print(f" Mean difference: {np.mean(diff_matrix):.4f}")
|
||||
print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
|
||||
# 3. Cross-model similarity (same token position)
|
||||
print(f"\n3. Cross-Model Same-Token Similarities:")
|
||||
for i in range(n_tokens):
|
||||
sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0]
|
||||
print(f" Token {i} ({tokens[i]}): {sim:.4f}")
|
||||
|
||||
return {
|
||||
'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
|
||||
'similarity_matrix_diff': diff_matrix,
|
||||
'max_diff': np.max(diff_matrix),
|
||||
'mean_diff': np.mean(diff_matrix),
|
||||
'rms_diff': np.sqrt(np.mean(diff_matrix**2))
|
||||
}
|
||||
# 4. Similarity matrix comparison
|
||||
print(f"\n4. Similarity Matrix Differences:")
|
||||
py_sim_matrix = cosine_similarity(python_emb)
|
||||
cpp_sim_matrix = cosine_similarity(cpp_emb)
|
||||
diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix)
|
||||
|
||||
print(f" Max difference: {np.max(diff_matrix):.4f}")
|
||||
print(f" Mean difference: {np.mean(diff_matrix):.4f}")
|
||||
print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}")
|
||||
|
||||
return {
|
||||
'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)],
|
||||
'similarity_matrix_diff': diff_matrix,
|
||||
'max_diff': np.max(diff_matrix),
|
||||
'mean_diff': np.mean(diff_matrix),
|
||||
'rms_diff': np.sqrt(np.mean(diff_matrix**2))
|
||||
}
|
||||
|
||||
def read_prompt_from_file(file_path):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -251,6 +251,8 @@ option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adr
|
|||
set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
|
||||
"gmml: OpenCL API version to target")
|
||||
|
||||
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
|
||||
|
||||
# toolchain for vulkan-shaders-gen
|
||||
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// backend API
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_hexagon_init(void);
|
||||
|
||||
GGML_BACKEND_API bool ggml_backend_is_hexagon(ggml_backend_t backend);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_hexagon_reg(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
|
|||
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
|
||||
|
||||
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
||||
size_t n_threads, size_t n_devices,
|
||||
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
|
||||
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
|
||||
|
|
|
|||
|
|
@ -577,6 +577,10 @@ extern "C" {
|
|||
GGML_UNARY_OP_EXP,
|
||||
GGML_UNARY_OP_GELU_ERF,
|
||||
GGML_UNARY_OP_XIELU,
|
||||
GGML_UNARY_OP_FLOOR,
|
||||
GGML_UNARY_OP_CEIL,
|
||||
GGML_UNARY_OP_ROUND,
|
||||
GGML_UNARY_OP_TRUNC,
|
||||
|
||||
GGML_UNARY_OP_COUNT,
|
||||
};
|
||||
|
|
@ -1151,6 +1155,46 @@ extern "C" {
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_floor(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_floor_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_ceil(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_ceil_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_round(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_round_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
/**
|
||||
* Truncates the fractional part of each element in the tensor (towards zero).
|
||||
* For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0
|
||||
* Similar to std::trunc in C/C++.
|
||||
*/
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_trunc(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_trunc_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
|
||||
|
||||
// xIELU activation function
|
||||
// x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
|
||||
// where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
|
||||
|
|
|
|||
|
|
@ -145,6 +145,9 @@ endif()
|
|||
# which was introduced in POSIX.1-2008, forcing us to go higher
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
|
||||
add_compile_definitions(_XOPEN_SOURCE=700)
|
||||
elseif (CMAKE_SYSTEM_NAME MATCHES "AIX")
|
||||
# Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default,
|
||||
# in order to define _SC_PHYS_PAGES.
|
||||
else()
|
||||
add_compile_definitions(_XOPEN_SOURCE=600)
|
||||
endif()
|
||||
|
|
@ -304,6 +307,10 @@ function(ggml_add_cpu_backend_variant tag_name)
|
|||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
ggml_add_cpu_backend_variant_impl(${tag_name})
|
||||
|
|
@ -368,6 +375,14 @@ if (GGML_CPU_ALL_VARIANTS)
|
|||
else()
|
||||
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
|
||||
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
|
||||
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
|
|
@ -387,6 +402,7 @@ ggml_add_backend(Vulkan)
|
|||
ggml_add_backend(WebGPU)
|
||||
ggml_add_backend(zDNN)
|
||||
ggml_add_backend(OpenCL)
|
||||
ggml_add_backend(Hexagon)
|
||||
|
||||
foreach (target ggml-base ggml)
|
||||
target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)
|
||||
|
|
|
|||
|
|
@ -226,16 +226,23 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
|
|||
}
|
||||
|
||||
if (best_fit_block == -1) {
|
||||
// no suitable block found, try the last block (this will grow a chunks size)
|
||||
// no suitable block found, try the last block (this may grow a chunks size)
|
||||
int64_t best_reuse = INT64_MIN;
|
||||
for (int c = 0; c < alloc->n_chunks; ++c) {
|
||||
struct tallocr_chunk * chunk = alloc->chunks[c];
|
||||
if (chunk->n_free_blocks > 0) {
|
||||
struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1];
|
||||
max_avail = MAX(max_avail, block->size);
|
||||
if (block->size >= size) {
|
||||
int64_t reuse_factor = chunk->max_size - block->offset - size;
|
||||
// reuse_factor < 0 : amount of extra memory that needs to be allocated
|
||||
// reuse_factor = 0 : allocated free space exactly matches tensor size
|
||||
// reuse_factor > 0 : superfluous memory that will remain unused
|
||||
bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse;
|
||||
bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse;
|
||||
if (block->size >= size && (better_reuse || better_fit)) {
|
||||
best_fit_chunk = c;
|
||||
best_fit_block = chunk->n_free_blocks - 1;
|
||||
break;
|
||||
best_reuse = reuse_factor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -268,7 +275,7 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al
|
|||
#ifdef GGML_ALLOCATOR_DEBUG
|
||||
add_allocated_tensor(alloc, addr, tensor);
|
||||
size_t cur_max = addr.offset + size;
|
||||
if (cur_max > alloc->max_size[addr.chunk]) {
|
||||
if (cur_max > chunk->max_size) {
|
||||
// sort allocated_tensors by chunk/offset
|
||||
for (int i = 0; i < 1024; i++) {
|
||||
for (int j = i + 1; j < 1024; j++) {
|
||||
|
|
@ -598,6 +605,26 @@ static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor
|
|||
return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
|
||||
}
|
||||
|
||||
// free the extra space at the end if the new tensor is smaller
|
||||
static void ggml_gallocr_free_extra_space(ggml_gallocr_t galloc, struct ggml_tensor * node, struct ggml_tensor * parent) {
|
||||
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
||||
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
|
||||
|
||||
size_t parent_size = ggml_backend_buft_get_alloc_size(galloc->bufts[p_hn->buffer_id], parent);
|
||||
size_t node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
|
||||
|
||||
GGML_ASSERT(parent_size >= node_size);
|
||||
|
||||
if (parent_size > node_size) {
|
||||
struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id];
|
||||
struct buffer_address p_addr = p_hn->addr;
|
||||
p_addr.offset += node_size;
|
||||
size_t extra_size = parent_size - node_size;
|
||||
AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name);
|
||||
ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
|
||||
GGML_ASSERT(buffer_id >= 0);
|
||||
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
||||
|
|
@ -643,6 +670,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
|||
hn->addr = p_hn->addr;
|
||||
p_hn->allocated = false; // avoid freeing the parent
|
||||
view_src_hn->allocated = false;
|
||||
ggml_gallocr_free_extra_space(galloc, node, view_src);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
|
|
@ -650,6 +678,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
|||
hn->buffer_id = p_hn->buffer_id;
|
||||
hn->addr = p_hn->addr;
|
||||
p_hn->allocated = false; // avoid freeing the parent
|
||||
ggml_gallocr_free_extra_space(galloc, node, parent);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@
|
|||
#include "ggml-opencl.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_HEXAGON
|
||||
#include "ggml-hexagon.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_BLAS
|
||||
#include "ggml-blas.h"
|
||||
#endif
|
||||
|
|
@ -199,6 +203,9 @@ struct ggml_backend_registry {
|
|||
#ifdef GGML_USE_OPENCL
|
||||
register_backend(ggml_backend_opencl_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_HEXAGON
|
||||
register_backend(ggml_backend_hexagon_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_CANN
|
||||
register_backend(ggml_backend_cann_reg());
|
||||
#endif
|
||||
|
|
@ -598,6 +605,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
|||
ggml_backend_load_best("sycl", silent, dir_path);
|
||||
ggml_backend_load_best("vulkan", silent, dir_path);
|
||||
ggml_backend_load_best("opencl", silent, dir_path);
|
||||
ggml_backend_load_best("hexagon", silent, dir_path);
|
||||
ggml_backend_load_best("musa", silent, dir_path);
|
||||
ggml_backend_load_best("cpu", silent, dir_path);
|
||||
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
||||
|
|
|
|||
|
|
@ -51,28 +51,31 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
|
|||
return ACL_DT_UNDEFINED;
|
||||
}
|
||||
|
||||
aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
|
||||
size_t* nb, int64_t dims, aclFormat format,
|
||||
size_t offset) {
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne,
|
||||
size_t * nb,
|
||||
int64_t dims,
|
||||
aclFormat format,
|
||||
size_t offset) {
|
||||
// If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
|
||||
// added.
|
||||
int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];
|
||||
|
||||
if (ne == nullptr) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
acl_ne[i] = tensor->ne[i];
|
||||
acl_ne[i] = tensor->ne[i];
|
||||
// The step size of acl is in elements.
|
||||
acl_stride[i] = tensor->nb[i] / ggml_element_size(tensor);
|
||||
}
|
||||
} else {
|
||||
// With bcast
|
||||
for (int i = 0; i < dims; i++) {
|
||||
acl_ne[i] = ne[i];
|
||||
acl_ne[i] = ne[i];
|
||||
acl_stride[i] = nb[i] / ggml_element_size(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims);
|
||||
int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims);
|
||||
int64_t acl_storage_len = 1;
|
||||
for (int i = 0; i < final_dims; i++) {
|
||||
acl_storage_len += (acl_ne[i] - 1) * acl_stride[i];
|
||||
|
|
@ -84,15 +87,13 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
|
|||
std::reverse(acl_ne, acl_ne + final_dims);
|
||||
std::reverse(acl_stride, acl_stride + final_dims);
|
||||
|
||||
aclTensor* acl_tensor = aclCreateTensor(
|
||||
acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
|
||||
elem_offset, format, &acl_storage_len, 1,
|
||||
tensor->data);
|
||||
aclTensor * acl_tensor = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
|
||||
elem_offset, format, &acl_storage_len, 1, tensor->data);
|
||||
|
||||
return acl_tensor;
|
||||
}
|
||||
|
||||
bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) {
|
||||
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (t1->ne[i] != t0->ne[i] && t1->ne[i] != 1) {
|
||||
return true;
|
||||
|
|
@ -101,15 +102,16 @@ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) {
|
|||
return false;
|
||||
}
|
||||
|
||||
int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
|
||||
const ggml_tensor* src1,
|
||||
int64_t* bcast_src0_ne,
|
||||
int64_t* bcast_src1_ne, size_t* bcast_src0_nb,
|
||||
size_t* bcast_src1_nb) {
|
||||
int64_t ggml_cann_get_bcast_shape(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
int64_t * bcast_src0_ne,
|
||||
int64_t * bcast_src1_ne,
|
||||
size_t * bcast_src0_nb,
|
||||
size_t * bcast_src1_nb) {
|
||||
GGML_ASSERT(ggml_can_repeat(src1, src0));
|
||||
int bcast_dim_cnt = 0;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
int64_t nr = src0->ne[i] / src1->ne[i];
|
||||
int64_t nr = src0->ne[i] / src1->ne[i];
|
||||
bcast_src0_ne[bcast_dim_cnt] = src0->ne[i] / nr;
|
||||
bcast_src1_ne[bcast_dim_cnt] = src1->ne[i];
|
||||
bcast_src0_nb[bcast_dim_cnt] = src0->nb[i];
|
||||
|
|
@ -119,21 +121,26 @@ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
|
|||
// Need to add an extra dim.
|
||||
bcast_src0_ne[bcast_dim_cnt] = nr;
|
||||
bcast_src1_ne[bcast_dim_cnt] = 1;
|
||||
bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] *
|
||||
bcast_src0_ne[bcast_dim_cnt - 1];
|
||||
bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] *
|
||||
bcast_src1_ne[bcast_dim_cnt - 1];
|
||||
bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] * bcast_src0_ne[bcast_dim_cnt - 1];
|
||||
bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] * bcast_src1_ne[bcast_dim_cnt - 1];
|
||||
bcast_dim_cnt++;
|
||||
}
|
||||
}
|
||||
return bcast_dim_cnt;
|
||||
}
|
||||
|
||||
int64_t ggml_cann_get_mulmat_bcast_shape(
|
||||
const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
|
||||
const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
|
||||
int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
|
||||
size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb) {
|
||||
int64_t ggml_cann_get_mulmat_bcast_shape(const int64_t * input_ne,
|
||||
const int64_t * weight_ne,
|
||||
const int64_t * dst_ne,
|
||||
const size_t * input_nb,
|
||||
const size_t * weight_nb,
|
||||
const size_t * dst_nb,
|
||||
int64_t * bcast_input_ne,
|
||||
int64_t * bcast_weight_ne,
|
||||
int64_t * bcast_dst_ne,
|
||||
size_t * bcast_input_nb,
|
||||
size_t * bcast_weight_nb,
|
||||
size_t * bcast_dst_nb) {
|
||||
// input and dst shoule in same shape, except first two dims.
|
||||
GGML_ASSERT(input_ne[2] == dst_ne[2]);
|
||||
GGML_ASSERT(input_ne[3] == dst_ne[3]);
|
||||
|
|
@ -148,34 +155,30 @@ int64_t ggml_cann_get_mulmat_bcast_shape(
|
|||
// Do not use bcast in the first two dimensions because we only support
|
||||
// the bcast batch dimension. Just copy them.
|
||||
if (i < 2 || nr == 1) {
|
||||
bcast_input_ne[bcast_dim_cnt] = input_ne[i];
|
||||
bcast_input_ne[bcast_dim_cnt] = input_ne[i];
|
||||
bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];
|
||||
bcast_dst_ne[bcast_dim_cnt] = dst_ne[i];
|
||||
bcast_dst_ne[bcast_dim_cnt] = dst_ne[i];
|
||||
|
||||
bcast_input_nb[bcast_dim_cnt] = input_nb[i];
|
||||
bcast_input_nb[bcast_dim_cnt] = input_nb[i];
|
||||
bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];
|
||||
bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
|
||||
bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
|
||||
bcast_dim_cnt++;
|
||||
} else {
|
||||
// Need to add an extra dim.
|
||||
bcast_input_ne[bcast_dim_cnt] = nr;
|
||||
bcast_dst_ne[bcast_dim_cnt] = nr;
|
||||
bcast_input_ne[bcast_dim_cnt] = nr;
|
||||
bcast_dst_ne[bcast_dim_cnt] = nr;
|
||||
bcast_weight_ne[bcast_dim_cnt] = 1;
|
||||
bcast_input_nb[bcast_dim_cnt] = input_nb[i];
|
||||
bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
|
||||
bcast_input_nb[bcast_dim_cnt] = input_nb[i];
|
||||
bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
|
||||
bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];
|
||||
bcast_dim_cnt++;
|
||||
|
||||
bcast_input_ne[bcast_dim_cnt] = input_ne[i] / nr;
|
||||
bcast_dst_ne[bcast_dim_cnt] = dst_ne[i] / nr;
|
||||
bcast_input_ne[bcast_dim_cnt] = input_ne[i] / nr;
|
||||
bcast_dst_ne[bcast_dim_cnt] = dst_ne[i] / nr;
|
||||
bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];
|
||||
bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] *
|
||||
bcast_input_ne[bcast_dim_cnt - 1];
|
||||
bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] *
|
||||
bcast_dst_ne[bcast_dim_cnt - 1];
|
||||
bcast_weight_nb[bcast_dim_cnt] =
|
||||
bcast_weight_nb[bcast_dim_cnt - 1] *
|
||||
bcast_weight_ne[bcast_dim_cnt - 1];
|
||||
bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] * bcast_input_ne[bcast_dim_cnt - 1];
|
||||
bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] * bcast_dst_ne[bcast_dim_cnt - 1];
|
||||
bcast_weight_nb[bcast_dim_cnt] = bcast_weight_nb[bcast_dim_cnt - 1] * bcast_weight_ne[bcast_dim_cnt - 1];
|
||||
bcast_dim_cnt++;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,10 +62,12 @@ aclDataType ggml_cann_type_mapping(ggml_type type);
|
|||
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = nullptr,
|
||||
size_t* nb = nullptr, int64_t dims = 0,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0);
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne = nullptr,
|
||||
size_t * nb = nullptr,
|
||||
int64_t dims = 0,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0);
|
||||
|
||||
/**
|
||||
* @brief Template for creating an ACL tensor from provided parameters. typename TYPE
|
||||
|
|
@ -87,12 +89,15 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = null
|
|||
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
template<typename TYPE>
|
||||
aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
|
||||
TYPE type_size, int64_t* ne, TYPE* nb,
|
||||
int64_t dims,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0) {
|
||||
template <typename TYPE>
|
||||
aclTensor * ggml_cann_create_tensor(void * data_ptr,
|
||||
aclDataType dtype,
|
||||
TYPE type_size,
|
||||
int64_t * ne,
|
||||
TYPE * nb,
|
||||
int64_t dims,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0) {
|
||||
int64_t tmp_ne[GGML_MAX_DIMS * 2];
|
||||
int64_t tmp_stride[GGML_MAX_DIMS * 2];
|
||||
|
||||
|
|
@ -109,9 +114,8 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
|
|||
std::reverse(tmp_ne, tmp_ne + dims);
|
||||
std::reverse(tmp_stride, tmp_stride + dims);
|
||||
|
||||
aclTensor* acl_tensor =
|
||||
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size,
|
||||
format, &acl_storage_len, 1, data_ptr);
|
||||
aclTensor * acl_tensor =
|
||||
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);
|
||||
|
||||
return acl_tensor;
|
||||
}
|
||||
|
|
@ -132,7 +136,7 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
|
|||
* to 1. If such a dimension is found, broadcasting is required to align t1
|
||||
* with t0 for element-wise operations.
|
||||
*/
|
||||
bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1);
|
||||
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1);
|
||||
|
||||
/**
|
||||
* @brief Computes broadcast shapes and strides for two ggml_tensors.
|
||||
|
|
@ -187,19 +191,21 @@ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1);
|
|||
* dim1 in a inserted dim, should add nb for dim1,
|
||||
* and all other nb moves to next in order.
|
||||
*/
|
||||
int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* src1,
|
||||
int64_t* bcast_ne_src0, int64_t* bcast_ne_src1,
|
||||
size_t* bcast_nb_src0, size_t* bcast_nb_src1);
|
||||
int64_t ggml_cann_get_bcast_shape(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
int64_t * bcast_ne_src0,
|
||||
int64_t * bcast_ne_src1,
|
||||
size_t * bcast_nb_src0,
|
||||
size_t * bcast_nb_src1);
|
||||
|
||||
// Bcast macro to avoid duplicate code.
|
||||
#define BCAST_SHAPE(src0, src1) \
|
||||
int64_t bcast_##src0##_ne[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##src0##_nb[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##src1##_nb[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_dims = ggml_cann_get_bcast_shape( \
|
||||
src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, bcast_##src0##_nb, \
|
||||
bcast_##src1##_nb);
|
||||
#define BCAST_SHAPE(src0, src1) \
|
||||
int64_t bcast_##src0##_ne[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##src0##_nb[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##src1##_nb[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_dims = ggml_cann_get_bcast_shape(src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, \
|
||||
bcast_##src0##_nb, bcast_##src1##_nb);
|
||||
|
||||
#define BCAST_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
|
||||
|
||||
|
|
@ -233,26 +239,31 @@ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* sr
|
|||
* before cast dim.
|
||||
* @sa ggml_cann_get_bcast_shape
|
||||
*/
|
||||
int64_t ggml_cann_get_mulmat_bcast_shape(
|
||||
const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
|
||||
const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
|
||||
int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
|
||||
size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb);
|
||||
int64_t ggml_cann_get_mulmat_bcast_shape(const int64_t * input_ne,
|
||||
const int64_t * weight_ne,
|
||||
const int64_t * dst_ne,
|
||||
const size_t * input_nb,
|
||||
const size_t * weight_nb,
|
||||
const size_t * dst_nb,
|
||||
int64_t * bcast_input_ne,
|
||||
int64_t * bcast_weight_ne,
|
||||
int64_t * bcast_dst_ne,
|
||||
size_t * bcast_input_nb,
|
||||
size_t * bcast_weight_nb,
|
||||
size_t * bcast_dst_nb);
|
||||
|
||||
// Bcast macro to avoid duplicate code.
|
||||
#define BCAST_MUL_MAT_SHAPE(input, weight, dst) \
|
||||
int64_t bcast_##input##_ne[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_##weight##_ne[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_##dst##_ne[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##input##_nb[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##weight##_nb[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##dst##_nb[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape( \
|
||||
input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, \
|
||||
bcast_##input##_ne, bcast_##weight##_ne, bcast_##dst##_ne, \
|
||||
bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);
|
||||
#define BCAST_MUL_MAT_SHAPE(input, weight, dst) \
|
||||
int64_t bcast_##input##_ne[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_##weight##_ne[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_##dst##_ne[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##input##_nb[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##weight##_nb[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##dst##_nb[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape( \
|
||||
input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, bcast_##input##_ne, bcast_##weight##_ne, \
|
||||
bcast_##dst##_ne, bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);
|
||||
|
||||
#define BCAST_MUL_MAT_PARAM(tensor) \
|
||||
bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
|
||||
#define BCAST_MUL_MAT_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
|
||||
|
||||
#endif // CANN_ACL_TENSOR_H
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -62,7 +62,7 @@
|
|||
* @param dst The ggml tensor representing the destination, which op is
|
||||
* GGML_OP_REPEAT and specifies the desired dimensions.
|
||||
*/
|
||||
void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies the Leaky ReLU activation function to a tensor using the CANN
|
||||
|
|
@ -82,7 +82,7 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result of the Leaky ReLU
|
||||
* activation is stored, which op is `GGML_OP_LEAKY_RELU`
|
||||
*/
|
||||
void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_leaky_relu(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Concatenates multiple tensors along a specified dimension using the
|
||||
|
|
@ -97,7 +97,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @attention tensorList length should be 2 and the dimension using for concat
|
||||
* default to 1.
|
||||
*/
|
||||
void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_concat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Generates a sequence of evenly spaced values within a specified
|
||||
|
|
@ -113,7 +113,7 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* `start`, 'stop' and 'step' are in dst->op_params and dst->op is
|
||||
* `GGML_OP_ARANGE`.
|
||||
*/
|
||||
void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_arange(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies a clamp operation to the elements of a ggml tensor using the
|
||||
|
|
@ -131,7 +131,7 @@ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the clamped values will be stored.
|
||||
* dst->op is `GGML_OP_CLAMP`, `min` and `max` value is in dst->params.
|
||||
*/
|
||||
void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_clamp(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Scales the elements of a ggml tensor by a constant factor using the
|
||||
|
|
@ -148,7 +148,7 @@ void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the scaled values will be stored.
|
||||
* dst->op is `GGML_OP_SCALE` and `scale` value is in dst->params.
|
||||
*/
|
||||
void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_scale(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Sorts the elements of a ggml tensor and returns the indices that
|
||||
|
|
@ -163,7 +163,7 @@ void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the sorted indices will be stored.
|
||||
* dst->op is `GGML_OP_ARGSORT`.
|
||||
*/
|
||||
void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_argsort(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Layer Normalization for a ggml tensor using the CANN
|
||||
|
|
@ -185,7 +185,7 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the normalized values will be stored.
|
||||
* @attention `Var` defaults to dst->ne[0].
|
||||
*/
|
||||
void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Group Normalization for a ggml tensor using the CANN
|
||||
|
|
@ -209,7 +209,7 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
*
|
||||
* @attention eps defaults to 1e-6f.
|
||||
*/
|
||||
void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the accumulation of tensors using the CANN backend.
|
||||
|
|
@ -228,7 +228,7 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the accumulated values will be stored.
|
||||
* `inplace` is in dst->params, and dst->op is `GGML_OP_ACC`.
|
||||
*/
|
||||
void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the sum of elements along the last dimension of a ggml tensor
|
||||
|
|
@ -244,7 +244,7 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
*
|
||||
* @attention `reduce_dims` defaults to 3, which means the last dimension.
|
||||
*/
|
||||
void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the sum of elements in a ggml tensor.
|
||||
|
|
@ -258,7 +258,7 @@ void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
*
|
||||
*/
|
||||
|
||||
void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Upsamples a ggml tensor using nearest neighbor interpolation using
|
||||
|
|
@ -274,8 +274,7 @@ void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the upsampled values will be stored.
|
||||
* dst->op is `GGML_OP_UPSCALE`.
|
||||
*/
|
||||
void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
|
||||
ggml_tensor* dst);
|
||||
void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Pads a ggml tensor to match the dimensions of the destination tensor
|
||||
|
|
@ -290,7 +289,7 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
|
|||
* @param dst The destination tensor, which specifies the target dimensions for
|
||||
* padding. dst->op is `GGML_OP_PAD`.
|
||||
*/
|
||||
void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_pad(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Executes a 2D pooling operation on a ggml tensor using the CANN
|
||||
|
|
@ -307,7 +306,7 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor on which the pooling operation is to be
|
||||
* performed. dst->op is `GGML_OP_POOL_2D`.
|
||||
*/
|
||||
void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_pool2d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Duplicates a ggml tensor using the CANN backend.
|
||||
|
|
@ -326,7 +325,7 @@ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* different shape and dst is no-contiguous.
|
||||
* @note: This func need to simplify.
|
||||
*/
|
||||
void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_dup(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Root Mean Square (RMS) normalization of a ggml tensor
|
||||
|
|
@ -348,7 +347,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the normalized values will be stored.
|
||||
* dst->op is `GGML_OP_RMS_NORM`.
|
||||
*/
|
||||
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_rms_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies a diagonal mask to the tensor with a specified value.
|
||||
|
|
@ -363,7 +362,7 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* `GGML_OP_DIAG_MASK`
|
||||
* @param value The value to use for masking.
|
||||
*/
|
||||
void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float value);
|
||||
void ggml_cann_diag_mask(ggml_backend_cann_context & ctx, ggml_tensor * dst, float value);
|
||||
|
||||
/**
|
||||
* @brief Performs an image-to-column transformation on the input tensor.
|
||||
|
|
@ -378,7 +377,7 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float
|
|||
* @param dst The destination tensor that stores the result of the operation.
|
||||
* dst->op is `GGML_OP_IM2COL`.
|
||||
*/
|
||||
void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_im2col(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes time step embeddings using sine and cosine functions.
|
||||
|
|
@ -392,10 +391,10 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result of the embedding operation
|
||||
* will be stored. dst->op is `GGML_OP_TIMESTEP_EMBEDDING`.
|
||||
*/
|
||||
void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
// @see ggml_cann_dup.
|
||||
void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the softmax activation with optional masking.
|
||||
|
|
@ -417,7 +416,7 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result will be stored. dst->op is
|
||||
* `GGML_OP_SOFTMAX`.
|
||||
*/
|
||||
void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Extracts specific rows from a tensor based on indices.
|
||||
|
|
@ -429,7 +428,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param ctx The backend CANN context for executing operations.
|
||||
* @param dst The destination tensor where the extracted rows will be stored.
|
||||
*/
|
||||
void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Writes specific rows into a tensor at positions specified by indices.
|
||||
|
|
@ -441,7 +440,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param ctx The backend CANN context for executing operations.
|
||||
* @param dst The destination tensor where the specified rows will be updated.
|
||||
*/
|
||||
void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Executes matrix multiplication for the given tensor.
|
||||
|
|
@ -454,7 +453,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor for storing the result of the matrix
|
||||
* multiplication. dst->op is `GGML_OP_MUL_MAT`.
|
||||
*/
|
||||
void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies Rotary Positional Embedding (RoPE) to the input tensor.
|
||||
|
|
@ -477,7 +476,7 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @note The function currently does not support cases where the freq_scale is
|
||||
* not equal 1.
|
||||
*/
|
||||
void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the index of the maximum value along the specified dimension
|
||||
|
|
@ -492,7 +491,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the indices of the maximum values will
|
||||
* be stored. dst->op is `GGML_OP_ARGMAX`.
|
||||
*/
|
||||
void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Adds two tensors element-wise and stores the result in a destination
|
||||
|
|
@ -509,8 +508,10 @@ void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param acl_src1 The second source tensor.
|
||||
* @param acl_dst The destination tensor where the result will be stored.
|
||||
*/
|
||||
void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
|
||||
aclTensor* acl_src1, aclTensor* acl_dst = nullptr);
|
||||
void aclnn_add(ggml_backend_cann_context & ctx,
|
||||
aclTensor * acl_src0,
|
||||
aclTensor * acl_src1,
|
||||
aclTensor * acl_dst = nullptr);
|
||||
|
||||
/**
|
||||
* @brief Sub two tensors element-wise and stores the result in a destination
|
||||
|
|
@ -527,8 +528,10 @@ void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
|
|||
* @param acl_src1 The second source tensor.
|
||||
* @param acl_dst The destination tensor where the result will be stored.
|
||||
*/
|
||||
void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
|
||||
aclTensor* acl_src1, aclTensor* acl_dst = nullptr);
|
||||
void aclnn_sub(ggml_backend_cann_context & ctx,
|
||||
aclTensor * acl_src0,
|
||||
aclTensor * acl_src1,
|
||||
aclTensor * acl_dst = nullptr);
|
||||
|
||||
/**
|
||||
* @brief Performs element-wise multiplication of two tensors and stores the
|
||||
|
|
@ -546,8 +549,10 @@ void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
|
|||
* @param acl_other The second tensor for element-wise multiplication.
|
||||
* @param acl_dst The destination tensor where the result will be stored.
|
||||
*/
|
||||
void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_other, aclTensor* acl_dst = nullptr);
|
||||
void aclnn_mul(ggml_backend_cann_context & ctx,
|
||||
aclTensor * acl_src,
|
||||
aclTensor * acl_other,
|
||||
aclTensor * acl_dst = nullptr);
|
||||
|
||||
/**
|
||||
* @brief Matrix division, optionally in-place.
|
||||
|
|
@ -567,8 +572,10 @@ void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
|||
* @param inplace Flag indicating whether to perform the operation in-place on
|
||||
* `acl_src`.
|
||||
*/
|
||||
void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_other, aclTensor* acl_dst = nullptr);
|
||||
void aclnn_div(ggml_backend_cann_context & ctx,
|
||||
aclTensor * acl_src,
|
||||
aclTensor * acl_other,
|
||||
aclTensor * acl_dst = nullptr);
|
||||
|
||||
/**
|
||||
* @brief Applies element-wise cosine function to the elements of a tensor.
|
||||
|
|
@ -584,8 +591,7 @@ void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
|||
* @param acl_dst The destination tensor where the cosine results will be
|
||||
* stored.
|
||||
*/
|
||||
void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_dst);
|
||||
void aclnn_cos(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Applies element-wise sine function to the elements of a tensor.
|
||||
|
|
@ -602,8 +608,7 @@ void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
|||
* @param acl_src The source tensor on which the sine function will be applied.
|
||||
* @param acl_dst The destination tensor where the sine results will be stored.
|
||||
*/
|
||||
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_dst);
|
||||
void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Prepares broadcast-compatible ACL tensors for two input tensors and one
|
||||
|
|
@ -621,8 +626,12 @@ void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
|||
* @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.
|
||||
* @param acl_dst Output pointer to the created ACL tensor corresponding to dst.
|
||||
*/
|
||||
void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst,
|
||||
aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst);
|
||||
void bcast_shape(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
aclTensor ** acl_src0,
|
||||
aclTensor ** acl_src1,
|
||||
aclTensor ** acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
|
||||
|
|
@ -637,7 +646,7 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst,
|
|||
* @param dst The destination tensor where the transposed convolution result
|
||||
* will be stored. dst->op is `GGML_OP_CONV_TRANSPOSE_1D`.
|
||||
*/
|
||||
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies the ELU (Exponential Linear Unit) activation to a ggml tensor
|
||||
|
|
@ -662,7 +671,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
|
|||
* @param dst The destination tensor where the ELU-activated result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_ELU`.
|
||||
*/
|
||||
void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the mean of a ggml tensor element-wise using the CANN backend.
|
||||
|
|
@ -677,7 +686,7 @@ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the mean result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_MEAN`.
|
||||
*/
|
||||
void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_mean(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies 1D reflect padding to a ggml tensor using the CANN backend.
|
||||
|
|
@ -692,7 +701,7 @@ void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the padded result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_PAD_REFLECT_1D`.
|
||||
*/
|
||||
void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Counts the number of equal elements in two ggml tensors using the CANN backend.
|
||||
|
|
@ -708,7 +717,7 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_COUNT_EQUAL`.
|
||||
*/
|
||||
void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies the Step activation function to a ggml tensor using the CANN backend.
|
||||
|
|
@ -723,7 +732,7 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_STEP`.
|
||||
*/
|
||||
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Performs the Flash Attention extended operator using the CANN backend.
|
||||
|
|
@ -738,59 +747,46 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
|
||||
*/
|
||||
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/*
|
||||
* @brief A generic wrapper for ACL resources with custom deleter support.
|
||||
*/
|
||||
using any_acl_resource = std::unique_ptr<void, std::function<void(void*)>>;
|
||||
using any_acl_resource = std::unique_ptr<void, std::function<void(void *)>>;
|
||||
|
||||
/**
|
||||
* @brief Trait structure used to define how to destroy a given ACL resource type.
|
||||
*
|
||||
* @tparam T ACL resource type.
|
||||
*/
|
||||
template<typename T>
|
||||
struct acl_resource_traits;
|
||||
template <typename T> struct acl_resource_traits;
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclTensor, defines how to destroy an aclTensor resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclTensor> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyTensor(static_cast<aclTensor*>(p)));
|
||||
}
|
||||
template <> struct acl_resource_traits<aclTensor> {
|
||||
static void destroy(void * p) { ACL_CHECK(aclDestroyTensor(static_cast<aclTensor *>(p))); }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclIntArray> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray*>(p)));
|
||||
}
|
||||
template <> struct acl_resource_traits<aclIntArray> {
|
||||
static void destroy(void * p) { ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray *>(p))); }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclScalar, defines how to destroy an aclScalar resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclScalar> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyScalar(static_cast<aclScalar*>(p)));
|
||||
}
|
||||
template <> struct acl_resource_traits<aclScalar> {
|
||||
static void destroy(void * p) { ACL_CHECK(aclDestroyScalar(static_cast<aclScalar *>(p))); }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclTensorList> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList*>(p)));
|
||||
}
|
||||
template <> struct acl_resource_traits<aclTensorList> {
|
||||
static void destroy(void * p) { ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList *>(p))); }
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -800,14 +796,8 @@ struct acl_resource_traits<aclTensorList> {
|
|||
* @param ptr Raw pointer to ACL resource.
|
||||
* @return any_acl_resource Smart pointer that handles destruction.
|
||||
*/
|
||||
template<typename T>
|
||||
any_acl_resource make_acl_resource(T* ptr) {
|
||||
return any_acl_resource(
|
||||
static_cast<void*>(ptr),
|
||||
[](void* p) {
|
||||
acl_resource_traits<T>::destroy(p);
|
||||
}
|
||||
);
|
||||
template <typename T> any_acl_resource make_acl_resource(T * ptr) {
|
||||
return any_acl_resource(static_cast<void *>(ptr), [](void * p) { acl_resource_traits<T>::destroy(p); });
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -817,8 +807,7 @@ any_acl_resource make_acl_resource(T* ptr) {
|
|||
* @param vec Target vector to hold ACL resources.
|
||||
* @param args Raw pointers to ACL resources.
|
||||
*/
|
||||
template<typename... Args>
|
||||
void register_acl_resources(std::vector<any_acl_resource>& vec, Args*... args) {
|
||||
template <typename... Args> void register_acl_resources(std::vector<any_acl_resource> & vec, Args *... args) {
|
||||
(vec.emplace_back(make_acl_resource(args)), ...);
|
||||
}
|
||||
|
||||
|
|
@ -826,39 +815,36 @@ void register_acl_resources(std::vector<any_acl_resource>& vec, Args*... args) {
|
|||
* @brief Task class that wraps the execution of an aclnn function call.
|
||||
*/
|
||||
class aclnn_task : public cann_task {
|
||||
public:
|
||||
aclnn_task(aclnn_func_t aclnn_func, void * workspace_addr,
|
||||
uint64_t workspace_size, aclOpExecutor * executor,
|
||||
aclrtStream stream) :
|
||||
aclnn_func_(aclnn_func),
|
||||
workspace_addr_(workspace_addr),
|
||||
workspace_size_(workspace_size),
|
||||
executor_(executor),
|
||||
stream_(stream) {}
|
||||
virtual void run_task() override {
|
||||
ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_));
|
||||
}
|
||||
private:
|
||||
aclnn_func_t aclnn_func_;
|
||||
void * workspace_addr_;
|
||||
uint64_t workspace_size_;
|
||||
aclOpExecutor * executor_;
|
||||
aclrtStream stream_;
|
||||
public:
|
||||
aclnn_task(aclnn_func_t aclnn_func,
|
||||
void * workspace_addr,
|
||||
uint64_t workspace_size,
|
||||
aclOpExecutor * executor,
|
||||
aclrtStream stream) :
|
||||
aclnn_func_(aclnn_func),
|
||||
workspace_addr_(workspace_addr),
|
||||
workspace_size_(workspace_size),
|
||||
executor_(executor),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); }
|
||||
private:
|
||||
aclnn_func_t aclnn_func_;
|
||||
void * workspace_addr_;
|
||||
uint64_t workspace_size_;
|
||||
aclOpExecutor * executor_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class that releases ACL resources after usage.
|
||||
*/
|
||||
class release_resource_task : public cann_task {
|
||||
public:
|
||||
release_resource_task(std::vector<any_acl_resource>&& resources){
|
||||
resource_ = std::move(resources);
|
||||
}
|
||||
public:
|
||||
release_resource_task(std::vector<any_acl_resource> && resources) { resource_ = std::move(resources); }
|
||||
|
||||
virtual void run_task() override {
|
||||
resource_.clear();
|
||||
}
|
||||
private:
|
||||
virtual void run_task() override { resource_.clear(); }
|
||||
private:
|
||||
std::vector<any_acl_resource> resource_;
|
||||
};
|
||||
|
||||
|
|
@ -866,38 +852,40 @@ private:
|
|||
* @brief Task class for performing asynchronous memory copy operations.
|
||||
*/
|
||||
class async_memcpy_task : public cann_task {
|
||||
public:
|
||||
async_memcpy_task(void* dst, const void* src, size_t size,
|
||||
aclrtMemcpyKind kind, aclrtStream stream)
|
||||
: dst_(dst), src_(src), size_(size), kind_(kind), stream_(stream) {}
|
||||
public:
|
||||
async_memcpy_task(void * dst, const void * src, size_t size, aclrtMemcpyKind kind, aclrtStream stream) :
|
||||
dst_(dst),
|
||||
src_(src),
|
||||
size_(size),
|
||||
kind_(kind),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_));
|
||||
}
|
||||
private:
|
||||
void* dst_;
|
||||
const void* src_;
|
||||
size_t size_;
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); }
|
||||
private:
|
||||
void * dst_;
|
||||
const void * src_;
|
||||
size_t size_;
|
||||
aclrtMemcpyKind kind_;
|
||||
aclrtStream stream_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Task class for performing asynchronous memory set operations.
|
||||
*/
|
||||
class async_memset_task : public cann_task {
|
||||
public:
|
||||
async_memset_task(void* buffer, size_t size, int32_t value, aclrtStream stream)
|
||||
: buffer_(buffer), size_(size), value_(value), stream_(stream) {}
|
||||
public:
|
||||
async_memset_task(void * buffer, size_t size, int32_t value, aclrtStream stream) :
|
||||
buffer_(buffer),
|
||||
size_(size),
|
||||
value_(value),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override {
|
||||
ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_));
|
||||
}
|
||||
private:
|
||||
void* buffer_;
|
||||
size_t size_;
|
||||
int32_t value_;
|
||||
aclrtStream stream_;
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); }
|
||||
private:
|
||||
void * buffer_;
|
||||
size_t size_;
|
||||
int32_t value_;
|
||||
aclrtStream stream_;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -918,25 +906,24 @@ class async_memset_task : public cann_task {
|
|||
* same stream are executed in queue order.
|
||||
*/
|
||||
|
||||
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor));\
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
if (CTX.async_mode) { \
|
||||
auto task = \
|
||||
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, \
|
||||
executor, CTX.stream()); \
|
||||
CTX.task_queue.submit_task(std::move(task)); \
|
||||
} else { \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));\
|
||||
} \
|
||||
#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \
|
||||
do { \
|
||||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
workspaceAddr = workspace_allocator.get(); \
|
||||
} \
|
||||
if (CTX.async_mode) { \
|
||||
auto task = \
|
||||
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, executor, CTX.stream()); \
|
||||
CTX.task_queue.submit_task(std::move(task)); \
|
||||
} else { \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
/**
|
||||
|
|
@ -947,11 +934,10 @@ class async_memset_task : public cann_task {
|
|||
* @param ctx Backend context which manages task submission and async mode.
|
||||
* @param args Pointers to ACL resources to be released.
|
||||
*/
|
||||
template <typename... Args>
|
||||
void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
|
||||
template <typename... Args> void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
|
||||
std::vector<any_acl_resource> resources;
|
||||
register_acl_resources(resources, std::forward<Args>(args)...);
|
||||
if(ctx.async_mode) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<release_resource_task>(std::move(resources));
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
}
|
||||
|
|
@ -966,8 +952,11 @@ void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... arg
|
|||
* @param len Size of memory to copy (in bytes).
|
||||
* @param kind Type of memory copy (host-to-device, device-to-host, etc).
|
||||
*/
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
|
||||
const void * src, size_t len, aclrtMemcpyKind kind) {
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
|
|
@ -976,8 +965,11 @@ inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
|
|||
}
|
||||
}
|
||||
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
|
||||
const void * src, size_t len, aclrtMemcpyKind kind) {
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx->async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
|
||||
ctx->task_queue.submit_task(std::move(task));
|
||||
|
|
@ -994,8 +986,7 @@ inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
|
|||
* @param size Size of the memory buffer (in bytes).
|
||||
* @param value Value to set in the buffer.
|
||||
*/
|
||||
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer,
|
||||
size_t size, int value) {
|
||||
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, size_t size, int value) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
|
|
@ -1029,7 +1020,7 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe
|
|||
* @param dst The destination tensor where the expert-weighted token outputs are stored.
|
||||
* Expected to be of shape [M, K, N, 1].
|
||||
*/
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Check whether a tensor is a weight tensor for matrix multiplication.
|
||||
|
|
@ -1041,20 +1032,14 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
*
|
||||
* @param tensor Pointer to the target ggml_tensor object (const-qualified).
|
||||
*/
|
||||
static bool is_matmul_weight(const ggml_tensor* tensor) {
|
||||
std::string name = ggml_get_name(tensor);
|
||||
static const std::unordered_set<std::string> weight_suffixes{
|
||||
"output.weight",
|
||||
"attn_q.weight",
|
||||
"attn_k.weight",
|
||||
"attn_v.weight",
|
||||
"attn_output.weight",
|
||||
"ffn_gate.weight",
|
||||
"ffn_up.weight",
|
||||
"ffn_down.weight"
|
||||
};
|
||||
static bool is_matmul_weight(const ggml_tensor * tensor) {
|
||||
std::string name = ggml_get_name(tensor);
|
||||
static const std::unordered_set<std::string> weight_suffixes{ "output.weight", "attn_q.weight",
|
||||
"attn_k.weight", "attn_v.weight",
|
||||
"attn_output.weight", "ffn_gate.weight",
|
||||
"ffn_up.weight", "ffn_down.weight" };
|
||||
|
||||
for (const auto& suffix : weight_suffixes) {
|
||||
for (const auto & suffix : weight_suffixes) {
|
||||
if (name.find(suffix) != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
|
|
@ -1078,14 +1063,13 @@ static bool is_matmul_weight(const ggml_tensor* tensor) {
|
|||
* @param ctx The CANN backend context used to manage execution and resources.
|
||||
* @param dst The destination tensor.
|
||||
*/
|
||||
template <auto binary_op>
|
||||
void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src0 = dst->src[0];
|
||||
ggml_tensor* src1 = dst->src[1];
|
||||
template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
aclTensor* acl_src0;
|
||||
aclTensor* acl_src1;
|
||||
aclTensor* acl_dst;
|
||||
aclTensor * acl_src0;
|
||||
aclTensor * acl_src1;
|
||||
aclTensor * acl_dst;
|
||||
|
||||
// Need bcast
|
||||
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
|
||||
|
|
@ -1094,7 +1078,6 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Applies a unary operation to an input tensor using the CANN backend.
|
||||
*
|
||||
|
|
@ -1107,12 +1090,12 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||
* @param ctx The CANN backend context for managing resources and execution.
|
||||
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
|
||||
*/
|
||||
template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
|
||||
void ggml_cann_op_unary(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src = dst->src[0];
|
||||
template <void unary_op(ggml_backend_cann_context &, aclTensor *, aclTensor *)>
|
||||
void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
|
||||
aclTensor* acl_src = ggml_cann_create_tensor(src);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
aclTensor * acl_src = ggml_cann_create_tensor(src);
|
||||
aclTensor * acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
unary_op(ctx, acl_src, acl_dst);
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst);
|
||||
|
|
@ -1138,9 +1121,9 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
|
|||
*
|
||||
* @see GGML_CANN_CALL_OP_UNARY
|
||||
*/
|
||||
void ggml_cann_op_unary(
|
||||
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
|
||||
ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_op_unary(std::function<void(ggml_backend_cann_context &, aclTensor *, aclTensor *)> unary_op,
|
||||
ggml_backend_cann_context & ctx,
|
||||
ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies a gated (GLU-style) unary operation using the CANN backend.
|
||||
|
|
@ -1172,9 +1155,9 @@ void ggml_cann_op_unary(
|
|||
*
|
||||
* @see GGML_CANN_CALL_OP_UNARY_GATED
|
||||
*/
|
||||
void ggml_cann_op_unary_gated(
|
||||
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
|
||||
ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, aclTensor *, aclTensor *)> unary_op,
|
||||
ggml_backend_cann_context & ctx,
|
||||
ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
|
||||
|
|
@ -1197,16 +1180,13 @@ void ggml_cann_op_unary_gated(
|
|||
* @see ggml_cann_op_unary
|
||||
* @see GGML_CANN_CALL_ACLNN_OP
|
||||
*/
|
||||
#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \
|
||||
do { \
|
||||
auto lambda = [](ggml_backend_cann_context& ctx, \
|
||||
aclTensor* acl_src, \
|
||||
aclTensor* acl_dst) { \
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
|
||||
}; \
|
||||
ggml_cann_op_unary(lambda, ctx, dst); \
|
||||
} \
|
||||
while (0)
|
||||
#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \
|
||||
do { \
|
||||
auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
|
||||
}; \
|
||||
ggml_cann_op_unary(lambda, ctx, dst); \
|
||||
} while (0)
|
||||
|
||||
/**
|
||||
* @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
|
||||
|
|
@ -1229,15 +1209,12 @@ void ggml_cann_op_unary_gated(
|
|||
* @see ggml_cann_op_unary_gated
|
||||
* @see GGML_CANN_CALL_ACLNN_OP
|
||||
*/
|
||||
#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \
|
||||
do { \
|
||||
auto lambda = [](ggml_backend_cann_context& ctx, \
|
||||
aclTensor* acl_src, \
|
||||
aclTensor* acl_dst) { \
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
|
||||
}; \
|
||||
ggml_cann_op_unary_gated(lambda, ctx, dst); \
|
||||
} \
|
||||
while (0)
|
||||
#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \
|
||||
do { \
|
||||
auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
|
||||
}; \
|
||||
ggml_cann_op_unary_gated(lambda, ctx, dst); \
|
||||
} while (0)
|
||||
|
||||
#endif // CANN_ACLNN_OPS
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@
|
|||
#include "../include/ggml.h"
|
||||
#include "../ggml-impl.h"
|
||||
|
||||
#define MATRIX_ROW_PADDING 512
|
||||
#define MATRIX_ROW_PADDING 512
|
||||
#define GGML_CANN_MAX_STREAMS 8
|
||||
|
||||
/**
|
||||
|
|
@ -56,8 +56,7 @@
|
|||
* @param line The line number at which the error occurred.
|
||||
* @param msg The error message.
|
||||
*/
|
||||
[[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
|
||||
const char* file, int line, const char* msg);
|
||||
[[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
|
||||
|
||||
/**
|
||||
* @brief Checks the result of a CANN function call and invokes the error
|
||||
|
|
@ -89,25 +88,24 @@ struct ggml_cann_device_info {
|
|||
* @brief Information about a single CANN device.
|
||||
*/
|
||||
struct cann_device_info {
|
||||
int cc; /**< Compute capability. */
|
||||
int cc; /**< Compute capability. */
|
||||
size_t smpb; /**< Maximum shared memory per block. */
|
||||
bool vmm; /**< Virtual memory support. */
|
||||
bool vmm; /**< Virtual memory support. */
|
||||
size_t vmm_granularity; /**< Granularity of virtual memory. */
|
||||
size_t total_vram; /**< Total video RAM available on the device. */
|
||||
};
|
||||
|
||||
cann_device_info devices[GGML_CANN_MAX_DEVICES] =
|
||||
{}; /**< Array of CANN device information. */
|
||||
cann_device_info devices[GGML_CANN_MAX_DEVICES] = {}; /**< Array of CANN device information. */
|
||||
};
|
||||
|
||||
const ggml_cann_device_info& ggml_cann_info();
|
||||
const ggml_cann_device_info & ggml_cann_info();
|
||||
|
||||
void ggml_cann_set_device(int32_t device);
|
||||
void ggml_cann_set_device(int32_t device);
|
||||
int32_t ggml_cann_get_device();
|
||||
|
||||
std::optional<std::string> get_env(const std::string& name);
|
||||
bool parse_bool(const std::string& value);
|
||||
int parse_integer(const std::string& value);
|
||||
std::optional<std::string> get_env(const std::string & name);
|
||||
bool parse_bool(const std::string & value);
|
||||
int parse_integer(const std::string & value);
|
||||
|
||||
/**
|
||||
* @brief Abstract base class for memory pools used by CANN.
|
||||
|
|
@ -126,7 +124,7 @@ struct ggml_cann_pool {
|
|||
* will be stored.
|
||||
* @return Pointer to the allocated memory block.
|
||||
*/
|
||||
virtual void* alloc(size_t size, size_t* actual_size) = 0;
|
||||
virtual void * alloc(size_t size, size_t * actual_size) = 0;
|
||||
|
||||
/**
|
||||
* @brief Frees a previously allocated memory block.
|
||||
|
|
@ -136,16 +134,16 @@ struct ggml_cann_pool {
|
|||
* @note Note that all CANN opertors are running async. Make sure memory is
|
||||
* still avaiable before this operator finished.
|
||||
*/
|
||||
virtual void free(void* ptr, size_t size) = 0;
|
||||
virtual void free(void * ptr, size_t size) = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief RAII wrapper for managing memory allocations from a CANN memory pool.
|
||||
*/
|
||||
struct ggml_cann_pool_alloc {
|
||||
ggml_cann_pool* pool = nullptr; /**< Pointer to the memory pool. */
|
||||
void* ptr = nullptr; /**< Pointer to the allocated memory block. */
|
||||
size_t actual_size = 0; /**< Actual size of the allocated memory block. */
|
||||
ggml_cann_pool * pool = nullptr; /**< Pointer to the memory pool. */
|
||||
void * ptr = nullptr; /**< Pointer to the allocated memory block. */
|
||||
size_t actual_size = 0; /**< Actual size of the allocated memory block. */
|
||||
|
||||
/**
|
||||
* @brief Default constructor.
|
||||
|
|
@ -156,16 +154,14 @@ struct ggml_cann_pool_alloc {
|
|||
* @brief Constructor that initializes the memory pool.
|
||||
* @param pool Reference to the memory pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_alloc(ggml_cann_pool& pool) : pool(&pool) {}
|
||||
explicit ggml_cann_pool_alloc(ggml_cann_pool & pool) : pool(&pool) {}
|
||||
|
||||
/**
|
||||
* @brief Constructor that initializes the memory pool and allocates memory.
|
||||
* @param pool Reference to the memory pool.
|
||||
* @param size Size of the memory block to allocate.
|
||||
*/
|
||||
ggml_cann_pool_alloc(ggml_cann_pool& pool, size_t size) : pool(&pool) {
|
||||
alloc(size);
|
||||
}
|
||||
ggml_cann_pool_alloc(ggml_cann_pool & pool, size_t size) : pool(&pool) { alloc(size); }
|
||||
|
||||
/**
|
||||
* @brief Destructor that frees the allocated memory block.
|
||||
|
|
@ -181,7 +177,7 @@ struct ggml_cann_pool_alloc {
|
|||
* @param size Size of the memory block to allocate.
|
||||
* @return Pointer to the allocated memory block.
|
||||
*/
|
||||
void* alloc(size_t size) {
|
||||
void * alloc(size_t size) {
|
||||
GGML_ASSERT(pool != nullptr);
|
||||
GGML_ASSERT(ptr == nullptr);
|
||||
ptr = pool->alloc(size, &this->actual_size);
|
||||
|
|
@ -194,7 +190,7 @@ struct ggml_cann_pool_alloc {
|
|||
* @param size Size of the memory block to allocate.
|
||||
* @return Pointer to the allocated memory block.
|
||||
*/
|
||||
void* alloc(ggml_cann_pool& pool, size_t size) {
|
||||
void * alloc(ggml_cann_pool & pool, size_t size) {
|
||||
this->pool = &pool;
|
||||
return alloc(size);
|
||||
}
|
||||
|
|
@ -203,25 +199,25 @@ struct ggml_cann_pool_alloc {
|
|||
* @brief Gets the pointer to the allocated memory block.
|
||||
* @return Pointer to the allocated memory block.
|
||||
*/
|
||||
void* get() { return ptr; }
|
||||
void * get() { return ptr; }
|
||||
|
||||
// Deleted copy constructor
|
||||
ggml_cann_pool_alloc(const ggml_cann_pool_alloc&) = delete;
|
||||
ggml_cann_pool_alloc(const ggml_cann_pool_alloc &) = delete;
|
||||
|
||||
// Deleted move constructor
|
||||
ggml_cann_pool_alloc(ggml_cann_pool_alloc&&) = delete;
|
||||
ggml_cann_pool_alloc(ggml_cann_pool_alloc &&) = delete;
|
||||
|
||||
// Deleted copy assignment operator
|
||||
ggml_cann_pool_alloc& operator=(const ggml_cann_pool_alloc&) = delete;
|
||||
ggml_cann_pool_alloc & operator=(const ggml_cann_pool_alloc &) = delete;
|
||||
|
||||
// Deleted move assignment operator
|
||||
ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
|
||||
ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Function pointer type for ACLNN operator calls.
|
||||
*/
|
||||
using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream);
|
||||
using aclnn_func_t = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream);
|
||||
|
||||
/**
|
||||
* @brief Base class for all CANN tasks to be submitted to the task queue.
|
||||
|
|
@ -229,7 +225,7 @@ using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStrea
|
|||
* Users should override the run_task() method with actual task logic.
|
||||
*/
|
||||
class cann_task {
|
||||
public:
|
||||
public:
|
||||
virtual void run_task() {}
|
||||
};
|
||||
|
||||
|
|
@ -237,16 +233,20 @@ public:
|
|||
* @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
|
||||
*/
|
||||
class cann_task_queue {
|
||||
public:
|
||||
public:
|
||||
/**
|
||||
* @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
|
||||
*
|
||||
* @param capacity Queue capacity. Must be a power of 2.
|
||||
* @param device Target device ID (used for context setting).
|
||||
*/
|
||||
explicit cann_task_queue(size_t capacity, int32_t device)
|
||||
: buffer_(capacity), capacity_(capacity), head_(0), tail_(0),
|
||||
running_(false), device_(device) {
|
||||
explicit cann_task_queue(size_t capacity, int32_t device) :
|
||||
buffer_(capacity),
|
||||
capacity_(capacity),
|
||||
head_(0),
|
||||
tail_(0),
|
||||
running_(false),
|
||||
device_(device) {
|
||||
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
|
||||
mask_ = capacity_ - 1;
|
||||
}
|
||||
|
|
@ -257,7 +257,7 @@ public:
|
|||
* @param item Unique pointer to the task.
|
||||
* @return true if the task was successfully enqueued, false if the queue was full.
|
||||
*/
|
||||
bool enqueue(std::unique_ptr<cann_task>&& item) {
|
||||
bool enqueue(std::unique_ptr<cann_task> && item) {
|
||||
size_t next_tail = (tail_ + 1) & mask_;
|
||||
|
||||
if (next_tail == head_) {
|
||||
|
|
@ -276,17 +276,16 @@ public:
|
|||
*
|
||||
* @param task Task to be submitted.
|
||||
*/
|
||||
void submit_task(std::unique_ptr<cann_task>&& task) {
|
||||
while(!enqueue(std::move(task))) {
|
||||
void submit_task(std::unique_ptr<cann_task> && task) {
|
||||
while (!enqueue(std::move(task))) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!running_) {
|
||||
running_ = true;
|
||||
thread_ = std::thread(&cann_task_queue::execute, this);
|
||||
thread_ = std::thread(&cann_task_queue::execute, this);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -309,7 +308,7 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
/**
|
||||
* @brief Worker thread function that continuously dequeues and executes tasks.
|
||||
*/
|
||||
|
|
@ -317,7 +316,7 @@ private:
|
|||
ggml_cann_set_device(device_);
|
||||
|
||||
while (running_) {
|
||||
if(head_ == tail_) {
|
||||
if (head_ == tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
|
@ -330,22 +329,29 @@ private:
|
|||
}
|
||||
|
||||
std::vector<std::unique_ptr<cann_task>> buffer_;
|
||||
const size_t capacity_;
|
||||
size_t mask_;
|
||||
size_t head_;
|
||||
size_t tail_;
|
||||
bool running_;
|
||||
std::thread thread_;
|
||||
int32_t device_;
|
||||
const size_t capacity_;
|
||||
size_t mask_;
|
||||
size_t head_;
|
||||
size_t tail_;
|
||||
bool running_;
|
||||
std::thread thread_;
|
||||
int32_t device_;
|
||||
};
|
||||
|
||||
#ifdef USE_ACL_GRAPH
|
||||
struct ggml_graph_node_properties {
|
||||
void * node_address;
|
||||
ggml_op node_op;
|
||||
// dst tensor
|
||||
void * node_address;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
void * src_address[GGML_MAX_SRC];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
|
||||
// src tensor
|
||||
void * src_address[GGML_MAX_SRC];
|
||||
int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
|
||||
size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
|
||||
|
||||
// op
|
||||
ggml_op node_op;
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
};
|
||||
|
||||
|
|
@ -369,13 +375,11 @@ struct ggml_cann_graph {
|
|||
* move existing graphs to the front (most recently used), and clear the cache.
|
||||
*/
|
||||
struct ggml_cann_graph_lru_cache {
|
||||
size_t capacity; /**< Maximum number of graphs in the cache. */
|
||||
size_t capacity; /**< Maximum number of graphs in the cache. */
|
||||
|
||||
std::list<ggml_cann_graph*> cache_list; /**< List storing cached graphs as raw pointers. */
|
||||
std::list<ggml_cann_graph *> cache_list; /**< List storing cached graphs as raw pointers. */
|
||||
|
||||
ggml_cann_graph_lru_cache() {
|
||||
capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12"));
|
||||
}
|
||||
ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); }
|
||||
|
||||
/**
|
||||
* @brief Push a new graph to the front of the cache.
|
||||
|
|
@ -383,11 +387,11 @@ struct ggml_cann_graph_lru_cache {
|
|||
* @param new_node Pointer to the new ggml_cann_graph to cache.
|
||||
* Ownership is transferred to the cache (cache will delete it).
|
||||
*/
|
||||
void push(ggml_cann_graph* new_node) {
|
||||
void push(ggml_cann_graph * new_node) {
|
||||
if (cache_list.size() >= capacity) {
|
||||
ggml_cann_graph* old = cache_list.back();
|
||||
ggml_cann_graph * old = cache_list.back();
|
||||
cache_list.pop_back();
|
||||
delete old; // free the old graph
|
||||
delete old; // free the old graph
|
||||
}
|
||||
cache_list.push_front(new_node);
|
||||
}
|
||||
|
|
@ -396,7 +400,7 @@ struct ggml_cann_graph_lru_cache {
|
|||
* @brief Move an existing graph to the front of the cache.
|
||||
* @param node Pointer to the ggml_cann_graph to move.
|
||||
*/
|
||||
void move_to_front(ggml_cann_graph* node) {
|
||||
void move_to_front(ggml_cann_graph * node) {
|
||||
cache_list.remove(node);
|
||||
cache_list.push_front(node);
|
||||
}
|
||||
|
|
@ -414,92 +418,89 @@ struct ggml_cann_graph_lru_cache {
|
|||
/**
|
||||
* @brief Destructor that clears the cache and frees all cached graphs.
|
||||
*/
|
||||
~ggml_cann_graph_lru_cache() {
|
||||
clear();
|
||||
}
|
||||
~ggml_cann_graph_lru_cache() { clear(); }
|
||||
};
|
||||
#endif // USE_ACL_GRAPH
|
||||
|
||||
struct ggml_cann_rope_cache {
|
||||
~ggml_cann_rope_cache() {
|
||||
if(theta_scale_cache != nullptr) {
|
||||
if (theta_scale_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(theta_scale_cache));
|
||||
}
|
||||
if(sin_cache != nullptr) {
|
||||
if (sin_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(sin_cache));
|
||||
}
|
||||
if(cos_cache != nullptr) {
|
||||
if (cos_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(cos_cache));
|
||||
}
|
||||
}
|
||||
|
||||
void* theta_scale_cache = nullptr;
|
||||
void * theta_scale_cache = nullptr;
|
||||
int64_t theta_scale_length = 0;
|
||||
// sin/cos cache, used only to accelerate first layer on each device
|
||||
void* sin_cache = nullptr;
|
||||
void* cos_cache = nullptr;
|
||||
int64_t position_length = 0;
|
||||
void * sin_cache = nullptr;
|
||||
void * cos_cache = nullptr;
|
||||
int64_t position_length = 0;
|
||||
// Properties to check before reusing the sincos cache
|
||||
bool cached = false;
|
||||
float ext_factor = 0.0f;
|
||||
float theta_scale = 0.0f;
|
||||
float freq_scale = 0.0f;
|
||||
float attn_factor = 0.0f;
|
||||
bool is_neox = false;
|
||||
bool cached = false;
|
||||
float ext_factor = 0.0f;
|
||||
float theta_scale = 0.0f;
|
||||
float freq_scale = 0.0f;
|
||||
float attn_factor = 0.0f;
|
||||
bool is_neox = false;
|
||||
};
|
||||
|
||||
struct ggml_cann_tensor_cache {
|
||||
~ggml_cann_tensor_cache() {
|
||||
if(cache != nullptr) {
|
||||
if (cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(cache));
|
||||
}
|
||||
}
|
||||
|
||||
void* cache = nullptr;
|
||||
int64_t size = 0;
|
||||
void * cache = nullptr;
|
||||
int64_t size = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Context for managing CANN backend operations.
|
||||
*/
|
||||
struct ggml_backend_cann_context {
|
||||
int32_t device; /**< Device ID. */
|
||||
std::string name; /**< Name of the device. */
|
||||
std::string description; /**< Description of the device. */
|
||||
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
|
||||
int32_t device; /**< Device ID. */
|
||||
std::string name; /**< Name of the device. */
|
||||
std::string description; /**< Description of the device. */
|
||||
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
|
||||
#ifdef USE_ACL_GRAPH
|
||||
/// Cached CANN ACL graph used for executing the current ggml computation graph.
|
||||
ggml_cann_graph_lru_cache graph_lru_cache;
|
||||
bool acl_graph_mode = true;
|
||||
bool acl_graph_mode = true;
|
||||
#endif
|
||||
cann_task_queue task_queue;
|
||||
bool async_mode;
|
||||
cann_task_queue task_queue;
|
||||
bool async_mode;
|
||||
// Rope Cache
|
||||
ggml_cann_rope_cache rope_cache;
|
||||
ggml_cann_rope_cache rope_cache;
|
||||
// Constant Pool
|
||||
ggml_cann_tensor_cache rms_norm_one_tensor_cache;
|
||||
ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
|
||||
|
||||
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
|
||||
aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */
|
||||
|
||||
/**
|
||||
* @brief Constructor for initializing the context with a given device.
|
||||
* @param device Device ID.
|
||||
*/
|
||||
explicit ggml_backend_cann_context(int device)
|
||||
: device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
|
||||
explicit ggml_backend_cann_context(int device) :
|
||||
device(device),
|
||||
name("CANN" + std::to_string(device)),
|
||||
task_queue(1024, device) {
|
||||
ggml_cann_set_device(device);
|
||||
description = aclrtGetSocName();
|
||||
|
||||
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
|
||||
device, async_mode ? "ON" : "OFF");
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF");
|
||||
#ifdef USE_ACL_GRAPH
|
||||
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
|
||||
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n",
|
||||
__func__, device,
|
||||
acl_graph_mode ? "GRAPH" : "EAGER",
|
||||
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
|
||||
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
|
||||
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -542,8 +543,7 @@ struct ggml_backend_cann_context {
|
|||
aclrtStream stream() { return stream(0); }
|
||||
|
||||
// TODO: each stream should have a memory pool.
|
||||
std::unique_ptr<ggml_cann_pool>
|
||||
mem_pool; /**< Memory pool for the device. */
|
||||
std::unique_ptr<ggml_cann_pool> mem_pool; /**< Memory pool for the device. */
|
||||
|
||||
/**
|
||||
* @brief Create a new memory pool for a given device.
|
||||
|
|
@ -556,7 +556,7 @@ struct ggml_backend_cann_context {
|
|||
* @brief Get or create the memory pool for the context.
|
||||
* @return Reference to the memory pool.
|
||||
*/
|
||||
ggml_cann_pool& pool() {
|
||||
ggml_cann_pool & pool() {
|
||||
if (mem_pool == nullptr) {
|
||||
mem_pool = new_pool_for_device(device);
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -466,29 +466,45 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
message(STATUS "s390x detected")
|
||||
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
|
||||
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
||||
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/arch/s390/quants.c)
|
||||
|
||||
# TODO: Separation to determine activation of VX/VXE/VXE2
|
||||
if (${S390X_M} MATCHES "8561|8562")
|
||||
message(STATUS "z15 target")
|
||||
list(APPEND ARCH_FLAGS -march=z15)
|
||||
elseif (${S390X_M} MATCHES "3931")
|
||||
message(STATUS "z16 target")
|
||||
list(APPEND ARCH_FLAGS -march=z16)
|
||||
elseif (${S390X_M} MATCHES "9175|9176")
|
||||
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
|
||||
# binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
|
||||
message(STATUS "z17 target")
|
||||
list(APPEND ARCH_FLAGS -march=arch15)
|
||||
else()
|
||||
message(STATUS "Unknown target")
|
||||
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
|
||||
list(APPEND ARCH_FLAGS -march=native -mtune=native)
|
||||
# for native compilation
|
||||
if (GGML_NATIVE)
|
||||
# check machine level to determine target
|
||||
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
||||
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
||||
|
||||
# TODO: Separation to determine activation of VX/VXE/VXE2
|
||||
if (${S390X_M} MATCHES "8561|8562")
|
||||
message(STATUS "z15 target")
|
||||
list(APPEND ARCH_FLAGS -march=z15)
|
||||
elseif (${S390X_M} MATCHES "3931")
|
||||
message(STATUS "z16 target")
|
||||
list(APPEND ARCH_FLAGS -march=z16)
|
||||
elseif (${S390X_M} MATCHES "9175|9176")
|
||||
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
|
||||
# binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
|
||||
message(STATUS "z17 target")
|
||||
list(APPEND ARCH_FLAGS -march=arch15)
|
||||
else()
|
||||
message(STATUS "Unknown target")
|
||||
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
|
||||
list(APPEND ARCH_FLAGS -march=native -mtune=native)
|
||||
endif()
|
||||
# for cross-compilation
|
||||
elseif(GGML_CPU_ALL_VARIANTS)
|
||||
# range through IBM z15 to z17
|
||||
# NOTE: update when a new hardware level is released
|
||||
foreach (ZHW RANGE 15 17)
|
||||
if(DEFINED GGML_INTERNAL_Z${ZHW})
|
||||
message(STATUS "z${ZHW} cross-compile target")
|
||||
list(APPEND ARCH_FLAGS -march=z${ZHW})
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if (GGML_VXE)
|
||||
if (GGML_VXE OR GGML_INTERNAL_VXE)
|
||||
message(STATUS "VX/VXE/VXE2 enabled")
|
||||
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_VXE)
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ struct ggml_compute_params {
|
|||
#endif // __VXE2__
|
||||
#endif // __s390x__ && __VEC__
|
||||
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__ARM_FEATURE_SVE) && defined(__linux__)
|
||||
#include <sys/prctl.h>
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -689,8 +689,13 @@ bool ggml_is_numa(void) {
|
|||
#endif
|
||||
|
||||
static void ggml_init_arm_arch_features(void) {
|
||||
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
|
||||
#if defined(__linux__)
|
||||
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
|
||||
#else
|
||||
// TODO: add support of SVE for non-linux systems
|
||||
#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -2179,6 +2184,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_UNARY_OP_HARDSWISH:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
case GGML_UNARY_OP_EXP:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
{
|
||||
n_tasks = 1;
|
||||
} break;
|
||||
|
|
@ -3558,13 +3567,17 @@ void ggml_cpu_init(void) {
|
|||
#ifdef GGML_USE_OPENMP
|
||||
//if (!getenv("OMP_WAIT_POLICY")) {
|
||||
// // set the wait policy to active, so that OpenMP threads don't sleep
|
||||
// putenv("OMP_WAIT_POLICY=active");
|
||||
// setenv("OMP_WAIT_POLICY", "active", 0)
|
||||
//}
|
||||
|
||||
if (!getenv("KMP_BLOCKTIME")) {
|
||||
// set the time to wait before sleeping a thread
|
||||
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
|
||||
putenv("KMP_BLOCKTIME=200"); // 200ms
|
||||
#ifdef _WIN32
|
||||
_putenv_s("KMP_BLOCKTIME", "200"); // 200ms
|
||||
#else
|
||||
setenv("KMP_BLOCKTIME", "200", 0); // 200ms
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,108 @@
|
|||
|
||||
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t)>
|
||||
static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) {
|
||||
return Fn(a, b, c);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t)>
|
||||
static inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) {
|
||||
return Fn(a, b);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
|
||||
static inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl,
|
||||
const void* lhs, const void* rhs, void* dst,
|
||||
size_t dst_stride_row, size_t dst_stride_col,
|
||||
float clamp_min, float clamp_max) {
|
||||
Fn(m, n, k, bl, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,void*,size_t,size_t,float,float)>
|
||||
static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
||||
const void* lhs, const void* rhs, void* dst,
|
||||
size_t dst_stride_row, size_t dst_stride_col,
|
||||
float clamp_min, float clamp_max) {
|
||||
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
||||
return Fn(m, k, bl, mr, kr, sr);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
|
||||
return Fn(m, k, mr, kr, sr);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
||||
return Fn(m_idx, k, bl, mr, kr, sr);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) {
|
||||
return Fn(m_idx, k, mr, kr, sr);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
|
||||
static inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
|
||||
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
|
||||
Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
|
||||
static inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
|
||||
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
|
||||
Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)>
|
||||
static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
|
||||
size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) {
|
||||
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
||||
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
|
||||
return Fn(n, k, nr, kr, bl);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t)>
|
||||
static inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
|
||||
return Fn(n, k);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t)>
|
||||
static inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) {
|
||||
return Fn(k, nr, kr, bl);
|
||||
}
|
||||
|
||||
template<size_t(*Fn)(size_t)>
|
||||
static inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) {
|
||||
return Fn(k);
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const uint8_t*,const float*,void*,size_t,const struct kai_rhs_pack_qs4cxs1s0_param*)>
|
||||
static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
|
||||
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/,
|
||||
void* rhs_packed, size_t extra_bytes, const void* params) {
|
||||
Fn(num_groups, n, k, nr, kr, sr, bl,
|
||||
static_cast<const uint8_t*>(rhs),
|
||||
static_cast<const float*>(bias),
|
||||
rhs_packed, extra_bytes,
|
||||
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
|
||||
}
|
||||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
|
||||
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
||||
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
|
||||
void* rhs_packed, size_t extra_bytes, const void* params) {
|
||||
Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params);
|
||||
}
|
||||
|
||||
static const size_t INT4_PER_BYTE = 2;
|
||||
static const size_t INT4_BITS = 4;
|
||||
static const int Q4_0_ZERO_POINT = 8;
|
||||
|
|
@ -122,17 +224,18 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
|
||||
},
|
||||
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
},
|
||||
/* SME GEMV */
|
||||
/* .kern_info = */ {
|
||||
|
|
@ -142,23 +245,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
|
||||
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
|
|
@ -174,17 +278,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn10<kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
/* .pack_func_ex = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
},
|
||||
/* SME GEMV */
|
||||
/* .kern_info = */ {
|
||||
|
|
@ -194,23 +298,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
|
||||
/* .get_lhs_offset_ex = */ nullptr,
|
||||
/* .get_rhs_packed_offset_ex = */ nullptr,
|
||||
/* .run_kernel_ex = */ nullptr,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
/* .pack_func_ex = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
/* .packed_stride = */ NULL,
|
||||
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
|
||||
/* .to_float = */ NULL,
|
||||
/* .packed_stride = */ nullptr,
|
||||
/* .to_float = */ nullptr,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn2<kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn1<kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn13<kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
|
|
@ -229,17 +334,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* DOTPROD GEMV */
|
||||
/* .kern_info = */ {
|
||||
|
|
@ -249,23 +354,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
|
|
@ -283,17 +389,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
},
|
||||
/* i8mm GEMV */
|
||||
/* .kern_info = */ {
|
||||
|
|
@ -303,23 +409,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
|
|
@ -338,17 +445,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
},
|
||||
/* i8mm GEMV */
|
||||
/* .kern_info = */ {
|
||||
|
|
@ -358,23 +465,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
|
|
@ -392,17 +500,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* DOTPROD GEMV */
|
||||
/* .kern_info = */ {
|
||||
|
|
@ -412,23 +520,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
|
|
@ -443,6 +552,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|||
ggml_kleidiai_kernels * kernel = nullptr;
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
||||
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
||||
|
|
@ -452,6 +562,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
return kernel;
|
||||
|
|
@ -460,12 +571,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) {
|
||||
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
||||
kernels = &gemm_gemv_kernels[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return kernels;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,8 +4,6 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <variant>
|
||||
#include "ggml.h"
|
||||
|
||||
enum cpu_feature {
|
||||
|
|
@ -15,6 +13,7 @@ enum cpu_feature {
|
|||
CPU_FEATURE_SVE = 4,
|
||||
CPU_FEATURE_SME = 8
|
||||
};
|
||||
|
||||
inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
|
||||
lhs = static_cast<cpu_feature>(lhs | rhs);
|
||||
return lhs;
|
||||
|
|
@ -30,63 +29,52 @@ struct kernel_info {
|
|||
size_t (*get_nr)(void);
|
||||
size_t (*get_kr)(void);
|
||||
size_t (*get_sr)(void);
|
||||
std::variant<
|
||||
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
||||
std::function<size_t(size_t m_idx, size_t k)>
|
||||
> get_lhs_offset;
|
||||
std::variant<
|
||||
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
|
||||
std::function<size_t(size_t n_idx, size_t k)>
|
||||
> get_rhs_packed_offset;
|
||||
|
||||
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
|
||||
size_t (*get_dst_size)(size_t m, size_t n);
|
||||
std::variant<
|
||||
std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
|
||||
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
|
||||
std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
|
||||
size_t dst_stride_col, float clamp_min, float clamp_max)>
|
||||
> run_kernel;
|
||||
|
||||
size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl);
|
||||
|
||||
size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl);
|
||||
|
||||
void (*run_kernel_ex)(
|
||||
size_t m, size_t n, size_t k, size_t bl,
|
||||
const void* lhs_packed, const void* rhs_packed,
|
||||
void* dst, size_t dst_stride_row, size_t dst_stride_col,
|
||||
float clamp_min, float clamp_max);
|
||||
};
|
||||
|
||||
struct lhs_packing_info {
|
||||
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
|
||||
std::variant<
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
|
||||
> get_packed_offset;
|
||||
std::variant<
|
||||
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
|
||||
std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
|
||||
> packed_size;
|
||||
std::variant<
|
||||
std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
|
||||
size_t lhs_stride, void* lhs_packed)>,
|
||||
std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
|
||||
void* lhs_packed)>
|
||||
> pack_func;
|
||||
|
||||
size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
||||
|
||||
size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
|
||||
|
||||
void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
|
||||
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed);
|
||||
};
|
||||
|
||||
struct rhs_packing_info {
|
||||
std::variant<
|
||||
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
|
||||
std::function<size_t(size_t n, size_t k)>
|
||||
> packed_size;
|
||||
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
|
||||
std::variant<
|
||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
|
||||
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
|
||||
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
|
||||
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
|
||||
> pack_func;
|
||||
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
|
||||
size_t kr, size_t bl, size_t num_bytes_multiplier);
|
||||
|
||||
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out,
|
||||
size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl,
|
||||
size_t num_bytes_multiplier);
|
||||
|
||||
size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
|
||||
|
||||
size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl);
|
||||
|
||||
void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
|
||||
size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params);
|
||||
};
|
||||
|
||||
struct ggml_kleidiai_kernels {
|
||||
kernel_info gemm;
|
||||
kernel_info gemm;
|
||||
lhs_packing_info gemm_lhs_info;
|
||||
|
||||
kernel_info gemv;
|
||||
kernel_info gemv;
|
||||
lhs_packing_info gemv_lhs_info;
|
||||
|
||||
rhs_packing_info rhs_info;
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <stdexcept>
|
||||
#include <stdint.h>
|
||||
#include <string.h>
|
||||
#include <string>
|
||||
#if defined(__linux__)
|
||||
#include <asm/hwcap.h>
|
||||
#include <sys/auxv.h>
|
||||
|
|
@ -87,40 +88,6 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
|||
return tensor->ne[dim];
|
||||
}
|
||||
|
||||
template <typename Variant, typename Ret, typename... Args, std::size_t... Is>
|
||||
constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) {
|
||||
using V = std::remove_reference_t<Variant>;
|
||||
return (std::is_invocable_r_v<
|
||||
Ret,
|
||||
std::variant_alternative_t<Is, V>,
|
||||
Args...> || ...);
|
||||
}
|
||||
|
||||
template <typename Variant, typename Ret, typename... Args>
|
||||
constexpr bool variant_any_invocable_v =
|
||||
variant_any_invocable_impl<Variant, Ret, Args...>(
|
||||
std::make_index_sequence<
|
||||
std::variant_size_v<std::remove_reference_t<Variant>>>{});
|
||||
|
||||
template<typename Ret, typename Variant, typename... Args>
|
||||
static inline Ret variant_call(Variant && var, Args&&... args) {
|
||||
static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>,
|
||||
"No alternative in Variant is invocable with the provided arguments and return type.");
|
||||
|
||||
return std::visit(
|
||||
[&](auto && f) -> Ret {
|
||||
using F = std::decay_t<decltype(f)>;
|
||||
if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
|
||||
return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...);
|
||||
} else {
|
||||
GGML_ABORT("Invalid function type in variant_call");
|
||||
GGML_UNREACHABLE();
|
||||
}
|
||||
},
|
||||
std::forward<Variant>(var)
|
||||
);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::kleidiai {
|
||||
|
||||
static size_t round_down(size_t x, size_t y) {
|
||||
|
|
@ -145,7 +112,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
return false;
|
||||
}
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
||||
GGML_ASSERT(kernels);
|
||||
if (!kernels) {
|
||||
return false;
|
||||
}
|
||||
bool is_gemv = op->src[1]->ne[1] == 1;
|
||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||
|
|
@ -159,16 +128,18 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
size_t sr = kernel->get_sr();
|
||||
|
||||
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
||||
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
|
||||
if (!lhs_info->packed_size_ex) return false;
|
||||
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
|
||||
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
||||
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
|
||||
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
||||
const int64_t rhs_batch_size0 = op->src[0]->ne[2];
|
||||
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
||||
size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) +
|
||||
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
|
||||
size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) +
|
||||
kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) +
|
||||
k * n * sizeof(float) + n * sizeof(float);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
@ -196,12 +167,18 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
GGML_ASSERT(kernels);
|
||||
if (!kernels) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool is_gemv = src1->ne[1] == 1;
|
||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||
GGML_ASSERT(kernel);
|
||||
if (!kernels->rhs_info.pack_func_ex ||
|
||||
!kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int nth = params->nth;
|
||||
const int ith = params->ith;
|
||||
|
|
@ -228,10 +205,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
const int64_t kr = (int64_t) kernel->get_kr();
|
||||
const int64_t sr = (int64_t) kernel->get_sr();
|
||||
|
||||
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k);
|
||||
const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float);
|
||||
const size_t bias_size = (size_t)n * sizeof(float);
|
||||
const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr);
|
||||
const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0);
|
||||
const size_t kxn_size = k * n * sizeof(float);
|
||||
const size_t bias_size = n * sizeof(float);
|
||||
|
||||
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
|
||||
GGML_ASSERT(wsize_required <= params->wsize);
|
||||
|
|
@ -259,10 +236,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
||||
|
||||
// Base packed offset (aligned) and per-row stride in bytes
|
||||
const size_t base_packed_off = variant_call<size_t>(
|
||||
lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t next_block_off = variant_call<size_t>(
|
||||
lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t base_packed_off = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
|
||||
const size_t next_block_off = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr);
|
||||
const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
|
||||
|
||||
int64_t remaining = m_count;
|
||||
|
|
@ -278,9 +253,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
|
||||
void * dst_ptr = lhs_packed + dst_off;
|
||||
|
||||
variant_call<void>(lhs_info->pack_func,
|
||||
(size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
|
||||
/*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
|
||||
lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
||||
|
||||
cur += take;
|
||||
remaining -= take;
|
||||
|
|
@ -296,10 +269,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
reinterpret_cast<const uint16_t *>(rhs_batch_base),
|
||||
rhs_stride);
|
||||
|
||||
variant_call<void>(kernels->rhs_info.pack_func,
|
||||
/*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr,
|
||||
/*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
|
||||
rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
|
||||
kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float),
|
||||
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
|
@ -320,20 +291,15 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
||||
|
||||
// LHS packed base at row 0 (consistent with packing above)
|
||||
const size_t lhs_packed_offset0 = variant_call<size_t>(
|
||||
lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
|
||||
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k);
|
||||
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
|
||||
const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
|
||||
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
|
||||
|
||||
const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
|
||||
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
|
||||
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
|
||||
|
||||
variant_call<void>(kernel->run_kernel,
|
||||
(size_t)m, (size_t)n_to_process, (size_t)k,
|
||||
lhs_ptr, rhs_ptr,
|
||||
dst_ptr, dst_stride, sizeof(float),
|
||||
-FLT_MAX, FLT_MAX);
|
||||
kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -354,13 +320,19 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||
GGML_ASSERT(kernels);
|
||||
if (!kernels) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_gemv = src1->ne[1] == 1;
|
||||
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||
|
||||
GGML_ASSERT(kernel);
|
||||
if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
|
||||
!kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth_raw = params->nth;
|
||||
|
|
@ -402,25 +374,26 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
// Transform LHS
|
||||
const size_t src_stride = src1->nb[1];
|
||||
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr);
|
||||
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||
|
||||
variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||
// Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer
|
||||
lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
// Perform the operation
|
||||
const size_t dst_stride = dst->nb[1];
|
||||
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
|
||||
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr);
|
||||
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0);
|
||||
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
||||
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
||||
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||
|
||||
if (n_to_process > 0) {
|
||||
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
||||
kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
||||
sizeof(float), -FLT_MAX, FLT_MAX);
|
||||
}
|
||||
|
||||
|
|
@ -429,7 +402,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
|
||||
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
if (!ctx.kernels) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
|
@ -438,6 +413,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||
|
||||
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
|
||||
kernel_info * kernel = &ctx.kernels->gemm;
|
||||
if (!rhs_info->to_float || !kernel->get_nr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ggml_nelements(src1);
|
||||
|
|
@ -480,7 +458,7 @@ public:
|
|||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
||||
ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, ¶ms);
|
||||
|
||||
return 0;
|
||||
GGML_UNUSED(data_size);
|
||||
|
|
@ -548,7 +526,7 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_
|
|||
const size_t nr = ctx.kernels->gemm.get_nr();
|
||||
const size_t kr = ctx.kernels->gemm.get_kr();
|
||||
|
||||
return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
||||
return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3467,31 +3467,27 @@ static void ggml_compute_forward_norm_f32(
|
|||
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
|
||||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)x[i00];
|
||||
}
|
||||
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, x);
|
||||
float mean = sum/ne00;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
float variance = 0;
|
||||
|
||||
ggml_float sum2 = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
float v = x[i00] - mean;
|
||||
y[i00] = v;
|
||||
sum2 += (ggml_float)(v*v);
|
||||
}
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
mean = -mean;
|
||||
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
||||
vDSP_measqv(y, 1, &variance, ne00);
|
||||
#else
|
||||
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
|
||||
#endif //GGML_USE_ACCELERATE
|
||||
|
||||
float variance = sum2/ne00;
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
}
|
||||
}
|
||||
|
|
@ -8997,6 +8993,22 @@ void ggml_compute_forward_unary(
|
|||
{
|
||||
ggml_compute_forward_exp(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
{
|
||||
ggml_compute_forward_floor(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
{
|
||||
ggml_compute_forward_ceil(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
{
|
||||
ggml_compute_forward_round(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
{
|
||||
ggml_compute_forward_trunc(params, dst);
|
||||
} break;
|
||||
case GGML_UNARY_OP_XIELU:
|
||||
{
|
||||
ggml_compute_forward_xielu(params, dst);
|
||||
|
|
|
|||
|
|
@ -485,8 +485,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|||
int32_t start = ith * task_per_thread;
|
||||
int32_t end = std::min((ith + 1) * task_per_thread, task_count);
|
||||
for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {
|
||||
int32_t gemm_idx = compute_idx / block_size_m;
|
||||
int32_t m_idx = compute_idx % block_size_m * block_size_m;
|
||||
int32_t gemm_idx = compute_idx / per_gemm_block_count_m;
|
||||
int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m;
|
||||
int32_t m_idx = block_idx_in_gemm * block_size_m;
|
||||
const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx];
|
||||
int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);
|
||||
|
||||
|
|
|
|||
|
|
@ -73,6 +73,22 @@ static inline float op_log(float x) {
|
|||
return logf(x);
|
||||
}
|
||||
|
||||
static inline float op_floor(float x) {
|
||||
return floorf(x);
|
||||
}
|
||||
|
||||
static inline float op_ceil(float x) {
|
||||
return ceilf(x);
|
||||
}
|
||||
|
||||
static inline float op_round(float x) {
|
||||
return roundf(x);
|
||||
}
|
||||
|
||||
static inline float op_trunc(float x) {
|
||||
return truncf(x);
|
||||
}
|
||||
|
||||
template <float (*op)(float), typename src0_t, typename dst_t>
|
||||
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
|
||||
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
|
||||
|
|
@ -274,6 +290,22 @@ void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor *
|
|||
unary_op<op_log>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_floor(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_floor>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_ceil>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_round(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_round>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_trunc(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
unary_op<op_trunc>(params, dst);
|
||||
}
|
||||
|
||||
void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const float alpha_n = ggml_get_op_params_f32(dst, 1);
|
||||
const float alpha_p = ggml_get_op_params_f32(dst, 2);
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
|
|||
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_floor(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_round(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_trunc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
|
|||
|
|
@ -404,6 +404,72 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
|
|||
}
|
||||
}
|
||||
|
||||
ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) {
|
||||
int i = 0;
|
||||
ggml_float sum = 0;
|
||||
// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
|
||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||
for (; i + 15 < n; i += 16) {
|
||||
__m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i),
|
||||
_mm512_set1_ps(mean));
|
||||
_mm512_storeu_ps(y + i, val);
|
||||
sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val));
|
||||
}
|
||||
#elif defined(__AVX2__) && defined(__FMA__)
|
||||
for (; i + 7 < n; i += 8) {
|
||||
__m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i),
|
||||
_mm256_set1_ps(mean));
|
||||
_mm256_storeu_ps(y + i, val);
|
||||
val = _mm256_mul_ps(val,val);
|
||||
__m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
|
||||
_mm256_castps256_ps128(val));
|
||||
val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
|
||||
val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
|
||||
sum += (ggml_float)_mm_cvtss_f32(val2);
|
||||
}
|
||||
#elif defined(__SSE2__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
__m128 val = _mm_sub_ps(_mm_loadu_ps(x + i),
|
||||
_mm_set1_ps(mean));
|
||||
_mm_storeu_ps(y + i, val);
|
||||
val = _mm_mul_ps(val, val);
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
val = _mm_add_ps(val, _mm_movehl_ps(val, val));
|
||||
val = _mm_add_ss(val, _mm_movehdup_ps(val));
|
||||
#else
|
||||
__m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
|
||||
val = _mm_add_ps(val, tmp);
|
||||
tmp = _mm_movehl_ps(tmp, val);
|
||||
val = _mm_add_ss(val, tmp);
|
||||
#endif // __AVX__ || __AVX2__ || __AVX512F__
|
||||
sum += (ggml_float)_mm_cvtss_f32(val);
|
||||
}
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
float32x4_t val = vsubq_f32(vld1q_f32(x + i),
|
||||
vdupq_n_f32(mean));
|
||||
vst1q_f32(y + i, val);
|
||||
val = vmulq_f32(val, val);
|
||||
sum += (ggml_float)vaddvq_f32(val);
|
||||
}
|
||||
#elif defined(__VXE__) || defined(__VXE2__)
|
||||
for (; i + 3 < n; i += 4) {
|
||||
float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean));
|
||||
vec_xst(val, 0, y + i);
|
||||
val = vec_mul(val, val);
|
||||
sum += (ggml_float)vec_hsum_f32x4(val);
|
||||
}
|
||||
#endif
|
||||
for (; i < n; ++i) {
|
||||
float val = x[i] - mean;
|
||||
y[i] = val;
|
||||
val *= val;
|
||||
sum += (ggml_float)val;
|
||||
}
|
||||
return sum/n;
|
||||
}
|
||||
|
||||
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
|
||||
int i = 0;
|
||||
ggml_float sum = 0;
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t *
|
|||
void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
|
||||
|
||||
void ggml_vec_silu_f32(const int n, float * y, const float * x);
|
||||
ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean )
|
||||
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max);
|
||||
ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max);
|
||||
|
||||
|
|
@ -143,14 +144,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
|||
for (int i = 0; i < np; i += ggml_f16_step) {
|
||||
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
|
||||
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
|
||||
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
|
||||
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
|
||||
|
||||
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
|
||||
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
|
||||
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
|
||||
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
|
||||
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
|
||||
|
|
@ -159,7 +160,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
|
|||
|
||||
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
|
||||
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
|
||||
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
|
||||
ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
|
||||
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
|
||||
|
||||
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
|
||||
|
|
@ -819,7 +820,8 @@ inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_f
|
|||
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); }
|
||||
inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i])));
|
||||
const float v = GGML_CPU_FP16_TO_FP32(x[i]);
|
||||
y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v));
|
||||
}
|
||||
}
|
||||
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
|
||||
|
|
|
|||
|
|
@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
|
|||
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
||||
|
||||
file(GLOB GGML_SOURCES_CUDA "*.cu")
|
||||
file(GLOB SRCS "template-instances/fattn-tile*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/fattn-mma*.cu")
|
||||
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
||||
file(GLOB SRCS "template-instances/mmq*.cu")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,81 @@
|
|||
#include "argsort.cuh"
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/cub.cuh>
|
||||
using namespace cub;
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
|
||||
const int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int row = blockIdx.y;
|
||||
|
||||
if (col < ncols && row < nrows) {
|
||||
indices[row * ncols + col] = col;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx <= nrows) {
|
||||
offsets[idx] = idx * ncols;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
|
||||
int * temp_indices = temp_indices_alloc.get();
|
||||
float * temp_keys = temp_keys_alloc.get();
|
||||
int * d_offsets = offsets_alloc.get();
|
||||
|
||||
static const int block_size = 256;
|
||||
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
|
||||
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
|
||||
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||
|
||||
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
|
||||
stream);
|
||||
} else {
|
||||
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
|
||||
sizeof(float) * 8, stream);
|
||||
}
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||
void * d_temp_storage = temp_storage_alloc.get();
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
|
||||
stream);
|
||||
} else {
|
||||
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||
0, sizeof(float) * 8, stream);
|
||||
}
|
||||
}
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
// Bitonic sort implementation
|
||||
template<typename T>
|
||||
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||
T tmp = a;
|
||||
|
|
@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
|
|||
return n;
|
||||
}
|
||||
|
||||
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
|
||||
static void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
|
|
@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
|
|||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
|
||||
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
|
||||
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
|
||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||
|
||||
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
||||
|
||||
if (shared_mem > max_shared_mem || ncols > 1024) {
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||
} else {
|
||||
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||
}
|
||||
#else
|
||||
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
|
|
@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
|||
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
|
||||
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
|
||||
|
||||
if (block_nums.z > 65535) {
|
||||
if (block_nums.z > 65535 || block_nums.y > 65535) {
|
||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
||||
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
||||
|
|
|
|||
|
|
@ -245,7 +245,8 @@ static bool fp16_available(const int cc) {
|
|||
}
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
|
||||
return GGML_CUDA_CC_IS_AMD(cc) ||
|
||||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
|
||||
}
|
||||
|
||||
// To be used for feature selection of external libraries, e.g. cuBLAS.
|
||||
|
|
@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
|
|||
}
|
||||
|
||||
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
|
||||
// Important: do not use this function if dst and src both point at registers.
|
||||
// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
|
||||
// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
|
||||
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
|
||||
template <int nbytes, int alignment = 0>
|
||||
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
|
||||
if constexpr (alignment != 0) {
|
||||
|
|
@ -939,13 +944,6 @@ struct ggml_cuda_graph {
|
|||
bool disable_due_to_failed_graph_capture = false;
|
||||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_graph_node_properties> ggml_graph_properties;
|
||||
bool use_cpy_indirection = false;
|
||||
std::vector<char *> cpy_dest_ptrs;
|
||||
char ** dest_ptrs_d;
|
||||
int dest_ptrs_size = 0;
|
||||
// Index to allow each cpy kernel to be aware of it's position within the graph
|
||||
// relative to other cpy nodes.
|
||||
int graph_cpynode_index = -1;
|
||||
#endif
|
||||
};
|
||||
|
||||
|
|
@ -1007,3 +1005,16 @@ struct ggml_backend_cuda_context {
|
|||
return pool(device);
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_cuda_mm_fusion_args_host {
|
||||
const ggml_tensor * x_bias = nullptr;
|
||||
const ggml_tensor * gate = nullptr;
|
||||
const ggml_tensor * gate_bias = nullptr;
|
||||
ggml_glu_op glu_op;
|
||||
};
|
||||
struct ggml_cuda_mm_fusion_args_device {
|
||||
const void * x_bias = nullptr;
|
||||
const void * gate = nullptr;
|
||||
const void * gate_bias = nullptr;
|
||||
ggml_glu_op glu_op;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
|
|
|
|||
|
|
@ -8,18 +8,16 @@
|
|||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
|
||||
static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int nb12, const int nb13) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
|
||||
// then combine those indices with the corresponding byte offsets to get the total offsets
|
||||
const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
|
|
@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
|
|||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
|
||||
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int nb12, const int nb13) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
|
|
@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int
|
|||
}
|
||||
|
||||
template <cpy_kernel_t cpy_blck, int qk>
|
||||
static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int ne,
|
||||
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
const int nb12, const int nb13) {
|
||||
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
|
||||
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
char * cdst = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index]: cdst_direct;
|
||||
|
||||
const int i03 = i/(ne00 * ne01 * ne02);
|
||||
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
|
|
@ -118,67 +112,71 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
|
|||
cpy_blck(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
// Copy destination pointers to GPU to be available when pointer indirection is in use
|
||||
template<typename src_t, typename dst_t>
|
||||
static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
||||
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) {
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers
|
||||
CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
if (cuda_graph->dest_ptrs_d != nullptr) {
|
||||
CUDA_CHECK(cudaFree(cuda_graph->dest_ptrs_d));
|
||||
}
|
||||
CUDA_CHECK(cudaMalloc(&cuda_graph->dest_ptrs_d, host_dest_ptrs_size*sizeof(char *)));
|
||||
cuda_graph->dest_ptrs_size = host_dest_ptrs_size;
|
||||
if (i >= ne) {
|
||||
return;
|
||||
}
|
||||
// copy destination pointers to GPU
|
||||
CUDA_CHECK(cudaMemcpyAsync(cuda_graph->dest_ptrs_d, host_dest_ptrs, host_dest_ptrs_size*sizeof(char *), cudaMemcpyHostToDevice, stream));
|
||||
cuda_graph->graph_cpynode_index = 0; // reset index
|
||||
#else
|
||||
GGML_UNUSED_VARS(cuda_graph, host_dest_ptrs, host_dest_ptrs_size, stream);
|
||||
#endif
|
||||
|
||||
const src_t * x = (const src_t *) cx;
|
||||
dst_t * dst = (dst_t *) cdst;
|
||||
|
||||
dst[i] = ggml_cuda_cast<dst_t>(x[i]);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_flt_contiguous_cuda(
|
||||
const char * cx, char * cdst, const int64_t ne,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne);
|
||||
}
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static void ggml_cpy_flt_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q8_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK8_0 == 0);
|
||||
const int num_blocks = ne / QK8_0;
|
||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q8_0_f32_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_0 == 0);
|
||||
const int num_blocks = ne / QK4_0;
|
||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_0_f32_cuda(
|
||||
|
|
@ -187,22 +185,22 @@ static void ggml_cpy_q4_0_f32_cuda(
|
|||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q4_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_1 == 0);
|
||||
const int num_blocks = ne / QK4_1;
|
||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q4_1_f32_cuda(
|
||||
|
|
@ -211,22 +209,22 @@ static void ggml_cpy_q4_1_f32_cuda(
|
|||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_0_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_0 == 0);
|
||||
const int num_blocks = ne / QK5_0;
|
||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_0_f32_cuda(
|
||||
|
|
@ -235,22 +233,22 @@ static void ggml_cpy_q5_0_f32_cuda(
|
|||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_q5_1_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK5_1 == 0);
|
||||
const int num_blocks = ne / QK5_1;
|
||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_q5_1_f32_cuda(
|
||||
|
|
@ -259,25 +257,25 @@ static void ggml_cpy_q5_1_f32_cuda(
|
|||
const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12,
|
||||
const int nb10, const int nb11, const int nb12, const int nb13,
|
||||
cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
cudaStream_t stream) {
|
||||
const int num_blocks = ne;
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
const char * cx, char * cdst, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
||||
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ne % QK4_NL == 0);
|
||||
const int num_blocks = ne / QK4_NL;
|
||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
|
||||
|
|
@ -311,17 +309,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||
char * src0_ddc = (char *) src0->data;
|
||||
char * src1_ddc = (char *) src1->data;
|
||||
|
||||
char ** dest_ptrs_d = nullptr;
|
||||
int graph_cpynode_index = -1;
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
||||
dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
|
||||
graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(disable_indirection_for_this_node);
|
||||
#endif
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
|
||||
|
||||
if (src0->type == src1->type && contiguous_srcs) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
||||
if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
|
||||
|
|
@ -329,134 +319,94 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|||
} else
|
||||
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
|
||||
{
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else {
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
if (contiguous_srcs) {
|
||||
ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
|
||||
} else {
|
||||
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
}
|
||||
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
|
||||
if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
|
||||
ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(disable_indirection_for_this_node);
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
bool disable_indirection = true;
|
||||
ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
|
||||
}
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
// Prioritize CUDA graph compatibility over direct memory copy optimization.
|
||||
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, half>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
||||
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<half, half>>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<half, float>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<float, int32_t>>;
|
||||
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_flt<cpy_1_flt<int32_t, float>>;
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
}
|
||||
ggml_cuda_cpy(ctx, src0, dst);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,10 +2,6 @@
|
|||
|
||||
#define CUDA_CPY_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
|
||||
|
||||
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
|
||||
|
||||
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);
|
||||
|
|
|
|||
|
|
@ -793,8 +793,6 @@ void launch_fattn(
|
|||
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
|
|
@ -878,7 +876,7 @@ void launch_fattn(
|
|||
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
||||
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
||||
// multiple sequences of possibly different lengths.
|
||||
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
||||
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
||||
const int s31 = mask->nb[1] / sizeof(half2);
|
||||
const int s33 = mask->nb[3] / sizeof(half2);
|
||||
|
||||
|
|
@ -897,6 +895,7 @@ void launch_fattn(
|
|||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
GGML_ASSERT(max_blocks_per_sm > 0);
|
||||
int parallel_blocks = max_blocks_per_sm;
|
||||
|
||||
dim3 blocks_num;
|
||||
|
|
@ -916,8 +915,7 @@ void launch_fattn(
|
|||
|
||||
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
|
||||
} else {
|
||||
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
|
||||
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
|
|
@ -946,7 +944,7 @@ void launch_fattn(
|
|||
|
||||
blocks_num.x = ntiles_x;
|
||||
blocks_num.y = parallel_blocks;
|
||||
blocks_num.z = Q->ne[2]*Q->ne[3];
|
||||
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
|
|
|
|||
|
|
@ -1,756 +1,45 @@
|
|||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
#include "fattn-tile.cuh"
|
||||
#include "fattn-wmma-f16.cuh"
|
||||
|
||||
// kq_stride == number of KQ rows to process per iteration
|
||||
// kq_nbatch == number of K columns to load in parallel for KQ calculation
|
||||
|
||||
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
if (GGML_CUDA_CC_IS_RDNA(cc)) {
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 128;
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
if (fast_fp16_available(cc)) {
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
return -1;
|
||||
}
|
||||
GGML_UNUSED(warp_size);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
|
||||
#ifdef GGML_USE_HIP
|
||||
#ifdef RDNA
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 128;
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols == 32 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols == 32 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // RDNA
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
switch (D) {
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
case 128:
|
||||
return ncols <= 16 ? 64 : 32;
|
||||
case 256:
|
||||
return 32;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // GGML_USE_HIP
|
||||
GGML_UNUSED_VARS(ncols, warp_size);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) {
|
||||
#ifdef GGML_USE_HIP
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
case 256:
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
case 256:
|
||||
return 128;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#else
|
||||
switch (D) {
|
||||
case 64:
|
||||
return 64;
|
||||
case 128:
|
||||
return 128;
|
||||
case 256:
|
||||
return ncols <= 16 ? 128 : 64;
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
#endif // GGML_USE_HIP
|
||||
GGML_UNUSED_VARS(ncols, warp_size);
|
||||
}
|
||||
|
||||
static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED_VARS(cc, ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
|
||||
return 256;
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
|
||||
#ifdef RDNA
|
||||
return 3;
|
||||
#else
|
||||
return ncols <= 16 ? 3 : 2;
|
||||
#endif // RDNA
|
||||
GGML_UNUSED(ncols);
|
||||
}
|
||||
|
||||
template<int D, int ncols, bool use_logit_softcap> // D == head size
|
||||
__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
|
||||
static __global__ void flash_attn_tile(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
const int * __restrict__ KV_max,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const uint32_t n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
|
||||
const int32_t nb01, const int32_t nb02, const int32_t nb03,
|
||||
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
|
||||
const int32_t nb11, const int32_t nb12, const int64_t nb13,
|
||||
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
||||
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
||||
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
||||
#ifdef FLASH_ATTN_AVAILABLE
|
||||
|
||||
// Skip unused kernel variants for faster compilation:
|
||||
#ifdef GGML_USE_WMMA_FATTN
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
#endif // GGML_USE_WMMA_FATTN
|
||||
|
||||
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int warp_size = 32;
|
||||
constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
|
||||
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
|
||||
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
|
||||
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
|
||||
static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch");
|
||||
|
||||
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const float * sinksf = (const float *) (sinks);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
||||
constexpr int cpy_ne = cpy_nb / 4;
|
||||
|
||||
constexpr int cpw = ncols/nwarps; // cols per warp
|
||||
|
||||
// softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
|
||||
// KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
|
||||
|
||||
__shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ half2 Q_tmp[ncols][D/2];
|
||||
__shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#else
|
||||
constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
|
||||
|
||||
__shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
||||
__shared__ float Q_tmp[ncols][D];
|
||||
__shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
||||
float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
|
||||
|
||||
float KQ_max[cpw];
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
||||
KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
|
||||
}
|
||||
float KQ_sum[cpw] = {0.0f};
|
||||
|
||||
// Load Q data, convert to FP16 if fast.
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
const int j = j0 + threadIdx.y*cpw;
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
float tmp_f[cpy_ne_D] = {0.0f};
|
||||
if (ic0 + j < ne01) {
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp_f[i1] *= scale;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
||||
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Main loop over KV cache:
|
||||
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
||||
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
|
||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||
|
||||
float KQ_max_new[cpw];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < cpw; ++j) {
|
||||
KQ_max_new[j] = KQ_max[j];
|
||||
}
|
||||
|
||||
float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
|
||||
|
||||
// KQ = K @ Q matrix multiplication:
|
||||
#pragma unroll
|
||||
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
|
||||
&K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
||||
half2 tmp_h2[cpy_ne_kqnb/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_kqnb/2];
|
||||
#pragma unroll
|
||||
for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
|
||||
tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
|
||||
half2 K_k[kq_stride/warp_size][cpy_ne];
|
||||
half2 Q_k[cpw][cpy_ne];
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
|
||||
float K_k[kq_stride/warp_size][cpy_ne];
|
||||
float Q_k[cpw][cpy_ne];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
||||
#else
|
||||
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < cpy_ne; ++k) {
|
||||
ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (k_KQ_0 + kq_nbatch < D) {
|
||||
__syncthreads(); // Sync not needed on last iteration.
|
||||
}
|
||||
}
|
||||
|
||||
// Apply logit softcap, mask, update KQ_max:
|
||||
#pragma unroll
|
||||
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
||||
const int i_KQ = i_KQ_0 + threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
||||
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (use_logit_softcap) {
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
|
||||
KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||
|
||||
KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
half tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#else
|
||||
float tmp[kq_stride/warp_size][softmax_iter_j];
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
|
||||
const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
|
||||
KQ_max[j0+j1] = KQ_max_new[j0+j1];
|
||||
|
||||
float KQ_sum_add = 0.0f;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
|
||||
KQ_sum_add += val;
|
||||
tmp[i0/warp_size][j1] = val;
|
||||
}
|
||||
KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
ggml_cuda_memcpy_1<sizeof(tmp[0])>(
|
||||
KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
|
||||
}
|
||||
}
|
||||
|
||||
// VKQ = V @ KQ matrix multiplication:
|
||||
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
|
||||
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
|
||||
const int k_tile = k1 + threadIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
|
||||
&V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
half2 tmp_h2[cpy_ne_D/2];
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
||||
tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
|
||||
|
||||
float2 tmp_f2[cpy_ne_D/2];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
|
||||
tmp_f2[i1] = __half22float2(tmp_h2[i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
||||
&KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
half2 V_k[(D/2)/warp_size];
|
||||
half2 KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
half tmp[softmax_iter_j];
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
|
||||
&tmp, KQ[j][k0 + k1]);
|
||||
#pragma unroll
|
||||
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
||||
KQ_k[j0+j1] = __half2half2(tmp[j1]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
||||
float2 V_k[(D/2)/warp_size];
|
||||
float KQ_k[cpw];
|
||||
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
||||
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
||||
|
||||
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
|
||||
&KQ_k[j0], KQ[j][k0 + k1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
|
||||
VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Attention sink: adjust running max and sum once per head
|
||||
if (sinksf && blockIdx.y == 0) {
|
||||
const float sink = sinksf[head];
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cpw; ++j0) {
|
||||
float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
|
||||
KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
|
||||
|
||||
const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
|
||||
KQ_max[j0] = KQ_max_new_j;
|
||||
|
||||
const float val = expf(sink - KQ_max[j0]);
|
||||
KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
KQ_sum[j0] += val;
|
||||
}
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
||||
VKQ[j0][i0/warp_size].x *= KQ_max_scale;
|
||||
VKQ[j0][i0/warp_size].y *= KQ_max_scale;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
if (gridDim.y == 1) {
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
|
||||
}
|
||||
#else
|
||||
const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (D/2)/warp_size; ++i) {
|
||||
VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
|
||||
VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
// Write back results:
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
|
||||
|
||||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#ifdef FAST_FP16_AVAILABLE
|
||||
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
||||
float2 tmp[cpy_ne_D];
|
||||
#pragma unroll
|
||||
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
||||
tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
|
||||
}
|
||||
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
|
||||
}
|
||||
#else
|
||||
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
||||
ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
||||
&dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
|
||||
}
|
||||
#endif // FAST_FP16_AVAILABLE
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
|
||||
max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
ne00, ne01, ne02, ne03,
|
||||
nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
ne31, ne32, ne33,
|
||||
nb31, nb32, nb33);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FLASH_ATTN_AVAILABLE
|
||||
}
|
||||
|
||||
template <int D, bool use_logit_softcap>
|
||||
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const int warp_size = 32;
|
||||
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
|
||||
#ifdef GGML_USE_HIP
|
||||
if constexpr (D <= 128) {
|
||||
if (Q->ne[1] > 32) {
|
||||
constexpr int cols_per_block = 64;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif // GGML_USE_HIP
|
||||
|
||||
if (Q->ne[1] > 16) {
|
||||
constexpr int cols_per_block = 32;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int cols_per_block = 16;
|
||||
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
||||
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
||||
launch_fattn<D, cols_per_block, 1>
|
||||
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
||||
}
|
||||
|
||||
template <bool use_logit_softcap>
|
||||
static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
switch (Q->ne[0]) {
|
||||
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
switch (K->ne[0]) {
|
||||
case 40: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
|
||||
} break;
|
||||
case 64: {
|
||||
launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
|
||||
} break;
|
||||
case 80: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
|
||||
} break;
|
||||
case 96: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
|
||||
} break;
|
||||
case 112: {
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
|
||||
} break;
|
||||
case 128: {
|
||||
launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
|
||||
} break;
|
||||
case 256: {
|
||||
launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst);
|
||||
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
||||
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
||||
} break;
|
||||
case 576: {
|
||||
GGML_ASSERT(V->ne[0] == 512);
|
||||
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("Unsupported head size");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap == 0.0f) {
|
||||
constexpr bool use_logit_softcap = false;
|
||||
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
|
||||
} else {
|
||||
constexpr bool use_logit_softcap = true;
|
||||
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
|
|||
const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
|
||||
constexpr bool need_f16_K = false;
|
||||
constexpr bool need_f16_V = false;
|
||||
const bool need_f16_K = type_K == GGML_TYPE_F16;
|
||||
const bool need_f16_V = type_V == GGML_TYPE_F16;
|
||||
constexpr size_t nbytes_shared = 0;
|
||||
launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
|
||||
}
|
||||
|
|
@ -526,11 +526,6 @@ template <int D, ggml_type type_K, ggml_type type_V>
|
|||
void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
GGML_ASSERT(K->type == type_K);
|
||||
GGML_ASSERT(V->type == type_V);
|
||||
|
||||
float logit_softcap;
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
|
||||
|
|
|
|||
|
|
@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
|||
}
|
||||
}
|
||||
|
||||
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
||||
{ \
|
||||
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
|
||||
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
|
||||
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
|
||||
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
||||
return; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
|
||||
FATTN_VEC_CASE( 64, type_K, type_V) \
|
||||
|
|
@ -198,6 +202,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
return BEST_FATTN_KERNEL_NONE;
|
||||
#endif// FLASH_ATTN_AVAILABLE
|
||||
|
||||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
|
@ -206,37 +211,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
// The effective batch size for the kernel can be increased by gqa_ratio.
|
||||
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
|
||||
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
// TODO: temporary until support is extended
|
||||
// https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
|
||||
if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
switch (K->ne[0]) {
|
||||
case 40:
|
||||
case 64:
|
||||
case 128:
|
||||
case 256:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 80:
|
||||
case 96:
|
||||
case 128:
|
||||
case 112:
|
||||
case 256:
|
||||
if (V->ne[0] != K->ne[0]) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
case 576:
|
||||
if (V->ne[0] != 512) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
|
||||
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
|
|
@ -251,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
#endif // GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
switch (K->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
|
|
@ -270,47 +271,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0;
|
||||
|
||||
// If Turing tensor cores available, use them except for some cases with batch size 1:
|
||||
if (turing_mma_available(cc)) {
|
||||
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
|
||||
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
// If Turing tensor cores available, use them:
|
||||
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
|
||||
if (can_use_vector_kernel) {
|
||||
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
} else {
|
||||
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
||||
if (Q->ne[1] <= 2) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] == 1) {
|
||||
best = BEST_FATTN_KERNEL_VEC;
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
}
|
||||
if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
|
||||
best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
|
||||
if (!gqa_opt_applies && Q->ne[1] == 1) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
|
||||
return best;
|
||||
return BEST_FATTN_KERNEL_MMA_F16;
|
||||
}
|
||||
|
||||
// Use kernels specialized for small batch sizes if possible:
|
||||
if (Q->ne[1] <= 8 && can_use_vector_kernel) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
|
||||
// For large batch sizes, use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc)) {
|
||||
// Use the WMMA kernel if possible:
|
||||
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
|
||||
if (can_use_vector_kernel && Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
return BEST_FATTN_KERNEL_WMMA_F16;
|
||||
}
|
||||
|
||||
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
|
||||
// If there are no tensor cores available, use the generic tile kernel:
|
||||
if (can_use_vector_kernel) {
|
||||
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
||||
if (Q->ne[1] == 1) {
|
||||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] <= 2) {
|
||||
return BEST_FATTN_KERNEL_VEC;
|
||||
}
|
||||
}
|
||||
}
|
||||
return BEST_FATTN_KERNEL_TILE;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -231,7 +231,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
|
||||
info.default_tensor_split[id] = total_vram;
|
||||
total_vram += prop.totalGlobalMem;
|
||||
info.devices[id].integrated = prop.integrated;
|
||||
info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034)
|
||||
info.devices[id].nsm = prop.multiProcessorCount;
|
||||
info.devices[id].smpb = prop.sharedMemPerBlock;
|
||||
info.devices[id].warp_size = prop.warpSize;
|
||||
|
|
@ -273,6 +273,15 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|||
} else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
|
||||
turing_devices_without_mma.push_back({ id, device_name });
|
||||
}
|
||||
|
||||
// Temporary performance fix:
|
||||
// Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
|
||||
// TODO: Check for future drivers the default scheduling strategy and
|
||||
// remove this call again when cudaDeviceScheduleSpin is default.
|
||||
if (prop.major == 12 && prop.minor == 1) {
|
||||
CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
|
||||
}
|
||||
|
||||
#endif // defined(GGML_USE_HIP)
|
||||
}
|
||||
|
||||
|
|
@ -1948,8 +1957,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
|||
|
||||
size_t src1_stride_size = sizeof(cuda_t);
|
||||
|
||||
dim3 block_dims(ne13, ne12);
|
||||
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
|
||||
const int threads_x = 16;
|
||||
const int threads_y = 16;
|
||||
dim3 block_dims(threads_x, threads_y);
|
||||
|
||||
dim3 grid_dims(
|
||||
(ne13 + threads_x - 1) / threads_x,
|
||||
(ne12 + threads_y - 1) / threads_y
|
||||
);
|
||||
k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
|
||||
src0_ptr, src1_ptr, dst_t,
|
||||
ptrs_src.get(), ptrs_dst.get(),
|
||||
ne12, ne13,
|
||||
|
|
@ -1998,6 +2014,147 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|||
}
|
||||
}
|
||||
|
||||
static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
|
||||
const ggml_tensor * ffn_gate,
|
||||
const ggml_tensor * glu,
|
||||
const ggml_tensor * ffn_up_bias = nullptr,
|
||||
const ggml_tensor * ffn_gate_bias = nullptr) {
|
||||
const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;
|
||||
|
||||
if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
|
||||
const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
|
||||
|
||||
GGML_ASSERT(ffn_up && ffn_gate && glu);
|
||||
|
||||
if (!is_mul_mat && !is_mul_mat_id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (has_bias) {
|
||||
if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (expected_bias_op == GGML_OP_ADD) {
|
||||
const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;
|
||||
const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;
|
||||
if (!up_has_mul || !gate_has_mul) {
|
||||
return false;
|
||||
}
|
||||
} else { // GGML_OP_ADD_ID
|
||||
if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {
|
||||
return false;
|
||||
}
|
||||
if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
|
||||
!ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ffn_up->src[1] != ffn_gate->src[1]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };
|
||||
|
||||
if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
|
||||
ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
|
||||
|
||||
//TODO: add support for fusion for split buffers
|
||||
if (split) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
|
||||
ggml_tensor * src0 = tensor->src[0];
|
||||
ggml_tensor * src1 = tensor->src[1];
|
||||
const ggml_tensor * dst = tensor;
|
||||
|
||||
const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
|
||||
|
||||
bool use_mul_mat_vec_f =
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
|
||||
src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
|
||||
|
||||
//we only support fusion for ncols_dst = 1
|
||||
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
return use_mul_mat_vec_f;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
|
||||
ggml_tensor * src0 = tensor->src[0];
|
||||
ggml_tensor * src1 = tensor->src[1];
|
||||
const ggml_tensor * dst = tensor;
|
||||
|
||||
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
|
||||
ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
|
||||
src0->view_src;
|
||||
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
|
||||
dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
|
||||
// fusion is not universally faster on Pascal
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
if (cc <= GGML_CUDA_CC_PASCAL) {
|
||||
return false;
|
||||
}
|
||||
//we only support fusion for ncols_dst = 1
|
||||
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return use_mul_mat_vec_q;
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
||||
|
||||
|
|
@ -2633,11 +2790,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
|||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
||||
static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
|
||||
bool use_cuda_graph) {
|
||||
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
|
||||
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
|
||||
|
|
@ -2688,33 +2844,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
|||
#endif
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_CPY) {
|
||||
|
||||
// Store the pointers which are updated for each token, such that these can be sent
|
||||
// to the device and accessed using indirection from CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
|
||||
|
||||
// store a pointer to each copy op CUDA kernel to identify it later
|
||||
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
||||
if (!ptr) {
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
if (!use_cuda_graph) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_cuda_graph) {
|
||||
cuda_ctx->cuda_graph->use_cpy_indirection = true;
|
||||
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
|
||||
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
|
||||
}
|
||||
|
||||
return use_cuda_graph;
|
||||
}
|
||||
|
||||
|
|
@ -2733,7 +2867,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
|
|||
|
||||
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
||||
if (node->data != graph_node_properties->node_address &&
|
||||
node->op != GGML_OP_CPY &&
|
||||
node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -2754,14 +2887,13 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
|||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node->src[i] &&
|
||||
node->src[i]->data != graph_node_properties->src_address[i] &&
|
||||
node->op != GGML_OP_CPY &&
|
||||
node->op != GGML_OP_VIEW
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_SCALE &&
|
||||
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
|
||||
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -2834,43 +2966,74 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
#endif
|
||||
|
||||
//TODO: remove special case once ggml_can_fuse can handle empty nodes
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
|
||||
|
||||
if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
|
||||
|
||||
if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
|
||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
|
||||
}
|
||||
if (ops.size() == topk_moe_ops_with_norm.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx+8];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
|
||||
|
||||
if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < topk_moe_ops.size(); i++) {
|
||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx+4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2, node_idx + 5 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
|
||||
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
|
||||
|
||||
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
|
||||
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
|
||||
|
||||
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
|
||||
|
||||
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
|
||||
const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3];
|
||||
const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
|
||||
|
||||
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -2901,7 +3064,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
}
|
||||
|
||||
//if rms norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2962,21 +3125,35 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
if (!disable_fusion) {
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i+8];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
|
||||
i += 8;
|
||||
ggml_tensor * weights = cgraph->nodes[i + 9];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
|
||||
ggml_tensor * clamp = cgraph->nodes[i + 7];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
|
||||
/*delayed softmax*/ false, clamp);
|
||||
i += 9;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i+4];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
|
||||
ggml_tensor * weights = cgraph->nodes[i + 4];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
|
||||
/*delayed softmax*/ false);
|
||||
i += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i,
|
||||
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i + 5];
|
||||
ggml_tensor * ids = cgraph->nodes[i + 1];
|
||||
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
|
||||
/*delayed_softmax*/ true);
|
||||
i += 5;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
int n_fuse = 0;
|
||||
ggml_op ops[8];
|
||||
|
|
@ -3008,6 +3185,184 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||
}
|
||||
}
|
||||
|
||||
bool fused_mul_mat_vec = false;
|
||||
int fused_node_count = 0;
|
||||
|
||||
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
|
||||
ggml_tensor * glu = cgraph->nodes[i + 4];
|
||||
ggml_tensor * gate_bias_n = glu->src[0];
|
||||
ggml_tensor * up_bias_n = glu->src[1];
|
||||
|
||||
//we don't assume the order for {gate, up}. Instead infer it from the bias tensor
|
||||
ggml_tensor * gate_n = nullptr;
|
||||
ggml_tensor * up_n = nullptr;
|
||||
|
||||
if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
|
||||
gate_n = cgraph->nodes[i];
|
||||
up_n = cgraph->nodes[i + 2];
|
||||
} else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
|
||||
gate_n = cgraph->nodes[i + 2];
|
||||
up_n = cgraph->nodes[i];
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
|
||||
if (op_bias == GGML_OP_ADD) {
|
||||
if (bias_node->src[0] == mul_node) {
|
||||
return bias_node->src[1];
|
||||
}
|
||||
if (bias_node->src[1] == mul_node) {
|
||||
return bias_node->src[0];
|
||||
}
|
||||
return (ggml_tensor *) nullptr;
|
||||
}
|
||||
GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
|
||||
GGML_ASSERT(bias_node->src[0] == mul_node);
|
||||
return bias_node->src[1];
|
||||
};
|
||||
|
||||
ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
|
||||
ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
|
||||
|
||||
if (!up_bias_tensor || !gate_bias_tensor) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = up_n->src[0];
|
||||
const ggml_tensor * src1 = up_n->src[1];
|
||||
const ggml_tensor * ids = up_n->src[2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate_n->src[0];
|
||||
fusion_data.x_bias = up_bias_tensor;
|
||||
fusion_data.gate_bias = gate_bias_tensor;
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 5;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate_n->src[0];
|
||||
fusion_data.x_bias = up_bias_tensor;
|
||||
fusion_data.gate_bias = gate_bias_tensor;
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 5;
|
||||
break;
|
||||
}
|
||||
} else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
|
||||
ggml_tensor * glu = cgraph->nodes[i + 2];
|
||||
ggml_tensor * gate = glu->src[0];
|
||||
ggml_tensor * up = glu->src[1];
|
||||
|
||||
bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
|
||||
|| (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
|
||||
|
||||
if (!ok) continue;
|
||||
|
||||
const ggml_tensor * src0 = up->src[0];
|
||||
const ggml_tensor * src1 = up->src[1];
|
||||
const ggml_tensor * ids = up->src[2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate->src[0];
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 3;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate->src[0];
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 3;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (fused_mul_mat_vec) {
|
||||
i += fused_node_count - 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
fused_mul_mat_vec = false;
|
||||
fused_node_count = 0;
|
||||
|
||||
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_tensor * mm_node = cgraph->nodes[i];
|
||||
ggml_tensor * bias_node = cgraph->nodes[i + 1];
|
||||
|
||||
ggml_tensor * bias_tensor = nullptr;
|
||||
if (bias_op == GGML_OP_ADD) {
|
||||
if (bias_node->src[0] == mm_node) {
|
||||
bias_tensor = bias_node->src[1];
|
||||
} else if (bias_node->src[1] == mm_node) {
|
||||
bias_tensor = bias_node->src[0];
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (bias_node->src[0] != mm_node) {
|
||||
continue;
|
||||
}
|
||||
bias_tensor = bias_node->src[1];
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = mm_node->src[0];
|
||||
const ggml_tensor * src1 = mm_node->src[1];
|
||||
const ggml_tensor * ids = mm_node->src[2];
|
||||
|
||||
if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.x_bias = bias_tensor;
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 2;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 2;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (fused_mul_mat_vec) {
|
||||
i += fused_node_count - 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
|
||||
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||
|
|
@ -3120,7 +3475,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||
if (use_cuda_graph) {
|
||||
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
||||
|
||||
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
|
||||
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
|
||||
|
||||
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
|
||||
if (use_cuda_graph && cuda_graph_update_required) {
|
||||
|
|
@ -3147,10 +3502,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|||
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
|
||||
}
|
||||
|
||||
if (!use_cuda_graph) {
|
||||
cuda_ctx->cuda_graph->use_cpy_indirection = false;
|
||||
}
|
||||
|
||||
#else
|
||||
bool use_cuda_graph = false;
|
||||
bool cuda_graph_update_required = false;
|
||||
|
|
@ -3645,12 +3996,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_ARGSORT:
|
||||
// TODO: Support arbitrary column width
|
||||
#ifndef GGML_CUDA_USE_CUB
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
|
|
@ -3867,7 +4222,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
|||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||
|
||||
ggml_cuda_set_device(i);
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
#include "ggml.h"
|
||||
#include "mmf.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
|
|
@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
|
||||
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
|
||||
|
||||
mmf_ids_data ids_info{};
|
||||
mmf_ids_data * ids_info_ptr = nullptr;
|
||||
ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
|
||||
ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
|
||||
ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
|
||||
|
||||
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
|
||||
const int64_t ncols_dst = ids ? ne2 : ne1;
|
||||
const int64_t nchannels_dst = ids ? ne1 : ne2;
|
||||
|
|
@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
nchannels_y = ids->ne[0];
|
||||
}
|
||||
|
||||
if (ids && ncols_dst > 16) {
|
||||
const int64_t n_expert_used = ids->ne[0];
|
||||
const int64_t n_experts = ne02;
|
||||
const int64_t n_tokens = ne12;
|
||||
const int64_t ne_get_rows = n_tokens * n_expert_used;
|
||||
|
||||
ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
|
||||
ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
|
||||
expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
|
||||
|
||||
const int si1 = static_cast<int>(ids_s1);
|
||||
const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
|
||||
|
||||
GGML_ASSERT(sis1 > 0);
|
||||
|
||||
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
|
||||
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
ids_info.ids_src_compact = ids_src_compact_dev.get();
|
||||
ids_info.ids_dst_compact = ids_dst_compact_dev.get();
|
||||
ids_info.expert_bounds_dev = expert_bounds_dev.get();
|
||||
ids_info.n_experts = static_cast<int>(n_experts);
|
||||
ids_info.sis1 = sis1;
|
||||
ids_info_ptr = &ids_info;
|
||||
}
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
|
|
@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half2 * src0_d = (const half2 *) src0->data;
|
||||
|
|
@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
||||
|
|
@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
mul_mat_f_switch_cols_per_block(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
||||
|
|
@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
}
|
||||
|
||||
if (mul_mat_id) {
|
||||
if (type == GGML_TYPE_F32 && src1_ncols > 32) {
|
||||
if (src0_ne[1] <= 1024 && src1_ncols > 512) {
|
||||
return false;
|
||||
}
|
||||
if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
|
||||
} else if(src0_ne[1] > 1024 && src1_ncols > 128) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,14 @@ using namespace ggml_cuda_mma;
|
|||
|
||||
#define MMF_ROWS_PER_BLOCK 32
|
||||
|
||||
struct mmf_ids_data {
|
||||
const int32_t * ids_src_compact = nullptr;
|
||||
const int32_t * ids_dst_compact = nullptr;
|
||||
const int32_t * expert_bounds_dev = nullptr;
|
||||
int n_experts = 0;
|
||||
int sis1 = 0;
|
||||
};
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
|
||||
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
|
||||
|
|
@ -224,6 +232,250 @@ static __global__ void mul_mat_f(
|
|||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
|
||||
//This kernel is for larger batch sizes of mul_mat_id
|
||||
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
|
||||
static __global__ void mul_mat_f_ids(
|
||||
const T * __restrict__ x, const float * __restrict__ y,
|
||||
const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
|
||||
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
|
||||
const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
constexpr int tile_k_padded = warp_size + 4;
|
||||
constexpr int ntA = rows_per_block / tile_A::I;
|
||||
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
|
||||
|
||||
const int row0 = blockIdx.x * rows_per_block;
|
||||
|
||||
const int expert_idx = blockIdx.y;
|
||||
const int expert_start = expert_bounds[expert_idx];
|
||||
const int expert_end = expert_bounds[expert_idx + 1];
|
||||
const int ncols_expert = expert_end - expert_start;
|
||||
|
||||
const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
|
||||
const int tile_idx = blockIdx.z;
|
||||
if (tile_idx >= tiles_for_expert) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int col_base = tile_idx * cols_per_block;
|
||||
|
||||
GGML_UNUSED(channel_ratio);
|
||||
|
||||
const int channel_x = expert_idx;
|
||||
const int sample_dst = 0;
|
||||
const int sample_x = sample_dst / sample_ratio;
|
||||
const int sample_y = sample_dst;
|
||||
|
||||
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
|
||||
y += int64_t(sample_y) *stride_sample_y;
|
||||
dst += int64_t(sample_dst)*stride_sample_dst;
|
||||
|
||||
const int32_t * ids_src_expert = ids_src_compact + expert_start;
|
||||
const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
char * compute_base = data_mmv;
|
||||
|
||||
//const float2 * y2 = (const float2 *) y;
|
||||
|
||||
tile_C C[ntA][ntB];
|
||||
|
||||
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
|
||||
|
||||
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
|
||||
tile_A A[ntA][warp_size / tile_A::J];
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < tile_A::I; ++i) {
|
||||
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
|
||||
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
float vals_buf[2][tile_B::I];
|
||||
auto gather_tile = [&](int tile_idx_local, float *vals) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const int j = j0 + tile_idx_local*tile_B::I;
|
||||
const int global_j = col_base + j;
|
||||
float val = 0.0f;
|
||||
if (j < cols_per_block && global_j < ncols_expert) {
|
||||
const int src_entry = ids_src_expert[global_j];
|
||||
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
|
||||
const int token = (int) qrm.x;
|
||||
const int channel = (int) qrm.y;
|
||||
if (token < ncols_dst_total) {
|
||||
val = y[channel*stride_channel_y + token*stride_col_y + col];
|
||||
}
|
||||
}
|
||||
vals[j0] = val;
|
||||
}
|
||||
};
|
||||
|
||||
gather_tile(0, vals_buf[0]);
|
||||
|
||||
int curr_buf = 0;
|
||||
int next_buf = 1;
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
gather_tile(itB + 1, vals_buf[next_buf]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||
tile_B B;
|
||||
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||
}
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
curr_buf ^= 1;
|
||||
next_buf ^= 1;
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
|
||||
float2 vals_buf[2][tile_B::I];
|
||||
auto gather_tile = [&](int tile_idx_local, float2 *vals) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const int j = j0 + tile_idx_local*tile_B::I;
|
||||
const int global_j = col_base + j;
|
||||
float2 tmp = make_float2(0.0f, 0.0f);
|
||||
if (j < cols_per_block && global_j < ncols_expert) {
|
||||
const int src_entry = ids_src_expert[global_j];
|
||||
const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
|
||||
const int token = (int) qrm.x;
|
||||
const int channel = (int) qrm.y;
|
||||
if (token < ncols_dst_total) {
|
||||
tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
|
||||
}
|
||||
}
|
||||
vals[j0] = tmp;
|
||||
}
|
||||
};
|
||||
|
||||
if (ntB > 0) {
|
||||
gather_tile(0, vals_buf[0]);
|
||||
}
|
||||
|
||||
int curr_buf = 0;
|
||||
int next_buf = 1;
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||
const float2 tmp = vals_buf[curr_buf][j0];
|
||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
gather_tile(itB + 1, vals_buf[next_buf]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
|
||||
tile_B B;
|
||||
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
|
||||
}
|
||||
}
|
||||
|
||||
if (itB + 1 < ntB) {
|
||||
curr_buf ^= 1;
|
||||
next_buf ^= 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
}
|
||||
|
||||
float * buf_iw = (float *) compute_base;
|
||||
constexpr int kiw = nwarps*rows_per_block + 4;
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
#pragma unroll
|
||||
for (int itB = 0; itB < ntB; ++itB) {
|
||||
#pragma unroll
|
||||
for (int itA = 0; itA < ntA; ++itA) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
|
||||
const int j = itB*tile_C::J + tile_C::get_j(l);
|
||||
buf_iw[j*kiw + i] = C[itA][itB].x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (nwarps > 1) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
|
||||
const int j = j0 + threadIdx.y;
|
||||
|
||||
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
|
||||
return;
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
||||
sum += buf_iw[j*kiw + i];
|
||||
}
|
||||
|
||||
const int global_j = col_base + j;
|
||||
if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
|
||||
const int dst_entry = ids_dst_expert[global_j];
|
||||
const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
|
||||
const int token = (int) qrm.x;
|
||||
if (token < ncols_dst_total) {
|
||||
const int slot = (int) qrm.y;
|
||||
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
template<typename T, int cols_per_block, int nwarps>
|
||||
static inline void mul_mat_f_switch_ids(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
|
|
@ -232,13 +484,35 @@ static inline void mul_mat_f_switch_ids(
|
|||
const int64_t stride_col_id, const int64_t stride_row_id,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
|
||||
if (ids) {
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
|
||||
const mmf_ids_data * ids_data) {
|
||||
const bool has_ids_data = ids_data && ids_data->ids_src_compact;
|
||||
|
||||
// Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
|
||||
// we prefer the normal mul_mat_f path with has_ids=true.
|
||||
if (has_ids_data && ncols_dst > 16) {
|
||||
const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
|
||||
if (max_tiles == 0) {
|
||||
return;
|
||||
}
|
||||
dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
|
||||
|
||||
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
|
||||
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
|
||||
|
||||
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
|
||||
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
sis1_fd, nch_fd);
|
||||
} else if (ids) {
|
||||
const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
|
||||
dim3 block_nums_ids = block_nums;
|
||||
block_nums_ids.y *= col_tiles;
|
||||
|
||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} else {
|
||||
|
|
@ -258,7 +532,7 @@ void mul_mat_f_cuda(
|
|||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile< 8, 8, T> tile_B;
|
||||
|
||||
|
|
@ -290,7 +564,7 @@ void mul_mat_f_cuda(
|
|||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
|
||||
const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
|
||||
|
||||
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
|
||||
const dim3 block_dims(warp_size, nwarps_best, 1);
|
||||
|
|
@ -300,49 +574,57 @@ void mul_mat_f_cuda(
|
|||
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||
ids_data);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
cudaStream_t stream) {
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||
|
||||
const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
|
||||
|
||||
|
|
@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
case 1: {
|
||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 2: {
|
||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 3: {
|
||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 4: {
|
||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 5: {
|
||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 6: {
|
||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 7: {
|
||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 8: {
|
||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 9: {
|
||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 10: {
|
||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 11: {
|
||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 12: {
|
||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 13: {
|
||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 14: {
|
||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 15: {
|
||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
case 16: {
|
||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block(
|
|||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
|
||||
cudaStream_t stream);
|
||||
cudaStream_t stream, const mmf_ids_data * ids_data);
|
||||
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
|
||||
|
|
|
|||
|
|
@ -0,0 +1,164 @@
|
|||
#include "common.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
||||
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
|
||||
struct mm_ids_helper_store {
|
||||
uint32_t data;
|
||||
|
||||
__device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
|
||||
data = (it & 0x003FFFFF) | (iex_used << 22);
|
||||
}
|
||||
|
||||
__device__ uint32_t it() const {
|
||||
return data & 0x003FFFFF;
|
||||
}
|
||||
|
||||
__device__ uint32_t iex_used() const {
|
||||
return data >> 22;
|
||||
}
|
||||
};
|
||||
static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
|
||||
|
||||
// Helper function for mul_mat_id, converts ids to a more convenient format.
|
||||
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
|
||||
// ids_dst describes the same mapping but for the dst tensor.
|
||||
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
|
||||
template <int n_expert_used_template>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mm_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
|
||||
const int expert = blockIdx.x;
|
||||
|
||||
extern __shared__ char data_mm_ids_helper[];
|
||||
mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
|
||||
|
||||
int nex_prev = 0; // Number of columns for experts with a lower index.
|
||||
int it_compact = 0; // Running index for the compact slice of this expert.
|
||||
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
// Generic implementation:
|
||||
for (int it = 0; it < n_tokens; ++it) {
|
||||
int iex_used = -1; // The index at which the expert is used, if any.
|
||||
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
|
||||
const int expert_used = ids[it*si1 + iex];
|
||||
nex_prev += expert_used < expert;
|
||||
if (expert_used == expert) {
|
||||
iex_used = iex;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact] = mm_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
if (warp_reduce_any<warp_size>(iex_used != -1)) {
|
||||
it_compact++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Implementation optimized for specific numbers of experts used:
|
||||
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
|
||||
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
|
||||
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
|
||||
const int it = it0 + threadIdx.x / neu_padded;
|
||||
|
||||
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
|
||||
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
|
||||
ids[it*si1 + iex] : INT_MAX;
|
||||
const int iex_used = expert_used == expert ? iex : -1;
|
||||
nex_prev += expert_used < expert;
|
||||
|
||||
// Whether the threads at this token position have used the expert:
|
||||
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
|
||||
|
||||
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
|
||||
int it_compact_add_lower = 0;
|
||||
#pragma unroll
|
||||
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
|
||||
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
|
||||
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
|
||||
it_compact_add_lower += tmp;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
|
||||
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
|
||||
}
|
||||
}
|
||||
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
|
||||
|
||||
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
|
||||
const mm_ids_helper_store store_it = store[itc];
|
||||
const int it = store_it.it();
|
||||
const int iex_used = store_it.iex_used();
|
||||
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
|
||||
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
|
||||
}
|
||||
|
||||
if (threadIdx.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[expert] = nex_prev;
|
||||
|
||||
if (expert < static_cast<int>(gridDim.x) - 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[gridDim.x] = nex_prev + it_compact;
|
||||
}
|
||||
|
||||
template <int n_expert_used_template>
|
||||
static void launch_mm_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
|
||||
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
|
||||
|
||||
const dim3 num_blocks(n_experts, 1, 1);
|
||||
const dim3 block_size(warp_size, 1, 1);
|
||||
const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
|
||||
GGML_ASSERT(nbytes_shared <= smpbo);
|
||||
mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
|
||||
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
|
||||
}
|
||||
|
||||
void ggml_cuda_launch_mm_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
switch (n_expert_used) {
|
||||
case 2:
|
||||
launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 16:
|
||||
launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
default:
|
||||
launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
|
||||
void ggml_cuda_launch_mm_ids_helper(
|
||||
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
|
||||
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);
|
||||
|
|
@ -1,141 +1,6 @@
|
|||
#include "mmq.cuh"
|
||||
#include "quantize.cuh"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
|
||||
struct mmq_ids_helper_store {
|
||||
uint32_t data;
|
||||
|
||||
__device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
|
||||
data = (it & 0x003FFFFF) | (iex_used << 22);
|
||||
}
|
||||
|
||||
__device__ uint32_t it() const {
|
||||
return data & 0x003FFFFF;
|
||||
}
|
||||
|
||||
__device__ uint32_t iex_used() const {
|
||||
return data >> 22;
|
||||
}
|
||||
};
|
||||
static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
|
||||
|
||||
// Helper function for mul_mat_id, converts ids to a more convenient format.
|
||||
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
|
||||
// ids_dst describes the same mapping but for the dst tensor.
|
||||
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
|
||||
template <int n_expert_used_template>
|
||||
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
|
||||
const int expert = blockIdx.x;
|
||||
|
||||
extern __shared__ char data_mmq_ids_helper[];
|
||||
mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
|
||||
|
||||
int nex_prev = 0; // Number of columns for experts with a lower index.
|
||||
int it_compact = 0; // Running index for the compact slice of this expert.
|
||||
|
||||
if constexpr (n_expert_used_template == 0) {
|
||||
// Generic implementation:
|
||||
for (int it = 0; it < n_tokens; ++it) {
|
||||
int iex_used = -1; // The index at which the expert is used, if any.
|
||||
for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
|
||||
const int expert_used = ids[it*si1 + iex];
|
||||
nex_prev += expert_used < expert;
|
||||
if (expert_used == expert) {
|
||||
iex_used = iex;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
if (warp_reduce_any<warp_size>(iex_used != -1)) {
|
||||
it_compact++;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Implementation optimized for specific numbers of experts used:
|
||||
static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
|
||||
const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
|
||||
for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
|
||||
const int it = it0 + threadIdx.x / neu_padded;
|
||||
|
||||
const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
|
||||
const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
|
||||
ids[it*si1 + iex] : INT_MAX;
|
||||
const int iex_used = expert_used == expert ? iex : -1;
|
||||
nex_prev += expert_used < expert;
|
||||
|
||||
// Whether the threads at this token position have used the expert:
|
||||
const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
|
||||
|
||||
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
|
||||
int it_compact_add_lower = 0;
|
||||
#pragma unroll
|
||||
for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
|
||||
const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
|
||||
if (threadIdx.x >= static_cast<unsigned int>(offset)) {
|
||||
it_compact_add_lower += tmp;
|
||||
}
|
||||
}
|
||||
|
||||
if (iex_used != -1) {
|
||||
store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
|
||||
}
|
||||
|
||||
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
|
||||
it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
|
||||
}
|
||||
}
|
||||
nex_prev = warp_reduce_sum<warp_size>(nex_prev);
|
||||
|
||||
for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
|
||||
const mmq_ids_helper_store store_it = store[itc];
|
||||
const int it = store_it.it();
|
||||
const int iex_used = store_it.iex_used();
|
||||
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
|
||||
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
|
||||
}
|
||||
|
||||
if (threadIdx.x != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[expert] = nex_prev;
|
||||
|
||||
if (expert < static_cast<int>(gridDim.x) - 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
expert_bounds[gridDim.x] = nex_prev + it_compact;
|
||||
}
|
||||
|
||||
template <int n_expert_used_template>
|
||||
static void launch_mmq_ids_helper(
|
||||
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
|
||||
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
|
||||
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mmq_ids_helper_store");
|
||||
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[id].warp_size;
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
|
||||
|
||||
const dim3 num_blocks(n_experts, 1, 1);
|
||||
const dim3 block_size(warp_size, 1, 1);
|
||||
const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
|
||||
GGML_ASSERT(nbytes_shared <= smpbo);
|
||||
mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
|
||||
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
|
||||
}
|
||||
#include "mmid.cuh"
|
||||
|
||||
static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
|
||||
switch (args.type_x) {
|
||||
|
|
@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q(
|
|||
const int si1 = ids->nb[1] / ggml_element_size(ids);
|
||||
const int sis1 = nb12 / nb11;
|
||||
|
||||
switch (n_expert_used) {
|
||||
case 2:
|
||||
launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 16:
|
||||
launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
default:
|
||||
launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
break;
|
||||
}
|
||||
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
|
||||
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,21 @@
|
|||
#include "ggml.h"
|
||||
#include "common.cuh"
|
||||
#include "convert.cuh"
|
||||
#include "unary.cuh"
|
||||
#include "mmvf.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
||||
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
|
||||
static __global__ void mul_mat_vec_f(
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
|
||||
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
const int row = blockIdx.x;
|
||||
const int channel_dst = blockIdx.y;
|
||||
const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
|
||||
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
|
||||
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
|
||||
const int sample_dst = blockIdx.z;
|
||||
const int sample_x = sample_dst / sample_ratio;
|
||||
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
|
||||
const int sample_y = sample_dst;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
|
|
@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f(
|
|||
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
||||
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
||||
|
||||
bool use_gate = false;
|
||||
bool use_bias = false;
|
||||
bool use_gate_bias = false;
|
||||
ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
|
||||
const T * gate_x = nullptr;
|
||||
const float * x_bias = nullptr;
|
||||
const float * gate_bias = nullptr;
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
use_gate = fusion.gate != nullptr;
|
||||
use_bias = fusion.x_bias != nullptr;
|
||||
use_gate_bias = fusion.gate_bias != nullptr;
|
||||
glu_op = fusion.glu_op;
|
||||
|
||||
if (use_gate) {
|
||||
gate_x = static_cast<const T *>(fusion.gate);
|
||||
}
|
||||
if (use_bias) {
|
||||
x_bias = static_cast<const float *>(fusion.x_bias);
|
||||
}
|
||||
if (use_gate_bias) {
|
||||
gate_bias = static_cast<const float *>(fusion.gate_bias);
|
||||
use_gate_bias = use_gate;
|
||||
} else {
|
||||
use_gate_bias = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_gate) {
|
||||
gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||||
}
|
||||
if constexpr (has_fusion) {
|
||||
const int channel_bias = ids ? channel_x : channel_dst;
|
||||
if (use_bias) {
|
||||
x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
|
||||
}
|
||||
if (use_gate_bias) {
|
||||
gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
|
||||
}
|
||||
}
|
||||
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
float * buf_iw = (float *) data_mmv;
|
||||
float * buf_iw_gate = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
|
||||
}
|
||||
|
||||
if (block_size > warp_size) {
|
||||
if (tid < warp_size) {
|
||||
buf_iw[tid] = 0.0f;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
buf_iw_gate[tid] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float sumf[ncols_dst] = {0.0f};
|
||||
float sumf_gate[ncols_dst];
|
||||
if constexpr (has_fusion) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf_gate[j] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
const float2 * x2 = (const float2 *) x;
|
||||
const float2 * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const float2 *) gate_x;
|
||||
}
|
||||
}
|
||||
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmpx = x2[col2];
|
||||
float2 tmpx_gate = make_float2(0.0f, 0.0f);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += tmpx.x*tmpy.x;
|
||||
sumf[j] += tmpx.y*tmpy.y;
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half>) {
|
||||
const half2 * x2 = (const half2 *) x;
|
||||
const half2 * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const half2 *) gate_x;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::is_same_v<type_acc, float>) {
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmpx = __half22float2(x2[col2]);
|
||||
|
||||
float2 tmpx_gate = make_float2(0.0f, 0.0f);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = __half22float2(gate_x2[col2]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += tmpx.x * tmpy.x;
|
||||
sumf[j] += tmpx.y * tmpy.y;
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#ifdef FP16_AVAILABLE
|
||||
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
|
||||
half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
|
||||
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const half2 tmpx = x2[col2];
|
||||
|
||||
half2 tmpx_gate = make_half2(0.0f, 0.0f);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -83,21 +190,86 @@ static __global__ void mul_mat_vec_f(
|
|||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
|
||||
}
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
|
||||
//TODO: add support for ggml_cuda_mad for hip_bfloat162
|
||||
#if defined(GGML_USE_HIP)
|
||||
const int * x2 = (const int *) x;
|
||||
const int * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const int *) gate_x;
|
||||
}
|
||||
}
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const int tmpx = x2[col2];
|
||||
int tmpx_gate = 0;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
|
||||
sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
|
||||
const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
|
||||
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
|
||||
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
|
||||
const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
|
||||
const nv_bfloat162 * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const nv_bfloat162 *) gate_x;
|
||||
}
|
||||
}
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const nv_bfloat162 tmpx = x2[col2];
|
||||
nv_bfloat162 tmpx_gate;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
static_assert(std::is_same_v<T, void>, "unsupported type");
|
||||
}
|
||||
|
|
@ -106,13 +278,31 @@ static __global__ void mul_mat_vec_f(
|
|||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
|
||||
}
|
||||
}
|
||||
|
||||
if (block_size > warp_size) {
|
||||
buf_iw[tid/warp_size] = sumf[j];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
buf_iw_gate[tid/warp_size] = sumf_gate[j];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid < warp_size) {
|
||||
sumf[j] = buf_iw[tid];
|
||||
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
sumf_gate[j] = buf_iw_gate[tid];
|
||||
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (j < ncols_dst) {
|
||||
__syncthreads();
|
||||
}
|
||||
|
|
@ -123,12 +313,70 @@ static __global__ void mul_mat_vec_f(
|
|||
return;
|
||||
}
|
||||
|
||||
dst[tid*stride_col_dst + row] = sumf[tid];
|
||||
float value = sumf[tid];
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_bias) {
|
||||
value += x_bias[tid*stride_col_dst + row];
|
||||
}
|
||||
|
||||
if (use_gate) {
|
||||
float gate_value = sumf_gate[tid];
|
||||
if (use_gate_bias) {
|
||||
gate_value += gate_bias[tid*stride_col_dst + row];
|
||||
}
|
||||
switch (glu_op) {
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
value *= ggml_cuda_op_silu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
value *= ggml_cuda_op_gelu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU_OAI: {
|
||||
value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst[tid*stride_col_dst + row] = value;
|
||||
}
|
||||
|
||||
template<typename T, typename type_acc, int ncols_dst, int block_size>
|
||||
static void mul_mat_vec_f_switch_fusion(
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
|
||||
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
if constexpr (ncols_dst == 1) {
|
||||
if (has_fusion) {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
|
||||
}
|
||||
|
||||
template <typename T, typename type_acc, int ncols_dst>
|
||||
static void launch_mul_mat_vec_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
void launch_mul_mat_vec_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
|
|
@ -140,8 +388,8 @@ static void launch_mul_mat_vec_f_cuda(
|
|||
GGML_ASSERT(stride_col_y % 2 == 0);
|
||||
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
|
||||
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
||||
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
||||
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
||||
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
|
||||
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
|
||||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
|
|
@ -160,57 +408,59 @@ static void launch_mul_mat_vec_f_cuda(
|
|||
}
|
||||
}
|
||||
|
||||
const int nbytes_shared = warp_size*sizeof(float);
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
|
||||
const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
|
||||
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
||||
const dim3 block_dims(block_size_best, 1, 1);
|
||||
switch (block_size_best) {
|
||||
case 32: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 64: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 96: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 128: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 160: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 192: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 224: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 256: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -220,7 +470,7 @@ static void launch_mul_mat_vec_f_cuda(
|
|||
|
||||
template <typename T, typename type_acc>
|
||||
static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
|
|
@ -230,49 +480,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
|||
switch (ncols_dst) {
|
||||
case 1:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 2:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 3:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 5:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 7:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
|
|
@ -284,29 +534,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
|||
|
||||
template<typename T>
|
||||
static void mul_mat_vec_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
enum ggml_prec prec, cudaStream_t stream) {
|
||||
|
||||
if constexpr(std::is_same_v<T, half>) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
|
||||
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
return;
|
||||
}
|
||||
}
|
||||
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
|
||||
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||
const ggml_cuda_mm_fusion_args_host * fusion) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
|
@ -332,6 +584,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||
|
||||
if (fusion) {
|
||||
GGML_ASSERT( !ids || dst->ne[2] == 1);
|
||||
GGML_ASSERT( ids || dst->ne[1] == 1);
|
||||
if (fusion->x_bias) {
|
||||
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.x_bias = fusion->x_bias->data;
|
||||
}
|
||||
if (fusion->gate) {
|
||||
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
|
||||
fusion_local.gate = fusion->gate->data;
|
||||
}
|
||||
if (fusion->gate_bias) {
|
||||
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.gate_bias = fusion->gate_bias->data;
|
||||
}
|
||||
fusion_local.glu_op = fusion->glu_op;
|
||||
}
|
||||
|
||||
const int64_t s01 = src0->nb[1] / ts_src0;
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s1 = dst->nb[1] / ts_dst;
|
||||
|
|
@ -354,19 +630,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0->data;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
|
|
@ -393,7 +669,6 @@ void ggml_cuda_op_mul_mat_vec_f(
|
|||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||
|
||||
|
||||
// ggml_cuda_op provides single, contiguous matrices
|
||||
const int64_t stride_row = ne00;
|
||||
const int64_t stride_col_y = ne10;
|
||||
|
|
@ -410,22 +685,23 @@ void ggml_cuda_op_mul_mat_vec_f(
|
|||
const int64_t stride_sample_y = 0;
|
||||
const int64_t stride_sample_dst = 0;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device empty{};
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0_dd_i;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0_dd_i;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||
const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_f(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#include "mmvq.cuh"
|
||||
#include "quantize.cuh"
|
||||
#include "unary.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
|
|
@ -82,7 +83,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
|
|||
return MMVQ_PARAMETERS_GENERIC;
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
|
||||
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
|
||||
if (table_id == MMVQ_PARAMETERS_GENERIC) {
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
|
|
@ -136,11 +137,11 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
|
|||
return 1;
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_dst>
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
template <ggml_type type, int ncols_dst, bool has_fusion>
|
||||
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
|
||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
||||
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
||||
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
||||
|
|
@ -169,8 +170,38 @@ static __global__ void mul_mat_vec_q(
|
|||
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
|
||||
const uint32_t sample_y = sample_dst;
|
||||
|
||||
bool use_gate = false;
|
||||
bool use_bias = false;
|
||||
bool use_gate_bias = false;
|
||||
const void * vgate = nullptr;
|
||||
const float * x_bias = nullptr;
|
||||
const float * gate_bias = nullptr;
|
||||
ggml_glu_op active_glu;
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
use_gate = fusion.gate != nullptr;
|
||||
use_bias = fusion.x_bias != nullptr;
|
||||
use_gate_bias = fusion.gate_bias != nullptr && use_gate;
|
||||
vgate = fusion.gate;
|
||||
x_bias = (const float *) fusion.x_bias;
|
||||
gate_bias = (const float *) fusion.gate_bias;
|
||||
active_glu = fusion.glu_op;
|
||||
}
|
||||
|
||||
const uint32_t channel_bias = ids ? channel_x : channel_dst;
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_bias) {
|
||||
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||
}
|
||||
if (use_gate_bias) {
|
||||
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||
}
|
||||
}
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
||||
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
||||
|
||||
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
|
||||
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
|
||||
|
|
@ -187,17 +218,35 @@ static __global__ void mul_mat_vec_q(
|
|||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||
tmp[j][i] += vec_dot_q_cuda(
|
||||
vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_gate[j][i] += vec_dot_q_cuda(
|
||||
vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
||||
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
||||
if constexpr (!has_fusion) {
|
||||
(void) tmp_shared_gate;
|
||||
} else if (!use_gate) {
|
||||
(void) tmp_shared_gate;
|
||||
}
|
||||
|
||||
if (threadIdx.y > 0) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -216,12 +265,49 @@ static __global__ void mul_mat_vec_q(
|
|||
#pragma unroll
|
||||
for (int l = 0; l < nwarps-1; ++l) {
|
||||
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
|
||||
}
|
||||
}
|
||||
}
|
||||
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
||||
dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
|
||||
float result = tmp[j][threadIdx.x];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_bias) {
|
||||
result += x_bias[j*stride_col_dst + threadIdx.x];
|
||||
}
|
||||
if (use_gate) {
|
||||
float gate_value = tmp_gate[j][threadIdx.x];
|
||||
if (use_gate_bias) {
|
||||
gate_value += gate_bias[j*stride_col_dst + threadIdx.x];
|
||||
}
|
||||
switch (active_glu) {
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
result *= ggml_cuda_op_silu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
result *= ggml_cuda_op_gelu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU_OAI: {
|
||||
result = ggml_cuda_op_swiglu_oai_single(gate_value, result);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
result = result * gate_value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[j*stride_col_dst + threadIdx.x] = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -235,9 +321,37 @@ static std::pair<dim3, dim3> calc_launch_params(
|
|||
return {block_nums, block_dims};
|
||||
}
|
||||
|
||||
template<ggml_type type, int c_ncols_dst>
|
||||
static void mul_mat_vec_q_switch_fusion(
|
||||
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
||||
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
||||
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
||||
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
|
||||
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
if constexpr (c_ncols_dst == 1) {
|
||||
if (has_fusion) {
|
||||
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||
|
||||
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static void mul_mat_vec_q_switch_ncols_dst(
|
||||
const void * vx, const void * vy, const int32_t * ids, float * dst,
|
||||
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst,
|
||||
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
||||
|
|
@ -256,80 +370,83 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
|||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
|
||||
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
|
||||
GGML_ASSERT(!ids || ncols_dst == 1);
|
||||
switch (ncols_dst) {
|
||||
case 1: {
|
||||
constexpr int c_ncols_dst = 1;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 2: {
|
||||
constexpr int c_ncols_dst = 2;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 3: {
|
||||
constexpr int c_ncols_dst = 3;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 4: {
|
||||
constexpr int c_ncols_dst = 4;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 5: {
|
||||
constexpr int c_ncols_dst = 5;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 6: {
|
||||
constexpr int c_ncols_dst = 6;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 7: {
|
||||
constexpr int c_ncols_dst = 7;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 8: {
|
||||
constexpr int c_ncols_dst = 8;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(has_fusion);
|
||||
}
|
||||
static void mul_mat_vec_q_switch_type(
|
||||
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst,
|
||||
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst,
|
||||
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
||||
|
|
@ -339,143 +456,123 @@ static void mul_mat_vec_q_switch_type(
|
|||
switch (type_x) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -484,7 +581,8 @@ static void mul_mat_vec_q_switch_type(
|
|||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec_q(
|
||||
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||
const ggml_cuda_mm_fusion_args_host * fusion) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
|
||||
|
|
@ -508,6 +606,31 @@ void ggml_cuda_mul_mat_vec_q(
|
|||
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||
|
||||
if (fusion) {
|
||||
GGML_ASSERT( !ids || dst->ne[2] == 1);
|
||||
GGML_ASSERT( ids || dst->ne[1] == 1);
|
||||
|
||||
if (fusion->x_bias) {
|
||||
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.x_bias = fusion->x_bias->data;
|
||||
}
|
||||
if (fusion->gate) {
|
||||
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
|
||||
fusion_local.gate = fusion->gate->data;
|
||||
}
|
||||
if (fusion->gate_bias) {
|
||||
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.gate_bias = fusion->gate_bias->data;
|
||||
}
|
||||
fusion_local.glu_op = fusion->glu_op;
|
||||
}
|
||||
|
||||
// If src0 is a temporary compute buffer, clear any potential padding.
|
||||
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
|
||||
const size_t size_data = ggml_nbytes(src0);
|
||||
|
|
@ -549,10 +672,10 @@ void ggml_cuda_mul_mat_vec_q(
|
|||
const int64_t stride_channel_y = ids ? s11 : s12;
|
||||
|
||||
mul_mat_vec_q_switch_type(
|
||||
src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00,
|
||||
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
|
||||
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, stream);
|
||||
ne03, ne3, s03, s13, s3, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q(
|
||||
|
|
@ -578,8 +701,9 @@ void ggml_cuda_op_mul_mat_vec_q(
|
|||
const int stride_row_x = ne00 / ggml_blck_size(src0->type);
|
||||
const int stride_col_y = src1_padded_row_size / QK8_1;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||
mul_mat_vec_q_switch_type(
|
||||
src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
|
||||
src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
|
||||
|
||||
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||
|
||||
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(112, 112);
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(128, 128);
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(256, 256);
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(40, 40);
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(576, 512);
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(64, 64);
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(80, 80);
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE(96, 96);
|
||||
|
|
@ -3,8 +3,17 @@
|
|||
from glob import glob
|
||||
import os
|
||||
|
||||
HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576]
|
||||
|
||||
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
|
||||
|
||||
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-tile.cuh"
|
||||
|
||||
DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v});
|
||||
"""
|
||||
|
||||
SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../fattn-vec.cuh"
|
||||
|
|
@ -51,6 +60,11 @@ def get_short_name(long_quant_name):
|
|||
for filename in glob("*.cu"):
|
||||
os.remove(filename)
|
||||
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||
|
||||
for type_k in TYPES_KV:
|
||||
for type_v in TYPES_KV:
|
||||
with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
||||
|
|
@ -64,7 +78,9 @@ for ncols in [8, 16, 32, 64]:
|
|||
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
|
||||
f.write(SOURCE_FATTN_MMA_START)
|
||||
|
||||
for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
|
||||
for head_size_kq in HEAD_SIZES_KQ:
|
||||
if head_size_kq == 40:
|
||||
continue
|
||||
if head_size_kq != 576 and ncols2 == 16:
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 != 16:
|
||||
|
|
|
|||
|
|
@ -2,23 +2,70 @@
|
|||
#include "ggml.h"
|
||||
#include "topk-moe.cuh"
|
||||
|
||||
#include <cmath>
|
||||
#include <initializer_list>
|
||||
|
||||
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||
template <int experts_per_thread, bool use_limit>
|
||||
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
|
||||
float max_val = -INFINITY;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = lane + i * WARP_SIZE;
|
||||
const bool active = !use_limit || (idx < limit);
|
||||
if (active) {
|
||||
max_val = max(max_val, vals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
max_val = warp_reduce_max(max_val);
|
||||
|
||||
float sum = 0.f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = lane + i * WARP_SIZE;
|
||||
const bool active = !use_limit || (idx < limit);
|
||||
if (active) {
|
||||
const float val = expf(vals[i] - max_val);
|
||||
vals[i] = val;
|
||||
sum += val;
|
||||
} else {
|
||||
vals[i] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
const float inv_sum = 1.0f / sum;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = lane + i * WARP_SIZE;
|
||||
const bool active = !use_limit || (idx < limit);
|
||||
if (active) {
|
||||
vals[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
This kernel does the following:
|
||||
1. softmax over the logits per token [n_experts, n_tokens]
|
||||
1. optionally softmax over the logits per token [n_experts, n_tokens]
|
||||
2. argmax reduce over the top-k (n_experts_used) logits
|
||||
3. write weights + ids to global memory
|
||||
4. optionally normalize the weights
|
||||
4. optionally normalize the weights or apply softmax over the selected logits
|
||||
|
||||
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
|
||||
*/
|
||||
template <int n_experts, bool with_norm>
|
||||
template <int n_experts, bool with_norm, bool delayed_softmax = false>
|
||||
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
const int n_rows,
|
||||
const int n_expert_used) {
|
||||
const int n_expert_used,
|
||||
const float clamp_val) {
|
||||
const int row = blockIdx.x * blockDim.y + threadIdx.y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
|
|
@ -30,51 +77,30 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||
|
||||
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||
|
||||
float logits_r[experts_per_thread];
|
||||
float wt[experts_per_thread];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const int expert = i + threadIdx.x;
|
||||
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
|
||||
const int expert = i + threadIdx.x;
|
||||
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
|
||||
}
|
||||
|
||||
float max_val = logits_r[0];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
max_val = max(val, max_val);
|
||||
if constexpr (!delayed_softmax) {
|
||||
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
|
||||
}
|
||||
|
||||
max_val = warp_reduce_max(max_val);
|
||||
|
||||
float wt[experts_per_thread];
|
||||
float tmp = 0.f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
wt[i] = expf(val - max_val);
|
||||
tmp += wt[i];
|
||||
}
|
||||
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
wt[i] = wt[i] * inv_sum;
|
||||
}
|
||||
|
||||
//at this point, each thread holds a portion of softmax,
|
||||
//we do the argmax reduce over n_expert_used, each time marking
|
||||
//at this point, each thread holds either a portion of the softmax distribution
|
||||
//or the raw logits. We do the argmax reduce over n_expert_used, each time marking
|
||||
//the expert weight as -inf to exclude from the next iteration
|
||||
|
||||
float wt_sum = 0.f;
|
||||
|
||||
extern __shared__ float data_topk_shared[];
|
||||
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
|
||||
float output_weights[experts_per_thread];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
output_weights[i] = 0.f;
|
||||
}
|
||||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
|
|
@ -99,11 +125,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||
}
|
||||
}
|
||||
|
||||
if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
output_weights[k / WARP_SIZE] = max_val;
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
|
||||
wt_shared_ptr[k] = max_val;
|
||||
ids[k] = max_expert;
|
||||
ids[k] = max_expert;
|
||||
if constexpr (with_norm) {
|
||||
wt_sum += max_val;
|
||||
}
|
||||
|
|
@ -112,73 +141,86 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||
|
||||
if constexpr (with_norm) {
|
||||
wt_sum = warp_reduce_sum(wt_sum);
|
||||
wt_sum = max(wt_sum, clamp_val);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
||||
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
||||
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
output_weights[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
||||
weights[i] = wt_shared_ptr[i];
|
||||
if constexpr (delayed_softmax) {
|
||||
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = i * WARP_SIZE + threadIdx.x;
|
||||
if (idx < n_expert_used) {
|
||||
weights[idx] = output_weights[i];
|
||||
}
|
||||
}
|
||||
|
||||
if (!with_norm) {
|
||||
GGML_UNUSED(clamp_val);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool with_norm>
|
||||
template <bool with_norm, bool delayed_softmax = false>
|
||||
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
const int n_rows,
|
||||
const int n_expert,
|
||||
const int n_expert_used) {
|
||||
const int n_expert_used,
|
||||
const float clamp_val) {
|
||||
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
|
||||
const int rows_per_block = 4;
|
||||
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
||||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
|
||||
|
||||
switch (n_expert) {
|
||||
case 1:
|
||||
topk_moe_cuda<1, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<1, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 2:
|
||||
topk_moe_cuda<2, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<2, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 4:
|
||||
topk_moe_cuda<4, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<4, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 8:
|
||||
topk_moe_cuda<8, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<8, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 16:
|
||||
topk_moe_cuda<16, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<16, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 32:
|
||||
topk_moe_cuda<32, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<32, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 64:
|
||||
topk_moe_cuda<64, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<64, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 128:
|
||||
topk_moe_cuda<128, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<128, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 256:
|
||||
topk_moe_cuda<256, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<256, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
case 512:
|
||||
topk_moe_cuda<512, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<512, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "fatal error");
|
||||
|
|
@ -190,7 +232,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm) {
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax,
|
||||
ggml_tensor * clamp) {
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
|
@ -198,7 +242,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|||
const int n_experts = logits->ne[0];
|
||||
const int n_rows = logits->ne[1];
|
||||
|
||||
const float * logits_d = (const float *) logits->src[0]->data;
|
||||
const float * logits_d = (const float *) logits->data;
|
||||
float * weights_d = (float *) weights->data;
|
||||
int32_t * ids_d = (int32_t *) ids->data;
|
||||
|
||||
|
|
@ -206,14 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|||
|
||||
const int n_expert_used = weights->ne[1];
|
||||
|
||||
float clamp_val = -INFINITY;
|
||||
if (with_norm) {
|
||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
if (clamp) {
|
||||
clamp_val = ggml_get_op_params_f32(clamp, 0);
|
||||
}
|
||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
|
||||
} else {
|
||||
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
GGML_ASSERT(clamp == nullptr);
|
||||
if (delayed_softmax) {
|
||||
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
||||
clamp_val);
|
||||
} else {
|
||||
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
|
||||
clamp_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
|
|
@ -239,19 +294,43 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
|
|||
return false;
|
||||
}
|
||||
|
||||
if (clamp) {
|
||||
if (clamp->op != GGML_OP_CLAMP) {
|
||||
return false;
|
||||
}
|
||||
float max_val = ggml_get_op_params_f32(clamp, 1);
|
||||
|
||||
if (max_val != INFINITY) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
|
||||
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
||||
GGML_OP_RESHAPE };
|
||||
|
||||
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
|
||||
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||
|
||||
GGML_ASSERT(!norm || !delayed_softmax);
|
||||
|
||||
if (delayed_softmax) {
|
||||
return delayed_softmax_ops;
|
||||
}
|
||||
|
||||
if (norm) {
|
||||
return norm_ops;
|
||||
}
|
||||
|
||||
return no_norm_ops;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@
|
|||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * top_k,
|
||||
const bool with_norm);
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax = false,
|
||||
ggml_tensor * weight_clamp = nullptr);
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
|
||||
|
|
|
|||
|
|
@ -18,10 +18,7 @@ static __device__ __forceinline__ float op_step(float x) {
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu(float x) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
return ggml_cuda_op_gelu_single(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu_erf(float x) {
|
||||
|
|
@ -37,7 +34,7 @@ static __device__ __forceinline__ float op_gelu_quick(float x) {
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float op_silu(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
return ggml_cuda_op_silu_single(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_tanh(float x) {
|
||||
|
|
@ -317,13 +314,8 @@ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, cons
|
|||
|
||||
float xi = x[j0];
|
||||
float gi = g[j1];
|
||||
xi = fminf(xi, limit);
|
||||
gi = fmaxf(fminf(gi, limit), -limit);
|
||||
|
||||
float out_glu = xi / (1.0f + expf(-xi * alpha));
|
||||
out_glu = out_glu * (1.0f + gi);
|
||||
|
||||
dst[i] = out_glu;
|
||||
dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_NEG_BLOCK_SIZE 256
|
||||
|
|
@ -75,3 +76,23 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|||
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
|
||||
x = fminf(x, limit);
|
||||
g = fmaxf(fminf(g, limit), -limit);
|
||||
|
||||
float out_glu = x / (1.0f + expf(-x * alpha));
|
||||
out_glu = out_glu * (1.0f + g);
|
||||
return out_glu;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include(ExternalProject)
|
||||
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
|
||||
add_library(htp_iface OBJECT
|
||||
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
|
||||
|
||||
set_target_properties(htp_iface PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_include_directories(htp_iface PUBLIC
|
||||
${HEXAGON_SDK_ROOT}/incs
|
||||
${HEXAGON_SDK_ROOT}/incs/stddef
|
||||
${HEXAGON_SDK_ROOT}/utils/examples
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/htp
|
||||
${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
build_idl(htp/htp_iface.idl htp_iface)
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES Android)
|
||||
target_link_options(htp_iface PUBLIC -llog -ldl)
|
||||
elseif (CMAKE_SYSTEM_NAME MATCHES Windows)
|
||||
target_precompile_headers(htp_iface PUBLIC <sal.h>)
|
||||
else()
|
||||
target_link_options(htp_iface PUBLIC -ldl)
|
||||
endif()
|
||||
|
||||
link_custom_library(htp_iface cdsprpc)
|
||||
link_custom_library(htp_iface rpcmem)
|
||||
|
||||
set(TARGET_NAME ggml-hexagon)
|
||||
ggml_add_backend_library(${TARGET_NAME}
|
||||
ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h)
|
||||
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE htp_iface)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
# Build HTP bits
|
||||
set(HTP_CMAKE_ARGS
|
||||
-DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||
-DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT}
|
||||
-DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
|
||||
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG})
|
||||
|
||||
ExternalProject_Add(htp-v73
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73")
|
||||
|
||||
ExternalProject_Add(htp-v75
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75")
|
||||
|
||||
ExternalProject_Add(htp-v79
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79")
|
||||
|
||||
ExternalProject_Add(htp-v81
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81")
|
||||
|
||||
# Install Hexagon skels required at runtime
|
||||
install(FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so
|
||||
TYPE LIB)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,448 @@
|
|||
|
||||
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
#pragma clang diagnostic ignored "-Wsign-compare"
|
||||
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-common.h"
|
||||
#include "ggml-hexagon.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include "htp-utils.h"
|
||||
|
||||
#include <domain.h>
|
||||
#include <remote.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
domain * get_domain(int domain_id) {
|
||||
int i = 0;
|
||||
int size = sizeof(supported_domains) / sizeof(domain);
|
||||
|
||||
for (i = 0; i < size; i++) {
|
||||
if (supported_domains[i].id == domain_id) {
|
||||
return &supported_domains[i];
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
bool is_valid_domain_id(int domain_id, int compute_only) {
|
||||
int i = 0;
|
||||
int size = sizeof(supported_domains) / sizeof(domain);
|
||||
|
||||
if (compute_only) {
|
||||
return is_CDSP(domain_id);
|
||||
}
|
||||
|
||||
for (i = 0; i < size; i++) {
|
||||
if (supported_domains[i].id == domain_id) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
int ss_info = 0;
|
||||
if (domain_type != NULL) {
|
||||
if (strcmp(domain_type, "LPASS") == 0) {
|
||||
ss_info = FASTRPC_LPASS;
|
||||
} else if (strcmp(domain_type, "HPASS") == 0) {
|
||||
ss_info = FASTRPC_HPASS;
|
||||
} else {
|
||||
ss_info = FASTRPC_NSP;
|
||||
}
|
||||
}
|
||||
system_req_payload req = { 0 };
|
||||
req.id = FASTRPC_GET_DOMAINS;
|
||||
req.sys.domains = NULL;
|
||||
fastrpc_domain * domain = NULL;
|
||||
if (ss_info != 0) {
|
||||
req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info);
|
||||
} else {
|
||||
req.sys.flags = 0;
|
||||
}
|
||||
#ifdef _WIN32
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
goto bail;
|
||||
#endif
|
||||
if (remote_system_request) {
|
||||
nErr = remote_system_request(&req);
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
// Allocate memory for domain-info array
|
||||
req.sys.max_domains = req.sys.num_domains;
|
||||
if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) {
|
||||
nErr = AEE_ENOMEMORY;
|
||||
GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains");
|
||||
goto bail;
|
||||
}
|
||||
|
||||
nErr = remote_system_request(&req);
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
|
||||
for (int i = 0; i < req.sys.num_domains; i++) {
|
||||
// Verify that only requested type domains were returned
|
||||
domain = &req.sys.domains[i];
|
||||
if (domain->type != ss_info && domain_type != NULL) {
|
||||
nErr = -1;
|
||||
GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n");
|
||||
goto bail;
|
||||
}
|
||||
}
|
||||
*domains_info = req.sys.domains;
|
||||
*num_domains = req.sys.num_domains;
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
goto bail;
|
||||
}
|
||||
bail:
|
||||
if (nErr && !req.sys.domains) {
|
||||
free(req.sys.domains);
|
||||
}
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) {
|
||||
int err = 0;
|
||||
remote_rpc_effective_domain_id_t sess = { 0 };
|
||||
|
||||
sess.domain_name = domain_name;
|
||||
sess.domain_name_len = strlen(domain_name);
|
||||
sess.session_id = session_id;
|
||||
|
||||
err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess));
|
||||
if (err) {
|
||||
GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name,
|
||||
session_id);
|
||||
return err;
|
||||
}
|
||||
|
||||
*effec_domain_id = sess.effective_domain_id;
|
||||
return err;
|
||||
}
|
||||
|
||||
int get_dsp_support(int * domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*domain = CDSP_DOMAIN_ID; // DSP domain default value is CDSP_DOMAIN_ID
|
||||
|
||||
if (remote_handle_control) {
|
||||
struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 };
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
goto bail;
|
||||
}
|
||||
|
||||
if (dsp_capability_domain.capability == 0) {
|
||||
dsp_capability_domain.domain = ADSP_DOMAIN_ID; // Check for ADSP support.
|
||||
dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT;
|
||||
dsp_capability_domain.capability = 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if (dsp_capability_domain.capability) {
|
||||
*domain = ADSP_DOMAIN_ID; // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID
|
||||
}
|
||||
}
|
||||
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (attr == VTCM_PAGE || attr == VTCM_COUNT) {
|
||||
} else {
|
||||
nErr = AEE_EBADPARM;
|
||||
GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n");
|
||||
goto bail;
|
||||
}
|
||||
if (remote_handle_control) {
|
||||
if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for VTCM information
|
||||
* Since the ADSP does not have a dedicated VTCM, we expect the output to be 0
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_vtcm_dsp;
|
||||
dsp_capability_vtcm_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_vtcm_dsp.attribute_ID = attr;
|
||||
dsp_capability_vtcm_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_vtcm_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("Unsupported domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
bool is_unsignedpd_supported(int domain_id) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
if (remote_handle_control) {
|
||||
struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 };
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n");
|
||||
return false;
|
||||
}
|
||||
if (nErr) {
|
||||
GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr);
|
||||
return false;
|
||||
}
|
||||
if (dsp_capability_domain.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n");
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool get_unsignedpd_support(void) {
|
||||
return is_unsignedpd_supported(CDSP_DOMAIN_ID);
|
||||
}
|
||||
|
||||
bool is_async_fastrpc_supported(int domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for ASYNC_FASTRPC_SUPPORT information
|
||||
* Async fastrpc is supported only on CDSP
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_async_support;
|
||||
dsp_capability_async_support.domain = (uint32_t) domain;
|
||||
dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT;
|
||||
dsp_capability_async_support.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (dsp_capability_async_support.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_status_notification_supported(int domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
|
||||
if (remote_handle_control) {
|
||||
/*
|
||||
* Query the DSP for STATUS_NOTIFICATION_SUPPORT information
|
||||
* DSP User PD status notification Support
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_status_notification_support;
|
||||
dsp_capability_status_notification_support.domain = (uint32_t) domain;
|
||||
dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT;
|
||||
dsp_capability_status_notification_support.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (dsp_capability_status_notification_support.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return false;
|
||||
}
|
||||
|
||||
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) {
|
||||
nErr = AEE_EBADPARM;
|
||||
GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n");
|
||||
goto bail;
|
||||
}
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for HMX SUPPORT information
|
||||
* HMX is supported on CDSP only
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_hmx_dsp;
|
||||
dsp_capability_hmx_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_hmx_dsp.attribute_ID = attr;
|
||||
dsp_capability_hmx_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_hmx_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_hex_arch_ver(int domain, int * arch) {
|
||||
if (!remote_handle_control) {
|
||||
GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
struct remote_dsp_capability arch_ver;
|
||||
arch_ver.domain = (uint32_t) domain;
|
||||
arch_ver.attribute_ID = ARCH_VER;
|
||||
arch_ver.capability = (uint32_t) 0;
|
||||
|
||||
int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
|
||||
if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
if (err != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
|
||||
switch (arch_ver.capability & 0xff) {
|
||||
case 0x73:
|
||||
*arch = 73;
|
||||
return 0;
|
||||
case 0x75:
|
||||
*arch = 75;
|
||||
return 0;
|
||||
case 0x79:
|
||||
*arch = 79;
|
||||
return 0;
|
||||
case 0x81:
|
||||
*arch = 81;
|
||||
return 0;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for HVX SUPPORT information
|
||||
* HVX is supported on CDSP only
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_hvx_dsp;
|
||||
dsp_capability_hvx_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_hvx_dsp.attribute_ID = attr;
|
||||
dsp_capability_hvx_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_hvx_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
#ifndef HTP_UTILS_H
|
||||
#define HTP_UTILS_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <AEEStdErr.h>
|
||||
#include <inttypes.h>
|
||||
#include <remote.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
/* Offset to differentiate HLOS and Hexagon error codes.
|
||||
Stores the value of AEE_EOFFSET for Hexagon. */
|
||||
#ifndef DSP_OFFSET
|
||||
# define DSP_OFFSET 0x80000400
|
||||
#endif
|
||||
|
||||
/* Errno for connection reset by peer. */
|
||||
#ifndef ECONNRESET
|
||||
# ifdef __hexagon__
|
||||
# define ECONNRESET 104
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Abstraction of different OS specific sleep APIs.
|
||||
SLEEP accepts input in seconds. */
|
||||
#ifndef SLEEP
|
||||
# ifdef __hexagon__
|
||||
# define SLEEP(x) \
|
||||
{ /* Do nothing for simulator. */ \
|
||||
}
|
||||
# else
|
||||
# ifdef _WINDOWS
|
||||
# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
|
||||
# else
|
||||
# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
|
||||
# endif
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Include windows specific header files. */
|
||||
#ifdef _WINDOWS
|
||||
# include <sysinfoapi.h>
|
||||
# include <windows.h>
|
||||
# define _CRT_SECURE_NO_WARNINGS 1
|
||||
# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
|
||||
/* Including this file for custom implementation of getopt function. */
|
||||
# include "getopt_custom.h"
|
||||
#endif
|
||||
|
||||
/* Includes and defines for all HLOS except windows */
|
||||
#if !defined(__hexagon__) && !defined(_WINDOWS)
|
||||
# include "unistd.h"
|
||||
|
||||
# include <sys/time.h>
|
||||
#endif
|
||||
|
||||
/* Includes and defines for Hexagon and all HLOS except Windows. */
|
||||
#if !defined(_WINDOWS)
|
||||
/* Weak reference to remote symbol for compilation. */
|
||||
# pragma weak remote_session_control
|
||||
# pragma weak remote_handle_control
|
||||
# pragma weak remote_handle64_control
|
||||
# pragma weak fastrpc_mmap
|
||||
# pragma weak fastrpc_munmap
|
||||
#endif
|
||||
|
||||
#if !defined(_WINDOWS)
|
||||
# pragma weak remote_system_request
|
||||
#endif
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query DSP support.
|
||||
*
|
||||
* @param[out] domain pointer to supported domain.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*/
|
||||
int get_dsp_support(int * domain);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query VTCM information.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*/
|
||||
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain.
|
||||
*
|
||||
* @return true if unsigned pd is supported.
|
||||
* false if unsigned pd is not supported, capability query failed.
|
||||
*/
|
||||
|
||||
bool get_unsignedpd_support(void);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query unsigned pd support.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @return true if unsigned pd is supported.
|
||||
* false if unsigned pd is not supported, capability query failed.
|
||||
*/
|
||||
|
||||
bool is_unsignedpd_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* is_valid_domain_id API: query a domain id is valid.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @param[in] compute_only value of domain is only compared with CDSP domains supported by the target when enabled.
|
||||
* @return true if value of domain is valid.
|
||||
* false if value of domain is not valid.
|
||||
*/
|
||||
|
||||
bool is_valid_domain_id(int domain_id, int compute_only);
|
||||
|
||||
/**
|
||||
* get_domain API: get domain struct from domain value.
|
||||
*
|
||||
* @param[in] domain value of a domain
|
||||
* @return Returns domain struct of the domain if it is supported or else
|
||||
* returns NULL.
|
||||
*
|
||||
*/
|
||||
|
||||
domain * get_domain(int domain_id);
|
||||
|
||||
/**
|
||||
* get_domains_info API: get information for all the domains available on the device
|
||||
*
|
||||
* @param[in] domain_type pointer to domain type
|
||||
* @param[in] num_domains pointer to number of domains
|
||||
* @param[in] domains_info pointer to save discovered domains information.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
* It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application.
|
||||
*
|
||||
*/
|
||||
|
||||
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info);
|
||||
|
||||
/**
|
||||
* get_effective_domain_id API: get effective domain id for given session id
|
||||
*
|
||||
* @param[in] domain_name pointer to domain name
|
||||
* @param[in] session_id
|
||||
* @param[in] effec_domain_id pointer to save obtained effective domain id.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
|
||||
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id);
|
||||
|
||||
/**
|
||||
* is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @return Returns true or false stating support of Async FastRPC
|
||||
*
|
||||
*/
|
||||
|
||||
bool is_async_fastrpc_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @return Returns true or false stating status notification support information
|
||||
*
|
||||
*/
|
||||
bool is_status_notification_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* get_hmx_support_info API: query the DSP for HMX SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
/**
|
||||
* get_hex_arch_ver API: query the Hexagon processor architecture version information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] Arch version (73, 75, ...)
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hex_arch_ver(int domain, int * arch);
|
||||
|
||||
/**
|
||||
* get_hvx_support_info API: query the DSP for HVX SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif //DSP_CAPABILITIES_UTILS_H
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
cmake_minimum_required(VERSION 3.22.2)
|
||||
project(ggml-htp C CXX ASM)
|
||||
|
||||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
|
||||
include_directories(
|
||||
${HEXAGON_SDK_ROOT}/incs
|
||||
${HEXAGON_SDK_ROOT}/incs/stddef
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../..
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
set(HTP_LIB ggml-htp-${DSP_VERSION})
|
||||
|
||||
add_library(${HTP_LIB} SHARED
|
||||
main.c
|
||||
htp_iface_skel.c
|
||||
worker-pool.c
|
||||
htp-dma.c
|
||||
hvx-sigmoid.c
|
||||
hvx-inverse.c
|
||||
hvx-exp.c
|
||||
hvx-utils.c
|
||||
matmul-ops.c
|
||||
binary-ops.c
|
||||
unary-ops.c
|
||||
softmax-ops.c
|
||||
act-ops.c
|
||||
rope-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>)
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
install(TARGETS ${HTP_LIB})
|
||||
|
|
@ -0,0 +1,448 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define htp_act_preamble3 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1->ne[0]; \
|
||||
const uint32_t ne11 = src1->ne[1]; \
|
||||
const uint32_t ne12 = src1->ne[2]; \
|
||||
const uint32_t ne13 = src1->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = src1->nb[0]; \
|
||||
const uint32_t nb11 = src1->nb[1]; \
|
||||
const uint32_t nb12 = src1->nb[2]; \
|
||||
const uint32_t nb13 = src1->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
#define htp_act_preamble2 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
bool src1_valid = src1->ne[0];
|
||||
if (!src1_valid) {
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
}
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||
|
||||
const int32_t swapped = op_params[1];
|
||||
|
||||
const int nc = (src1_valid) ? ne0 : ne0 / 2;
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
|
||||
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||
}
|
||||
|
||||
if (!src1_valid) {
|
||||
src0 += swapped ? nc : 0;
|
||||
src1 += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
|
||||
hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
|
||||
(uint8_t *) dst, nc);
|
||||
} else {
|
||||
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, nc, true);
|
||||
hvx_add_scalar_f32(src0_spad_data, 1.0, src1_spad_data, nc);
|
||||
hvx_inverse_f32(src1_spad_data, src0_spad_data, nc);
|
||||
|
||||
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, dst_spad_data, nc);
|
||||
hvx_mul_f32(dst_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "swiglu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
|
||||
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
htp_act_preamble3;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t src1_row_size = nb11;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||
FARF(HIGH, "act-f32: unaligned addresses in activations op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
bool src1_valid = src1->ne[0];
|
||||
if (!src1_valid) {
|
||||
data_src1 = data_src0;
|
||||
}
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||
|
||||
const int32_t swapped = op_params[1];
|
||||
const float alpha = ((const float *) (op_params))[2];
|
||||
const float limit = ((const float *) (op_params))[3];
|
||||
|
||||
const int nc = (src1_valid) ? ne0 : ne0 / 2;
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
|
||||
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||
}
|
||||
|
||||
if (!src1) {
|
||||
src0 += swapped ? nc : 0;
|
||||
src1 += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
// x (src0_spad_data) = std::min(src0_p[k], limit);
|
||||
hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
|
||||
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
|
||||
hvx_clamp_scalar_f32((const uint8_t *) src1, limit, limit, src1_spad_data, nc);
|
||||
// y (src1_spad_data) = y1 + 1.f
|
||||
hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc);
|
||||
// x1 (dst_spad_data) = alpha * (x)
|
||||
hvx_mul_scalar_f32(src0_spad_data, alpha, dst_spad_data, nc);
|
||||
// x2 (dst_spad_data) = expf(-x1)
|
||||
hvx_exp_f32(dst_spad_data, dst_spad_data, nc, true);
|
||||
// x3 (dst_spad_data) = x2 + 1.f
|
||||
hvx_add_scalar_f32(dst_spad_data, 1.0, dst_spad_data, nc);
|
||||
// x4 (dst_spad_data) = 1 / x3
|
||||
hvx_inverse_f32(dst_spad_data, dst_spad_data, nc);
|
||||
// out_glu(dst_spad_data) = x * x4
|
||||
hvx_mul_f32(src0_spad_data, dst_spad_data, dst_spad_data, nc);
|
||||
// out = out_glu * (y + 1.f);
|
||||
hvx_mul_f32(dst_spad_data, src1_spad_data, (uint8_t *) dst, nc);
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0],
|
||||
src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
htp_act_preamble2;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, ne0);
|
||||
hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
|
||||
} else {
|
||||
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, ne0, true);
|
||||
hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0);
|
||||
hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0);
|
||||
|
||||
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "silu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02,
|
||||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
octx->src0_nrows_per_thread);
|
||||
}
|
||||
|
||||
static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
|
||||
}
|
||||
|
||||
static int execute_op_activations_fp32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) {
|
||||
FARF(ERROR, "Non-contiguous tensors are not supported at this time \n");
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
worker_callback_t act_op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_UNARY_SILU:
|
||||
act_op_func = unary_silu_fp32;
|
||||
op_type = "silu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU:
|
||||
act_op_func = glu_swiglu_fp32;
|
||||
op_type = "swiglu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU_OAI:
|
||||
act_op_func = glu_swiglu_oai_fp32;
|
||||
op_type = "swiglu-oai-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src1->ne[0] ? src1->nb[1] : src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * octx->n_threads;
|
||||
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * octx->n_threads;
|
||||
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * octx->n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (src1->ne[0]) {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
octx->dst_spad.size);
|
||||
} else {
|
||||
FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
}
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "act-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
int op_activations(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_activations_fp32(octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
|
@ -0,0 +1,344 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0,
|
||||
const uint8_t * src1,
|
||||
uint8_t * data_dst,
|
||||
const int num_elems);
|
||||
|
||||
static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 };
|
||||
static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
|
||||
|
||||
#define htp_binary_preamble \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1->ne[0]; \
|
||||
const uint32_t ne11 = src1->ne[1]; \
|
||||
const uint32_t ne12 = src1->ne[2]; \
|
||||
const uint32_t ne13 = src1->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = src1->nb[0]; \
|
||||
const uint32_t nb11 = src1->nb[1]; \
|
||||
const uint32_t nb12 = src1->nb[2]; \
|
||||
const uint32_t nb13 = src1->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
static void binary_job_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
uint8_t * spad_data,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
enum htp_op op) {
|
||||
htp_binary_preamble;
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t src1_row_size = nb11;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||
FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
is_aligned = 0;
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op];
|
||||
|
||||
uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
|
||||
|
||||
const uint32_t nr0 = ne00 / ne10;
|
||||
|
||||
const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
|
||||
uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
|
||||
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
const uint8_t * restrict src1_ptr = NULL;
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
src1_ptr = data_src1 + (ir % src1_nrows) * src1_row_size;
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
|
||||
if (src1_row_size == src0_row_size) {
|
||||
htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (nr0 > 1) {
|
||||
if ((1 == is_aligned) && (nr0 == ne00)) {
|
||||
hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
|
||||
} else {
|
||||
for (uint32_t r = 0; r < nr0; r++) {
|
||||
memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11);
|
||||
}
|
||||
}
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00);
|
||||
} else {
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
|
||||
}
|
||||
|
||||
src0_ptr += src0_row_size;
|
||||
dst_ptr += dst_row_size;
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
|
||||
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
const struct htp_tensor * src2,
|
||||
struct htp_tensor * dst,
|
||||
uint8_t * spad_data,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
hvx_elemwise_f32_func func_HVX) {
|
||||
htp_binary_preamble;
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t src1_row_size = nb11;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t ne02_ne01 = ne02 * ne01;
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||
FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n");
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
// src0 indices
|
||||
const uint32_t i03 = ir / ne02_ne01;
|
||||
const uint32_t i02 = (ir - i03 * ne02_ne01) / ne01;
|
||||
const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
|
||||
|
||||
// src1 indices
|
||||
const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]);
|
||||
assert(i11 >= 0 && i11 < ne11);
|
||||
|
||||
float * restrict dst_ptr = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1);
|
||||
const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01);
|
||||
const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11);
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
|
||||
if (src1_row_size == src0_row_size) {
|
||||
htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size);
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t nr0 = ne00 / ne10;
|
||||
if (nr0 > 1) {
|
||||
for (uint32_t r = 0; r < nr0; r++) {
|
||||
memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10);
|
||||
}
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00);
|
||||
} else {
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
|
||||
src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1],
|
||||
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_MUL:
|
||||
case HTP_OP_ADD:
|
||||
case HTP_OP_SUB:
|
||||
binary_job_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->src1_spad.data, n, i,
|
||||
octx->src0_nrows_per_thread, octx->op);
|
||||
break;
|
||||
|
||||
case HTP_OP_ADD_ID:
|
||||
binary_add_id_job_f32_per_thread(&octx->src0, &octx->src1, &octx->src2, &octx->dst, octx->src0_spad.data, n,
|
||||
i, octx->src0_nrows_per_thread, hvx_add_f32);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Binary Op %u", octx->op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static int execute_op_binary_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t binary_op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_MUL:
|
||||
binary_op_func = binary_job_dispatcher_f32;
|
||||
op_type = "mul-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_ADD:
|
||||
binary_op_func = binary_job_dispatcher_f32;
|
||||
op_type = "add-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_SUB:
|
||||
binary_op_func = binary_job_dispatcher_f32;
|
||||
op_type = "sub-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_ADD_ID:
|
||||
binary_op_func = binary_job_dispatcher_f32;
|
||||
op_type = "add-id-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported binary-Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const int n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src1->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
FARF(HIGH,
|
||||
"%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
octx->dst_spad.size);
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
|
||||
octx->ctx->vtcm_size, spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
int op_binary(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_binary_f32(octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
if (HEXAGON_TOOLCHAIN_INCLUDED)
|
||||
return()
|
||||
endif()
|
||||
set(HEXAGON_TOOLCHAIN_INCLUDED true)
|
||||
|
||||
#Cross Compiling for Hexagon
|
||||
set(HEXAGON TRUE)
|
||||
set(CMAKE_SYSTEM_NAME QURT)
|
||||
set(CMAKE_SYSTEM_PROCESSOR Hexagon)
|
||||
set(CMAKE_SYSTEM_VERSION "1") #${HEXAGON_PLATFORM_LEVEL})
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
|
||||
set(CUSTOM_RUNELF_PATH "")
|
||||
|
||||
#To fix backward compatibility with EAI addon.
|
||||
if (NOT HEXAGON_SDK_ROOT)
|
||||
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
|
||||
endif()
|
||||
|
||||
if (NOT HEXAGON_TOOLS_ROOT)
|
||||
if (DEFINED ENV{HEXAGON_TOOLS_ROOT})
|
||||
set(HEXAGON_TOOLS_ROOT $ENV{HEXAGON_TOOLS_ROOT})
|
||||
endif()
|
||||
if(NOT HEXAGON_TOOLS_ROOT)
|
||||
set(HEXAGON_TOOLS_ROOT $ENV{DEFAULT_HEXAGON_TOOLS_ROOT})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
|
||||
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
|
||||
|
||||
#Get the Binary extension of the Hexagon Toolchain
|
||||
if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
|
||||
set(HEXAGON_TOOLCHAIN_SUFFIX .exe)
|
||||
endif()
|
||||
message(DEBUG "CMAKE_HOST_SYSTEM_NAME:${CMAKE_HOST_SYSTEM_NAME}")
|
||||
|
||||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_arch.cmake)
|
||||
|
||||
set(HEXAGON_TOOLCHAIN ${HEXAGON_TOOLS_ROOT})
|
||||
set(HEXAGON_LIB_DIR "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib")
|
||||
set(HEXAGON_ISS_DIR ${HEXAGON_TOOLCHAIN}/Tools/lib/iss)
|
||||
|
||||
set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES
|
||||
HEXAGON_SDK_ROOT
|
||||
HEXAGON_TOOLS_ROOT
|
||||
)
|
||||
|
||||
#QURT Related includes and linker flags
|
||||
set(V_ARCH ${HEXAGON_ARCH})
|
||||
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}")
|
||||
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||
|
||||
if( ${TREE} MATCHES PAKMAN )
|
||||
set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||
endif()
|
||||
message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}")
|
||||
set(RTOS_DIR ${_QURT_INSTALL_DIR})
|
||||
set(QCC_DIR "${HEXAGON_QCC_DIR}/${V_ARCH}/G0")
|
||||
set(TARGET_DIR "${HEXAGON_LIB_DIR}/${V_ARCH}/G0")
|
||||
|
||||
include_directories(
|
||||
${_QURT_INSTALL_DIR}/include
|
||||
${_QURT_INSTALL_DIR}/include/qurt
|
||||
${_QURT_INSTALL_DIR}/include/posix
|
||||
)
|
||||
|
||||
set(QURT_START_LINK_LIBS)
|
||||
set(QURT_START_LINK_LIBS
|
||||
"${TARGET_DIR}/init.o"
|
||||
"${RTOS_DIR}/lib/crt1.o"
|
||||
"${RTOS_DIR}/lib/debugmon.o"
|
||||
"${RTOS_DIR}/lib/libqurt.a"
|
||||
"${TARGET_DIR}/libc.a"
|
||||
"${TARGET_DIR}/libqcc.a"
|
||||
"${TARGET_DIR}/libhexagon.a"
|
||||
"${RTOS_DIR}/lib/libqurtcfs.a"
|
||||
"${RTOS_DIR}/lib/libtimer_island.a"
|
||||
"${RTOS_DIR}/lib/libtimer_main.a"
|
||||
"${RTOS_DIR}/lib/libposix.a"
|
||||
)
|
||||
STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}")
|
||||
|
||||
set(QURT_END_LINK_LIBS
|
||||
${TARGET_DIR}/fini.o
|
||||
)
|
||||
|
||||
#Non QURT related includes and linker flags
|
||||
|
||||
set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}")
|
||||
|
||||
if (NOT NO_WRAP_MEM_API)
|
||||
set(WRAP_MALLOC -Wl,--wrap=malloc)
|
||||
set(WRAP_CALLOC -Wl,--wrap=calloc)
|
||||
set(WRAP_FREE -Wl,--wrap=free)
|
||||
set(WRAP_REALLOC -Wl,--wrap=realloc)
|
||||
set(WRAP_MEMALIGN -Wl,--wrap=memalign)
|
||||
endif()
|
||||
|
||||
set(PIC_SHARED_LD_FLAGS
|
||||
-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}
|
||||
-G0
|
||||
-fpic
|
||||
-Wl,-Bsymbolic
|
||||
-Wl,-L${TARGET_DIR_NOOS}/G0/pic
|
||||
-Wl,-L${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/
|
||||
-Wl,--no-threads ${WRAP_MALLOC} ${WRAP_CALLOC} ${WRAP_FREE} ${WRAP_REALLOC} ${WRAP_MEMALIGN}
|
||||
-shared
|
||||
"-o <TARGET> <SONAME_FLAG><TARGET_SONAME>"
|
||||
"<LINK_FLAGS>"
|
||||
-Wl,--start-group
|
||||
"<OBJECTS>"
|
||||
"<LINK_LIBRARIES>"
|
||||
-Wl,--end-group
|
||||
-lc
|
||||
)
|
||||
STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}")
|
||||
|
||||
set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}")
|
||||
|
||||
#System include paths
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)
|
||||
|
||||
#LLVM toolchain setup
|
||||
#Compiler paths, options and architecture
|
||||
set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_ASM_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(HEXAGON_LINKER ${CMAKE_C_COMPILER})
|
||||
set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)
|
||||
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
|
||||
|
||||
#Compiler Options
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
|
||||
set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
|
||||
set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}")
|
||||
set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}")
|
||||
set(CMAKE_ASM_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}" )
|
||||
|
||||
#Linker Options
|
||||
set(CMAKE_C_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}")
|
||||
set(CMAKE_CXX_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue