Add experimental ggml-hexagon backend for the Hexagon NPU (#16547)
* model: add support for extra bufs for all devices * hexagon: add experimental ggml-hexagon backend for the Hexagon NPU This commit introduces a new experimental backend `ggml-hexagon` with support for the Hexagon NPU. Highlights: - Supports Hexagon versions: v73, v75, v79, and v81 - Targets Android devices based on Snapdragon SoCs: Gen3, 8-Elite, and 8-Elite Gen5 - Supports Q4_0, Q8_0, MXFP4, and FP32 data types - Implements core LLM ops: MUL_MAT/MUL_MAT_ID, ADD/SUB/MUL/ADD_ID, RMS_NORM, ROPE, GLU/SWIGLU, SOFTMAX **Note:** This backend is experimental and may exhibit instability or limited performance across supported devices. It is intended for early testing and feedback from llama.cpp/ggml developer and user community. Co-Authored-By: Rajdeep Ganguly <rganguly@qti.qualcomm.com> Co-Authored-By: Todor Boinovski <todorb@qti.qualcomm.com> * hexagon: fix format checker errors * hexagon: update readme and cmake presets * ci: add android-ndk-build jobs that build plain ARM64 and Snapdragon versions * hexagon: add simple graph optimizer for stacking MUL_MAT ops with the same input * hexagon: move ADB helper scripts into scripts/snapdragon/adb * hexagon: replace all f/printfs with GGML_LOG_... * readme: add hexagon to the list supported backends * hexagon: stack malmuts with quantized inputs only * hexagon: add TODO for fixing issues in hexagon_graph_optimize * hexagon: update to hex-sdk 6.4.0 and add scripts for running on QDC * scripts: fix lint errors * scripts: update qdc pytest script to make linter happy * hexagon: add reduce sum in fp32 * hexagon: reduce number of vector stores in matmul output * hexagon: remove the need for vdelta in reduce-multiply-x8 * hexagon: consistent use of reduce_sum_fp32 for row_sums * hexagon: some more matmul optimizations and comments Optimize cases where tensor dims are not multiple of 1024 (e.g in Qwen models). We've handled those cases already but at a higher overhead. * hexagon: update cmake presets * hexagon: add OPMASK support for run-bench.sh wrapper * hexagon: update to use GGML_BACKEND_API * hexagon: remove unused logic for setting tensor flags for the views * hexagon: add asserts to set/get_tensor to make sure we handle complete tensors Same asserts as the CPU backend. * hexagon: use cpy_tensor slow path for non-host buffers * hexagon: error checks in the buffer allocator * cmake: move include(extProj) under ggml-hexagon * hexagon: don't forget to delete the backend on free * hexagon: set/get_tensor size assert apply only to quantized tensors * hexagon: reintroduce HEX_VERBOSE wrapper for GGML_LOG_DEBUG for now GGML_LOG_DEBUG is always enabled for test-backend-ops and the output gets in the way. Ideally we need a bit more finer log levels. * docs: typos in hexagon developer docs (libggm-...) * hexagon: overhaul error handling in the session/device allocation this should handle all failure paths in the session allocation. * hexagon: update cmake presets to enable fp16 vectors * hexagon: remove unused time_usec function * hexagon: don't forget to release buffer contexts * hexagon: fixed indents in hvx-utils (missed clang-format auto-format failure) * hexagon: remove custom can_repeat function and use ggml_can_repeat --------- Co-authored-by: Rajdeep Ganguly <rganguly@qti.qualcomm.com> Co-authored-by: Todor Boinovski <todorb@qti.qualcomm.com>
This commit is contained in:
parent
a2e0088d92
commit
63d2fc46e1
|
|
@ -1305,6 +1305,81 @@ jobs:
|
||||||
cd examples/llama.android
|
cd examples/llama.android
|
||||||
./gradlew build --no-daemon
|
./gradlew build --no-daemon
|
||||||
|
|
||||||
|
android-ndk-build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
env:
|
||||||
|
OPENCL_VERSION: 2025.07.22
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- build: 'arm64-cpu'
|
||||||
|
defines: '-D ANDROID_ABI=arm64-v8a -D ANDROID_PLATFORM=android-31 -D CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -D GGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm -G Ninja -D LLAMA_CURL=OFF -D GGML_OPENMP=OFF'
|
||||||
|
- build: 'arm64-snapdragon'
|
||||||
|
defines: '--preset arm64-android-snapdragon-release'
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
id: checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install OpenCL Headers and Libs
|
||||||
|
id: install_opencl
|
||||||
|
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||||
|
run: |
|
||||||
|
mkdir opencl
|
||||||
|
curl -L -o opencl/clhpp.tar.gz https://github.com/KhronosGroup/OpenCL-CLHPP/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
|
||||||
|
curl -L -o opencl/headers.tar.gz https://github.com/KhronosGroup/OpenCL-Headers/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
|
||||||
|
curl -L -o opencl/icd-loader.tar.gz https://github.com/KhronosGroup/OpenCL-ICD-Loader/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
|
||||||
|
tar -xaf opencl/headers.tar.gz -C opencl
|
||||||
|
tar -xaf opencl/clhpp.tar.gz -C opencl
|
||||||
|
tar -xaf opencl/icd-loader.tar.gz -C opencl
|
||||||
|
sudo cp -r opencl/OpenCL-Headers-${OPENCL_VERSION}/CL ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
|
||||||
|
sudo cp -r opencl/OpenCL-CLHPP-${OPENCL_VERSION}/include/CL/* ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include/CL
|
||||||
|
cd opencl/OpenCL-ICD-Loader-${OPENCL_VERSION}
|
||||||
|
cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -DOPENCL_ICD_LOADER_HEADERS_DIR=${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=31 -DANDROID_STL=c++_shared
|
||||||
|
cmake --build build
|
||||||
|
sudo cp build/libOpenCL.so ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android
|
||||||
|
rm -rf opencl
|
||||||
|
|
||||||
|
- name: Install Hexagon SDK
|
||||||
|
id: install_hexsdk
|
||||||
|
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||||
|
env:
|
||||||
|
HEXSDK_VER: 6.4.0.2
|
||||||
|
HEXTLS_VER: 19.0.04
|
||||||
|
run: |
|
||||||
|
curl -L -o hex-sdk.tar.gz https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v$HEXSDK_VER/hexagon-sdk-v$HEXSDK_VER-amd64-lnx.tar.xz
|
||||||
|
mkdir hex-sdk
|
||||||
|
tar -xaf hex-sdk.tar.gz -C hex-sdk
|
||||||
|
ls -l hex-sdk
|
||||||
|
sudo mv hex-sdk /opt/hexagon
|
||||||
|
echo "HEXAGON_SDK_ROOT=/opt/hexagon/$HEXSDK_VER" >> "$GITHUB_ENV"
|
||||||
|
echo "HEXAGON_TOOLS_ROOT=/opt/hexagon/$HEXSDK_VER/tools/HEXAGON_Tools/$HEXTLS_VER" >> "$GITHUB_ENV"
|
||||||
|
echo "DEFAULT_HLOS_ARCH=64" >> "$GITHUB_ENV"
|
||||||
|
echo "DEFAULT_TOOLS_VARIANT=toolv19" >> "$GITHUB_ENV"
|
||||||
|
echo "DEFAULT_NO_QURT_INC=0" >> "$GITHUB_ENV"
|
||||||
|
echo "DEFAULT_DSP_ARCH=v73" >> "$GITHUB_ENV"
|
||||||
|
|
||||||
|
- name: Update CMake presets
|
||||||
|
id: update_presets
|
||||||
|
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||||
|
run: |
|
||||||
|
cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
id: ndk_build
|
||||||
|
run: |
|
||||||
|
cmake ${{ matrix.defines }} -B build
|
||||||
|
cmake --build build
|
||||||
|
cmake --install build --prefix pkg-adb/llama.cpp
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
id: cmake_test
|
||||||
|
run: |
|
||||||
|
echo "FIXME: test on devices"
|
||||||
|
|
||||||
openEuler-latest-cmake-cann:
|
openEuler-latest-cmake-cann:
|
||||||
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }}
|
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }}
|
||||||
defaults:
|
defaults:
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,7 @@
|
||||||
/ggml/src/ggml-impl.h @ggerganov @slaren
|
/ggml/src/ggml-impl.h @ggerganov @slaren
|
||||||
/ggml/src/ggml-metal/ @ggerganov
|
/ggml/src/ggml-metal/ @ggerganov
|
||||||
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
|
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
|
||||||
|
/ggml/src/ggml-hexagon/ @max-krasnyansky
|
||||||
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
||||||
/ggml/src/ggml-quants.* @ggerganov
|
/ggml/src/ggml-quants.* @ggerganov
|
||||||
/ggml/src/ggml-rpc/ @rgerganov
|
/ggml/src/ggml-rpc/ @rgerganov
|
||||||
|
|
|
||||||
|
|
@ -280,6 +280,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||||
| [IBM zDNN](docs/backend/zDNN.md) | IBM Z & LinuxONE |
|
| [IBM zDNN](docs/backend/zDNN.md) | IBM Z & LinuxONE |
|
||||||
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
|
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
|
||||||
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
|
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
|
||||||
|
| [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon |
|
||||||
|
|
||||||
## Obtaining and quantizing models
|
## Obtaining and quantizing models
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
{
|
||||||
|
"version": 4,
|
||||||
|
"configurePresets": [
|
||||||
|
{
|
||||||
|
"name": "arm64-android-snapdragon",
|
||||||
|
"hidden": true,
|
||||||
|
"architecture": { "value": "arm64", "strategy": "external" },
|
||||||
|
"toolset": { "value": "host=x86_64", "strategy": "external" },
|
||||||
|
"cacheVariables": {
|
||||||
|
"ANDROID_ABI": "arm64-v8a",
|
||||||
|
"ANDROID_PLATFORM": "android-31",
|
||||||
|
"CMAKE_TOOLCHAIN_FILE": "$env{ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake",
|
||||||
|
"CMAKE_C_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||||
|
"CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||||
|
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||||
|
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||||
|
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||||
|
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||||
|
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||||
|
"PREBUILT_LIB_DIR": "android_aarch64",
|
||||||
|
"GGML_OPENMP": "OFF",
|
||||||
|
"GGML_LLAMAFILE": "OFF",
|
||||||
|
"GGML_OPENCL": "ON",
|
||||||
|
"GGML_HEXAGON": "ON",
|
||||||
|
"LLAMA_CURL": "OFF"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"name": "arm64-windows-snapdragon",
|
||||||
|
"inherits": [ "base", "arm64-windows-llvm" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||||
|
"PREBUILT_LIB_DIR": "windows_aarch64",
|
||||||
|
"GGML_OPENMP": "OFF",
|
||||||
|
"GGML_LLAMAFILE": "OFF",
|
||||||
|
"GGML_OPENCL": "ON",
|
||||||
|
"GGML_HEXAGON": "ON",
|
||||||
|
"LLAMA_CURL": "OFF"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
{ "name": "arm64-android-snapdragon-debug" , "inherits": [ "base", "arm64-android-snapdragon", "debug" ] },
|
||||||
|
{ "name": "arm64-android-snapdragon-release", "inherits": [ "base", "arm64-android-snapdragon", "release" ] },
|
||||||
|
|
||||||
|
{ "name": "arm64-windows-snapdragon-debug" , "inherits": [ "base", "arm64-windows-snapdragon", "debug" ] },
|
||||||
|
{ "name": "arm64-windows-snapdragon-release", "inherits": [ "base", "arm64-windows-snapdragon", "release" ] }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,239 @@
|
||||||
|
# Snapdragon-based Android devices
|
||||||
|
|
||||||
|
## How to Build
|
||||||
|
|
||||||
|
The easiest way to build llama.cpp for a Snapdragon-based Android device is using the toolchain Docker image (see github.com/snapdragon-toolchain).
|
||||||
|
This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc.
|
||||||
|
|
||||||
|
This method works on Linux, macOS, and Windows. macOS and Windows users should install Docker Desktop.
|
||||||
|
|
||||||
|
```
|
||||||
|
~/src/llama.cpp$ docker run -it -u $(id -u):$(id -g) --volume $(pwd):/workspace --platform linux/amd64 ghcr.io/snapdragon-toolchain/arm64-android:v0.3
|
||||||
|
[d]/> cd /workspace
|
||||||
|
```
|
||||||
|
|
||||||
|
The rest of the Android build process assumes that you're running inside the toolchain container.
|
||||||
|
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||||
|
|
||||||
|
```
|
||||||
|
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||||
|
|
||||||
|
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
|
||||||
|
Preset CMake variables:
|
||||||
|
ANDROID_ABI="arm64-v8a"
|
||||||
|
...
|
||||||
|
CMAKE_TOOLCHAIN_FILE="/opt/android-ndk-r28b/build/cmake/android.toolchain.cmake"
|
||||||
|
GGML_HEXAGON="ON"
|
||||||
|
GGML_OPENCL="ON"
|
||||||
|
GGML_OPENMP="OFF"
|
||||||
|
HEXAGON_SDK_ROOT="/opt/hexagon/6.4.0.2"
|
||||||
|
...
|
||||||
|
-- Including OpenCL backend
|
||||||
|
-- Including Hexagon backend
|
||||||
|
...
|
||||||
|
-- Build files have been written to: /workspace/build-snapdragon
|
||||||
|
|
||||||
|
[d]/workspace> cmake --build build-snapdragon
|
||||||
|
...
|
||||||
|
[144/356] Performing build step for 'htp-v73'
|
||||||
|
[1/16] Generating htp_iface_skel.c, htp_iface_stub.c, htp_iface.h
|
||||||
|
[2/16] Building C object CMakeFiles/ggml-htp-v73.dir/hvx-sigmoid.c.obj
|
||||||
|
[3/16] Building C object CMakeFiles/ggml-htp-v73.dir/htp-dma.c.obj
|
||||||
|
[4/16] Building C object CMakeFiles/ggml-htp-v73.dir/worker-pool.c.obj
|
||||||
|
...
|
||||||
|
-- Installing: /workspace/build-snapdragon/ggml/src/ggml-hexagon/libggml-htp-v73.so
|
||||||
|
-- Installing: /workspace/build-snapdragon/ggml/src/ggml-hexagon/libggml-htp-v75.so
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
To generate an installable "package" simply use cmake --install:
|
||||||
|
|
||||||
|
```
|
||||||
|
[d]/workspace> cmake --install build-snapdragon --prefix pkg-adb/llama.cpp
|
||||||
|
-- Install configuration: "Release"
|
||||||
|
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-cpu.so
|
||||||
|
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-opencl.so
|
||||||
|
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-hexagon.so
|
||||||
|
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v73.so
|
||||||
|
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v75.so
|
||||||
|
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v79.so
|
||||||
|
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v81.so
|
||||||
|
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml.so
|
||||||
|
...
|
||||||
|
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-bench
|
||||||
|
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-cli
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
## How to Install
|
||||||
|
|
||||||
|
For this step, your device needs to be configured for on-device development.
|
||||||
|
Please see https://developer.android.com/studio/debug/dev-options for details.
|
||||||
|
|
||||||
|
Once ADB is enabled, use `adb push` to install `pkg-snapdragon` on the device.
|
||||||
|
**Note that the toolchain Docker image doesn't have ADB and doesn't set up the ADB bridge. Please use native ADB on the host.**
|
||||||
|
|
||||||
|
```
|
||||||
|
~/src/llama.cpp$ adb push pkg-adb/llama.cpp /data/local/tmp/
|
||||||
|
pkg-adb/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
|
||||||
|
pkg-adb/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
|
||||||
|
pkg-adb/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
|
||||||
|
102 files pushed, 0 skipped. 186.9 MB/s (963151597 bytes in 4.914s)
|
||||||
|
```
|
||||||
|
|
||||||
|
At this point, you should also install some models:
|
||||||
|
|
||||||
|
```
|
||||||
|
~/src/llama.cpp$ wget https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||||
|
...
|
||||||
|
2025-10-11 12:04:52 (10.7 MB/s) - ‘Llama-3.2-1B-Instruct-Q4_0.gguf’ saved [773025920/773025920]
|
||||||
|
|
||||||
|
~/src/llama.cpp$ adb push Llama-3.2-1B-Instruct-Q4_0.gguf /data/local/tmp/gguf
|
||||||
|
Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 bytes in 19.250s)
|
||||||
|
```
|
||||||
|
|
||||||
|
## How to Run
|
||||||
|
|
||||||
|
The easiest way to run llama.cpp cli tools is using provided wrapper scripts that properly set up all required environment variables.
|
||||||
|
|
||||||
|
llama.cpp supports three backends on Snapdragon-based devices: CPU, Adreno GPU (GPUOpenCL), and Hexagon NPU (HTP0-4).
|
||||||
|
You can select which backend to run the model on using the `D=` variable, which maps to the `--device` option.
|
||||||
|
|
||||||
|
Hexagon NPU behaves as a "GPU" device when it comes to `-ngl` and other offload-related options.
|
||||||
|
|
||||||
|
Here are some examples of running various llama.cpp tools via ADB.
|
||||||
|
|
||||||
|
Simple question for Llama-3.2-1B
|
||||||
|
|
||||||
|
```
|
||||||
|
~/src/llama.cpp$ M=Llama-3.2-1B-Instruct-Q4_0.gguf D=HTP0 ./scripts/snapdragon/adb/run-cli.sh -no-cnv -p "what is the most popular cookie in the world?"
|
||||||
|
...
|
||||||
|
ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1
|
||||||
|
ggml-hex: Hexagon Arch version v79
|
||||||
|
ggml-hex: allocating new session: HTP0
|
||||||
|
ggml-hex: new session: HTP0 : session-id 0 domain-id 3 uri file:///libggml-htp-v79.so?htp_iface_skel_handle_invoke&_modver=1.0&_dom=cdsp&_session=0 handle 0xb4000072c7955e50
|
||||||
|
...
|
||||||
|
load_tensors: offloading output layer to GPU
|
||||||
|
load_tensors: offloaded 17/17 layers to GPU
|
||||||
|
load_tensors: CPU model buffer size = 225.49 MiB
|
||||||
|
load_tensors: HTP0 model buffer size = 0.26 MiB
|
||||||
|
load_tensors: HTP0-REPACK model buffer size = 504.00 MiB
|
||||||
|
...
|
||||||
|
I hope this helps you understand the world's most popular cookies! [end of text]
|
||||||
|
...
|
||||||
|
llama_perf_sampler_print: sampling time = 30.08 ms / 487 runs ( 0.06 ms per token, 16191.77 tokens per second)
|
||||||
|
llama_perf_context_print: load time = 617.94 ms
|
||||||
|
llama_perf_context_print: prompt eval time = 80.76 ms / 11 tokens ( 7.34 ms per token, 136.21 tokens per second)
|
||||||
|
llama_perf_context_print: eval time = 9210.59 ms / 475 runs ( 19.39 ms per token, 51.57 tokens per second)
|
||||||
|
llama_perf_context_print: total time = 9454.92 ms / 486 tokens
|
||||||
|
llama_perf_context_print: graphs reused = 473
|
||||||
|
llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted |
|
||||||
|
llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||||
|
llama_memory_breakdown_print: | - Host | 439 = 225 + 136 + 77 |
|
||||||
|
llama_memory_breakdown_print: | - HTP0-REPACK | 504 = 504 + 0 + 0 |
|
||||||
|
```
|
||||||
|
|
||||||
|
Summary request for OLMoE-1B-7B. This is a large model that requires two HTP sessions/devices
|
||||||
|
|
||||||
|
```
|
||||||
|
~/src/llama.cpp$ M=OLMoE-1B-7B-0125-Instruct-Q4_0.gguf NDEV=2 D=HTP0,HTP1 ./scripts/snapdragon/adb/run-cli.sh -f surfing.txt -no-cnv
|
||||||
|
...
|
||||||
|
ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1
|
||||||
|
ggml-hex: Hexagon Arch version v81
|
||||||
|
ggml-hex: allocating new session: HTP0
|
||||||
|
ggml-hex: allocating new session: HTP1
|
||||||
|
...
|
||||||
|
load_tensors: offloading output layer to GPU
|
||||||
|
load_tensors: offloaded 17/17 layers to GPU
|
||||||
|
load_tensors: CPU model buffer size = 143.86 MiB
|
||||||
|
load_tensors: HTP1 model buffer size = 0.23 MiB
|
||||||
|
load_tensors: HTP1-REPACK model buffer size = 1575.00 MiB
|
||||||
|
load_tensors: HTP0 model buffer size = 0.28 MiB
|
||||||
|
load_tensors: HTP0-REPACK model buffer size = 2025.00 MiB
|
||||||
|
...
|
||||||
|
llama_context: CPU output buffer size = 0.19 MiB
|
||||||
|
llama_kv_cache: HTP1 KV buffer size = 238.00 MiB
|
||||||
|
llama_kv_cache: HTP0 KV buffer size = 306.00 MiB
|
||||||
|
llama_kv_cache: size = 544.00 MiB ( 8192 cells, 16 layers, 1/1 seqs), K (q8_0): 272.00 MiB, V (q8_0): 272.00 MiB
|
||||||
|
llama_context: HTP0 compute buffer size = 15.00 MiB
|
||||||
|
llama_context: HTP1 compute buffer size = 15.00 MiB
|
||||||
|
llama_context: CPU compute buffer size = 24.56 MiB
|
||||||
|
...
|
||||||
|
llama_perf_context_print: prompt eval time = 1730.57 ms / 212 tokens ( 8.16 ms per token, 122.50 tokens per second)
|
||||||
|
llama_perf_context_print: eval time = 5624.75 ms / 257 runs ( 21.89 ms per token, 45.69 tokens per second)
|
||||||
|
llama_perf_context_print: total time = 7377.33 ms / 469 tokens
|
||||||
|
llama_perf_context_print: graphs reused = 255
|
||||||
|
llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted |
|
||||||
|
llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||||
|
llama_memory_breakdown_print: | - HTP1 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||||
|
llama_memory_breakdown_print: | - Host | 742 = 144 + 544 + 54 |
|
||||||
|
llama_memory_breakdown_print: | - HTP1-REPACK | 1575 = 1575 + 0 + 0 |
|
||||||
|
llama_memory_breakdown_print: | - HTP0-REPACK | 2025 = 2025 + 0 + 0 |
|
||||||
|
```
|
||||||
|
|
||||||
|
Op test for MUL_MAT
|
||||||
|
|
||||||
|
```
|
||||||
|
~/src/llama.cpp$ HB=0 ./scripts/snapdragon/adb/run-tool.sh test-backend-ops -b HTP0 -o MUL_MAT
|
||||||
|
...
|
||||||
|
Backend 2/3: HTP0
|
||||||
|
Device description: Hexagon
|
||||||
|
Device memory: 2048 MB (2048 MB free)
|
||||||
|
MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK
|
||||||
|
MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK
|
||||||
|
MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK
|
||||||
|
|
||||||
|
~/src/llama.cpp-hexagon$ M=Llama-3.2-1B-Instruct-Q4_0.gguf ./scripts/snapdragon/adb/run-bench.sh -p 128 -n 64
|
||||||
|
...
|
||||||
|
ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1
|
||||||
|
ggml-hex: Hexagon Arch version v79
|
||||||
|
ggml-hex: allocating new session: HTP0
|
||||||
|
ggml-hex: new session: HTP0 : session-id 0 domain-id 3 uri file:///libggml-htp-v79.so?htp_iface_skel_handle_invoke&_modver=1.0&_dom=cdsp&_session=0 handle 0xb400007d4b231090
|
||||||
|
| model | size | params | backend | ngl | threads | n_batch | mmap | test | t/s |
|
||||||
|
| ---------------| ---------: | -----: | ---------- | --: | ------: | ------: | ---: | ----: | ------------: |
|
||||||
|
| llama 1B Q4_0 | 729.75 MiB | 1.24 B | HTP | 99 | 4 | 128 | 0 | pp128 | 169.42 ± 1.75 |
|
||||||
|
| llama 1B Q4_0 | 729.75 MiB | 1.24 B | HTP | 99 | 4 | 128 | 0 | tg64 | 51.54 ± 1.13 |
|
||||||
|
|
||||||
|
build: 6a8cf8914 (6733)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Environment variables
|
||||||
|
|
||||||
|
- `GGML_HEXAGON_NDEV=1`
|
||||||
|
Controls the number of devices/sessions to allocate. The default is 1.
|
||||||
|
Most quantized models under 4B fit into a single session; an 8B model needs two, and a 20B model needs four.
|
||||||
|
|
||||||
|
- `GGML_HEXAGON_NHVX=0`
|
||||||
|
Controls the number of HVX hardware threads to use. The default is all (actual number varies depending on the hardware version).
|
||||||
|
|
||||||
|
- `GGML_HEXAGON_HOSTBUF=1`
|
||||||
|
Controls whether the Hexagon backend allocates host buffers. By default, all buffers except for REPACK are host buffers.
|
||||||
|
This option is required for testing Ops that require REPACK buffers (MUL_MAT and MUL_MAT_ID).
|
||||||
|
|
||||||
|
- `GGML_HEXAGON_VERBOSE=1`
|
||||||
|
Enables verbose logging of Ops from the backend. Example output:
|
||||||
|
|
||||||
|
```
|
||||||
|
ggml-hex: HTP0 graph-compute n_nodes 2
|
||||||
|
ggml-hex: HTP0 matmul : blk.27.ffn_up.weight x ffn_norm-27 -> ffn_up-27 : 3072:8192 x 3072:1 -> 8192:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x1
|
||||||
|
ggml-hex: HTP0 matmul : blk.27.ffn_gate.weight x ffn_norm-27 -> ffn_gate-27 : 3072:8192 x 3072:1 -> 8192:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x3
|
||||||
|
ggml-hex: HTP0 graph-compute n_nodes 1
|
||||||
|
ggml-hex: HTP0 matmul : blk.27.ffn_down.weight x ffn_gate_par-27 -> ffn_out-27 : 8192:3072 x 8192:1 -> 3072:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x0
|
||||||
|
ggml-hex: HTP0 get-tensor result_output : data 0x7592487000 offset 0 size 513024
|
||||||
|
```
|
||||||
|
|
||||||
|
- `GGML_HEXAGON_PROFILE=1`
|
||||||
|
Generates a host-side profile for the ggml-hexagon Ops.
|
||||||
|
|
||||||
|
- `GGML_HEXAGON_OPMASK=0x0`
|
||||||
|
Allows enabling specific stages of the processing pipeline:
|
||||||
|
|
||||||
|
- `0x1` Enable Op Queue (i.e., queuing Ops into NPU)
|
||||||
|
- `0x2` Enable Dynamic Quantizer (if needed for the Op)
|
||||||
|
- `0x4` Enable Op Compute (MUL_MAT, etc.)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
`GGML_HEXAGON_OPMASK=0x1 llama-cli ...` - Ops are enqueued but NPU-side processing is stubbed out
|
||||||
|
`GGML_HEXAGON_OPMASK=0x3 llama-cli ...` - NPU performs dynamic quantization and skips the rest
|
||||||
|
`GGML_HEXAGON_OPMASK=0x7 llama-cli ...` - Full queuing and processing of Ops (default)
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Hexagon backend developer details
|
||||||
|
|
||||||
|
## Backend libraries
|
||||||
|
|
||||||
|
The Hexagon backend consist of two parts:
|
||||||
|
|
||||||
|
- `libggml-hexagon`
|
||||||
|
This is the regular CPU-side GGML backend library, either shared or statically linked
|
||||||
|
|
||||||
|
- `libggml-htp-vNN`
|
||||||
|
This is the NPU-side (HTP stands for Hexagon Tensor Processor) shared library that contains the Op dispatcher and kernels.
|
||||||
|
The correct library is selected automatically at runtime based on the HW version.
|
||||||
|
|
||||||
|
Here is an example of the build artifacts
|
||||||
|
|
||||||
|
```
|
||||||
|
~/src/llama.cpp$ ls -l pkg-adb/llama.cpp/lib/libggml*
|
||||||
|
pkg-adb/llama.cpp/lib/libggml-base.so
|
||||||
|
pkg-adb/llama.cpp/lib/libggml-cpu.so
|
||||||
|
pkg-adb/llama.cpp/lib/libggml-hexagon.so <<< CPU library
|
||||||
|
pkg-adb/llama.cpp/lib/libggml-htp-v73.so <<< HTP op/kernels for Hexagon v73
|
||||||
|
pkg-adb/llama.cpp/lib/libggml-htp-v75.so
|
||||||
|
pkg-adb/llama.cpp/lib/libggml-htp-v79.so
|
||||||
|
pkg-adb/llama.cpp/lib/libggml-htp-v81.so
|
||||||
|
```
|
||||||
|
|
||||||
|
## Memory buffers
|
||||||
|
|
||||||
|
Hexagon NPU backend takes advantage of the Snapdragon's unified memory model where all buffers are fully accessible by the CPU and GPU.
|
||||||
|
The NPU does have a dedicated tightly-coupled memory called VTCM but that memory is used only for intermediate data (e.g. dynamically
|
||||||
|
quantized tensors) or temporary data (chunks of the weight tensors fetched via DMA).
|
||||||
|
|
||||||
|
Please note that currently the Hexagon backend does not implement SET/GET_ROWS Ops because there is no advantage in offloading those
|
||||||
|
to the NPU at this point.
|
||||||
|
|
||||||
|
The backend does allocates non-host buffers for the tensors with datatypes that require repacking: Q4_0, Q8_0, MXFP4.
|
||||||
|
From the MMU perspective these buffers are still regular buffers (normal access by the CPU) they are marked as non-host simply to force
|
||||||
|
the repacking.
|
||||||
|
|
||||||
|
## Large model handling
|
||||||
|
|
||||||
|
Hexagon NPU session (aka Process Domain (PD) in the Hexagon docs) is limited to a memory mapping of around 3.5GB.
|
||||||
|
In llama.cpp/GGML the Hexagon session is mapped to a single GGML backend device (HTP0, HTP1, etc).
|
||||||
|
|
||||||
|
In order to map models larger than 3.5GB we need to allocate multiple devices and split the model.
|
||||||
|
For this we're taking advantage of the llama.cpp/GGML multi-GPU layer-splitting support.
|
||||||
|
Each Hexagon device behaves like a GPU from the offload and model splitting perspective.
|
||||||
|
|
||||||
|
Here is an example of running GPT-OSS-20B model on a newer Snapdragon device with 16GB of DDR.
|
||||||
|
|
||||||
|
```
|
||||||
|
M=gpt-oss-20b-Q4_0.gguf NDEV=4 D=HTP0,HTP1,HTP2,HTP3 P=surfing.txt scripts/snapdragon/adb/run-cli.sh -no-cnv -f surfing.txt -n 32
|
||||||
|
...
|
||||||
|
LD_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib
|
||||||
|
ADSP_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib
|
||||||
|
GGML_HEXAGON_NDEV=4 ./bin/llama-cli --no-mmap -m /data/local/tmp/llama.cpp/../gguf/gpt-oss-20b-Q4_0.gguf
|
||||||
|
-t 4 --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on -ngl 99 --device HTP0,HTP1,HTP2,HTP3 -no-cnv -f surfing.txt
|
||||||
|
...
|
||||||
|
llama_model_loader: - type f32: 289 tensors
|
||||||
|
llama_model_loader: - type q4_0: 96 tensors
|
||||||
|
llama_model_loader: - type q8_0: 2 tensors
|
||||||
|
llama_model_loader: - type mxfp4: 72 tensors
|
||||||
|
...
|
||||||
|
load_tensors: offloaded 25/25 layers to GPU
|
||||||
|
load_tensors: CPU model buffer size = 1182.09 MiB
|
||||||
|
load_tensors: HTP1 model buffer size = 6.64 MiB
|
||||||
|
load_tensors: HTP1-REPACK model buffer size = 2505.94 MiB
|
||||||
|
load_tensors: HTP3 model buffer size = 5.55 MiB
|
||||||
|
load_tensors: HTP3-REPACK model buffer size = 2088.28 MiB
|
||||||
|
load_tensors: HTP0 model buffer size = 7.75 MiB
|
||||||
|
load_tensors: HTP0-REPACK model buffer size = 2923.59 MiB
|
||||||
|
load_tensors: HTP2 model buffer size = 6.64 MiB
|
||||||
|
load_tensors: HTP2-REPACK model buffer size = 2505.94 MiB
|
||||||
|
...
|
||||||
|
llama_context: n_ctx_per_seq (8192) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
|
||||||
|
llama_context: CPU output buffer size = 0.77 MiB
|
||||||
|
llama_kv_cache_iswa: creating non-SWA KV cache, size = 8192 cells
|
||||||
|
llama_kv_cache: HTP1 KV buffer size = 25.50 MiB
|
||||||
|
llama_kv_cache: HTP3 KV buffer size = 25.50 MiB
|
||||||
|
llama_kv_cache: HTP0 KV buffer size = 25.50 MiB
|
||||||
|
llama_kv_cache: HTP2 KV buffer size = 25.50 MiB
|
||||||
|
llama_kv_cache: size = 102.00 MiB ( 8192 cells, 12 layers, 1/1 seqs), K (q8_0): 51.00 MiB, V (q8_0): 51.00 MiB
|
||||||
|
llama_kv_cache_iswa: creating SWA KV cache, size = 256 cells
|
||||||
|
llama_kv_cache: HTP1 KV buffer size = 0.80 MiB
|
||||||
|
llama_kv_cache: HTP3 KV buffer size = 0.53 MiB
|
||||||
|
llama_kv_cache: HTP0 KV buffer size = 1.06 MiB
|
||||||
|
llama_kv_cache: HTP2 KV buffer size = 0.80 MiB
|
||||||
|
llama_kv_cache: size = 3.19 MiB ( 256 cells, 12 layers, 1/1 seqs), K (q8_0): 1.59 MiB, V (q8_0): 1.59 MiB
|
||||||
|
llama_context: HTP0 compute buffer size = 16.06 MiB
|
||||||
|
llama_context: HTP1 compute buffer size = 16.06 MiB
|
||||||
|
llama_context: HTP2 compute buffer size = 16.06 MiB
|
||||||
|
llama_context: HTP3 compute buffer size = 16.06 MiB
|
||||||
|
llama_context: CPU compute buffer size = 98.19 MiB
|
||||||
|
...
|
||||||
|
llama_perf_context_print: prompt eval time = 3843.67 ms / 197 tokens ( 19.51 ms per token, 51.25 tokens per second)
|
||||||
|
llama_perf_context_print: eval time = 1686.13 ms / 31 runs ( 54.39 ms per token, 18.39 tokens per second)
|
||||||
|
llama_perf_context_print: total time = 6266.30 ms / 228 tokens
|
||||||
|
llama_perf_context_print: graphs reused = 30
|
||||||
|
llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted |
|
||||||
|
llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||||
|
llama_memory_breakdown_print: | - HTP1 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||||
|
llama_memory_breakdown_print: | - HTP2 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||||
|
llama_memory_breakdown_print: | - HTP3 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||||
|
llama_memory_breakdown_print: | - Host | 1476 = 1208 + 105 + 162 |
|
||||||
|
llama_memory_breakdown_print: | - HTP1-REPACK | 2505 = 2505 + 0 + 0 |
|
||||||
|
llama_memory_breakdown_print: | - HTP3-REPACK | 2088 = 2088 + 0 + 0 |
|
||||||
|
llama_memory_breakdown_print: | - HTP0-REPACK | 2923 = 2923 + 0 + 0 |
|
||||||
|
llama_memory_breakdown_print: | - HTP2-REPACK | 2505 = 2505 + 0 + 0 |
|
||||||
|
```
|
||||||
|
|
@ -251,6 +251,8 @@ option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adr
|
||||||
set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
|
set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
|
||||||
"gmml: OpenCL API version to target")
|
"gmml: OpenCL API version to target")
|
||||||
|
|
||||||
|
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
|
||||||
|
|
||||||
# toolchain for vulkan-shaders-gen
|
# toolchain for vulkan-shaders-gen
|
||||||
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
GGML_BACKEND_API ggml_backend_t ggml_backend_hexagon_init(void);
|
||||||
|
|
||||||
|
GGML_BACKEND_API bool ggml_backend_is_hexagon(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_hexagon_reg(void);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
@ -402,6 +402,7 @@ ggml_add_backend(Vulkan)
|
||||||
ggml_add_backend(WebGPU)
|
ggml_add_backend(WebGPU)
|
||||||
ggml_add_backend(zDNN)
|
ggml_add_backend(zDNN)
|
||||||
ggml_add_backend(OpenCL)
|
ggml_add_backend(OpenCL)
|
||||||
|
ggml_add_backend(Hexagon)
|
||||||
|
|
||||||
foreach (target ggml-base ggml)
|
foreach (target ggml-base ggml)
|
||||||
target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)
|
target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)
|
||||||
|
|
|
||||||
|
|
@ -57,6 +57,10 @@
|
||||||
#include "ggml-opencl.h"
|
#include "ggml-opencl.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_HEXAGON
|
||||||
|
#include "ggml-hexagon.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_BLAS
|
#ifdef GGML_USE_BLAS
|
||||||
#include "ggml-blas.h"
|
#include "ggml-blas.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -199,6 +203,9 @@ struct ggml_backend_registry {
|
||||||
#ifdef GGML_USE_OPENCL
|
#ifdef GGML_USE_OPENCL
|
||||||
register_backend(ggml_backend_opencl_reg());
|
register_backend(ggml_backend_opencl_reg());
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef GGML_USE_HEXAGON
|
||||||
|
register_backend(ggml_backend_hexagon_reg());
|
||||||
|
#endif
|
||||||
#ifdef GGML_USE_CANN
|
#ifdef GGML_USE_CANN
|
||||||
register_backend(ggml_backend_cann_reg());
|
register_backend(ggml_backend_cann_reg());
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -598,6 +605,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||||
ggml_backend_load_best("sycl", silent, dir_path);
|
ggml_backend_load_best("sycl", silent, dir_path);
|
||||||
ggml_backend_load_best("vulkan", silent, dir_path);
|
ggml_backend_load_best("vulkan", silent, dir_path);
|
||||||
ggml_backend_load_best("opencl", silent, dir_path);
|
ggml_backend_load_best("opencl", silent, dir_path);
|
||||||
|
ggml_backend_load_best("hexagon", silent, dir_path);
|
||||||
ggml_backend_load_best("musa", silent, dir_path);
|
ggml_backend_load_best("musa", silent, dir_path);
|
||||||
ggml_backend_load_best("cpu", silent, dir_path);
|
ggml_backend_load_best("cpu", silent, dir_path);
|
||||||
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||||
|
include(ExternalProject)
|
||||||
|
|
||||||
|
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||||
|
|
||||||
|
add_library(htp_iface OBJECT
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
|
||||||
|
|
||||||
|
set_target_properties(htp_iface PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
target_include_directories(htp_iface PUBLIC
|
||||||
|
${HEXAGON_SDK_ROOT}/incs
|
||||||
|
${HEXAGON_SDK_ROOT}/incs/stddef
|
||||||
|
${HEXAGON_SDK_ROOT}/utils/examples
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/htp
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR})
|
||||||
|
|
||||||
|
build_idl(htp/htp_iface.idl htp_iface)
|
||||||
|
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES Android)
|
||||||
|
target_link_options(htp_iface PUBLIC -llog -ldl)
|
||||||
|
elseif (CMAKE_SYSTEM_NAME MATCHES Windows)
|
||||||
|
target_precompile_headers(htp_iface PUBLIC <sal.h>)
|
||||||
|
else()
|
||||||
|
target_link_options(htp_iface PUBLIC -ldl)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
link_custom_library(htp_iface cdsprpc)
|
||||||
|
link_custom_library(htp_iface rpcmem)
|
||||||
|
|
||||||
|
set(TARGET_NAME ggml-hexagon)
|
||||||
|
ggml_add_backend_library(${TARGET_NAME}
|
||||||
|
ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h)
|
||||||
|
|
||||||
|
target_link_libraries(${TARGET_NAME} PRIVATE htp_iface)
|
||||||
|
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR})
|
||||||
|
|
||||||
|
# Build HTP bits
|
||||||
|
set(HTP_CMAKE_ARGS
|
||||||
|
-DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
|
||||||
|
-DCMAKE_BUILD_TYPE=Release
|
||||||
|
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||||
|
-DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT}
|
||||||
|
-DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
|
||||||
|
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG})
|
||||||
|
|
||||||
|
ExternalProject_Add(htp-v73
|
||||||
|
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||||
|
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73")
|
||||||
|
|
||||||
|
ExternalProject_Add(htp-v75
|
||||||
|
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||||
|
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75")
|
||||||
|
|
||||||
|
ExternalProject_Add(htp-v79
|
||||||
|
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||||
|
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79")
|
||||||
|
|
||||||
|
ExternalProject_Add(htp-v81
|
||||||
|
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||||
|
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81")
|
||||||
|
|
||||||
|
# Install Hexagon skels required at runtime
|
||||||
|
install(FILES
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so
|
||||||
|
TYPE LIB)
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,448 @@
|
||||||
|
|
||||||
|
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
|
||||||
|
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||||
|
#pragma clang diagnostic ignored "-Wsign-compare"
|
||||||
|
|
||||||
|
#define GGML_COMMON_IMPL_C
|
||||||
|
#include "ggml-backend-impl.h"
|
||||||
|
#include "ggml-common.h"
|
||||||
|
#include "ggml-hexagon.h"
|
||||||
|
#include "ggml-impl.h"
|
||||||
|
|
||||||
|
#include "htp-utils.h"
|
||||||
|
|
||||||
|
#include <domain.h>
|
||||||
|
#include <remote.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
domain * get_domain(int domain_id) {
|
||||||
|
int i = 0;
|
||||||
|
int size = sizeof(supported_domains) / sizeof(domain);
|
||||||
|
|
||||||
|
for (i = 0; i < size; i++) {
|
||||||
|
if (supported_domains[i].id == domain_id) {
|
||||||
|
return &supported_domains[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_valid_domain_id(int domain_id, int compute_only) {
|
||||||
|
int i = 0;
|
||||||
|
int size = sizeof(supported_domains) / sizeof(domain);
|
||||||
|
|
||||||
|
if (compute_only) {
|
||||||
|
return is_CDSP(domain_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i = 0; i < size; i++) {
|
||||||
|
if (supported_domains[i].id == domain_id) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) {
|
||||||
|
int nErr = AEE_SUCCESS;
|
||||||
|
int ss_info = 0;
|
||||||
|
if (domain_type != NULL) {
|
||||||
|
if (strcmp(domain_type, "LPASS") == 0) {
|
||||||
|
ss_info = FASTRPC_LPASS;
|
||||||
|
} else if (strcmp(domain_type, "HPASS") == 0) {
|
||||||
|
ss_info = FASTRPC_HPASS;
|
||||||
|
} else {
|
||||||
|
ss_info = FASTRPC_NSP;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
system_req_payload req = { 0 };
|
||||||
|
req.id = FASTRPC_GET_DOMAINS;
|
||||||
|
req.sys.domains = NULL;
|
||||||
|
fastrpc_domain * domain = NULL;
|
||||||
|
if (ss_info != 0) {
|
||||||
|
req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info);
|
||||||
|
} else {
|
||||||
|
req.sys.flags = 0;
|
||||||
|
}
|
||||||
|
#ifdef _WIN32
|
||||||
|
nErr = AEE_EUNSUPPORTED;
|
||||||
|
goto bail;
|
||||||
|
#endif
|
||||||
|
if (remote_system_request) {
|
||||||
|
nErr = remote_system_request(&req);
|
||||||
|
if (nErr != AEE_SUCCESS) {
|
||||||
|
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
// Allocate memory for domain-info array
|
||||||
|
req.sys.max_domains = req.sys.num_domains;
|
||||||
|
if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) {
|
||||||
|
nErr = AEE_ENOMEMORY;
|
||||||
|
GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains");
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
|
||||||
|
nErr = remote_system_request(&req);
|
||||||
|
if (nErr != AEE_SUCCESS) {
|
||||||
|
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < req.sys.num_domains; i++) {
|
||||||
|
// Verify that only requested type domains were returned
|
||||||
|
domain = &req.sys.domains[i];
|
||||||
|
if (domain->type != ss_info && domain_type != NULL) {
|
||||||
|
nErr = -1;
|
||||||
|
GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n");
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*domains_info = req.sys.domains;
|
||||||
|
*num_domains = req.sys.num_domains;
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTED;
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
bail:
|
||||||
|
if (nErr && !req.sys.domains) {
|
||||||
|
free(req.sys.domains);
|
||||||
|
}
|
||||||
|
return nErr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) {
|
||||||
|
int err = 0;
|
||||||
|
remote_rpc_effective_domain_id_t sess = { 0 };
|
||||||
|
|
||||||
|
sess.domain_name = domain_name;
|
||||||
|
sess.domain_name_len = strlen(domain_name);
|
||||||
|
sess.session_id = session_id;
|
||||||
|
|
||||||
|
err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess));
|
||||||
|
if (err) {
|
||||||
|
GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name,
|
||||||
|
session_id);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
*effec_domain_id = sess.effective_domain_id;
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_dsp_support(int * domain) {
|
||||||
|
int nErr = AEE_SUCCESS;
|
||||||
|
*domain = CDSP_DOMAIN_ID; // DSP domain default value is CDSP_DOMAIN_ID
|
||||||
|
|
||||||
|
if (remote_handle_control) {
|
||||||
|
struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 };
|
||||||
|
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
|
||||||
|
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||||
|
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dsp_capability_domain.capability == 0) {
|
||||||
|
dsp_capability_domain.domain = ADSP_DOMAIN_ID; // Check for ADSP support.
|
||||||
|
dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT;
|
||||||
|
dsp_capability_domain.capability = 0;
|
||||||
|
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain,
|
||||||
|
sizeof(struct remote_dsp_capability));
|
||||||
|
if (dsp_capability_domain.capability) {
|
||||||
|
*domain = ADSP_DOMAIN_ID; // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nErr != AEE_SUCCESS) {
|
||||||
|
GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTEDAPI;
|
||||||
|
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
bail:
|
||||||
|
return nErr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||||
|
int nErr = AEE_SUCCESS;
|
||||||
|
*capability = 0;
|
||||||
|
|
||||||
|
if (attr == VTCM_PAGE || attr == VTCM_COUNT) {
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EBADPARM;
|
||||||
|
GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n");
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
if (remote_handle_control) {
|
||||||
|
if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) {
|
||||||
|
/*
|
||||||
|
* Query the DSP for VTCM information
|
||||||
|
* Since the ADSP does not have a dedicated VTCM, we expect the output to be 0
|
||||||
|
*/
|
||||||
|
struct remote_dsp_capability dsp_capability_vtcm_dsp;
|
||||||
|
dsp_capability_vtcm_dsp.domain = (uint32_t) domain;
|
||||||
|
dsp_capability_vtcm_dsp.attribute_ID = attr;
|
||||||
|
dsp_capability_vtcm_dsp.capability = (uint32_t) 0;
|
||||||
|
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp,
|
||||||
|
sizeof(struct remote_dsp_capability));
|
||||||
|
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||||
|
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||||
|
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||||
|
nErr = AEE_SUCCESS;
|
||||||
|
goto bail;
|
||||||
|
} else if (nErr == AEE_SUCCESS) {
|
||||||
|
*capability = dsp_capability_vtcm_dsp.capability;
|
||||||
|
} else {
|
||||||
|
GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTED;
|
||||||
|
GGML_LOG_ERROR("Unsupported domain %d\n", domain);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTEDAPI;
|
||||||
|
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
bail:
|
||||||
|
return nErr;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_unsignedpd_supported(int domain_id) {
|
||||||
|
int nErr = AEE_SUCCESS;
|
||||||
|
if (remote_handle_control) {
|
||||||
|
struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 };
|
||||||
|
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
|
||||||
|
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||||
|
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (nErr) {
|
||||||
|
GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (dsp_capability_domain.capability == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTEDAPI;
|
||||||
|
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool get_unsignedpd_support(void) {
|
||||||
|
return is_unsignedpd_supported(CDSP_DOMAIN_ID);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_async_fastrpc_supported(int domain) {
|
||||||
|
int nErr = AEE_SUCCESS;
|
||||||
|
if (remote_handle_control) {
|
||||||
|
if (domain == CDSP_DOMAIN_ID) {
|
||||||
|
/*
|
||||||
|
* Query the DSP for ASYNC_FASTRPC_SUPPORT information
|
||||||
|
* Async fastrpc is supported only on CDSP
|
||||||
|
*/
|
||||||
|
struct remote_dsp_capability dsp_capability_async_support;
|
||||||
|
dsp_capability_async_support.domain = (uint32_t) domain;
|
||||||
|
dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT;
|
||||||
|
dsp_capability_async_support.capability = (uint32_t) 0;
|
||||||
|
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support,
|
||||||
|
sizeof(struct remote_dsp_capability));
|
||||||
|
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||||
|
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||||
|
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||||
|
nErr = AEE_SUCCESS;
|
||||||
|
goto bail;
|
||||||
|
} else if (dsp_capability_async_support.capability == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (nErr != AEE_SUCCESS) {
|
||||||
|
GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTED;
|
||||||
|
GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTEDAPI;
|
||||||
|
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
bail:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_status_notification_supported(int domain) {
|
||||||
|
int nErr = AEE_SUCCESS;
|
||||||
|
|
||||||
|
if (remote_handle_control) {
|
||||||
|
/*
|
||||||
|
* Query the DSP for STATUS_NOTIFICATION_SUPPORT information
|
||||||
|
* DSP User PD status notification Support
|
||||||
|
*/
|
||||||
|
struct remote_dsp_capability dsp_capability_status_notification_support;
|
||||||
|
dsp_capability_status_notification_support.domain = (uint32_t) domain;
|
||||||
|
dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT;
|
||||||
|
dsp_capability_status_notification_support.capability = (uint32_t) 0;
|
||||||
|
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support,
|
||||||
|
sizeof(struct remote_dsp_capability));
|
||||||
|
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||||
|
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||||
|
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||||
|
nErr = AEE_SUCCESS;
|
||||||
|
goto bail;
|
||||||
|
} else if (dsp_capability_status_notification_support.capability == 1) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (nErr != AEE_SUCCESS) {
|
||||||
|
GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTEDAPI;
|
||||||
|
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
bail:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||||
|
int nErr = AEE_SUCCESS;
|
||||||
|
*capability = 0;
|
||||||
|
|
||||||
|
if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) {
|
||||||
|
nErr = AEE_EBADPARM;
|
||||||
|
GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n");
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
if (remote_handle_control) {
|
||||||
|
if (domain == CDSP_DOMAIN_ID) {
|
||||||
|
/*
|
||||||
|
* Query the DSP for HMX SUPPORT information
|
||||||
|
* HMX is supported on CDSP only
|
||||||
|
*/
|
||||||
|
struct remote_dsp_capability dsp_capability_hmx_dsp;
|
||||||
|
dsp_capability_hmx_dsp.domain = (uint32_t) domain;
|
||||||
|
dsp_capability_hmx_dsp.attribute_ID = attr;
|
||||||
|
dsp_capability_hmx_dsp.capability = (uint32_t) 0;
|
||||||
|
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp,
|
||||||
|
sizeof(struct remote_dsp_capability));
|
||||||
|
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||||
|
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||||
|
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||||
|
nErr = AEE_SUCCESS;
|
||||||
|
goto bail;
|
||||||
|
} else if (nErr == AEE_SUCCESS) {
|
||||||
|
*capability = dsp_capability_hmx_dsp.capability;
|
||||||
|
} else {
|
||||||
|
GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTED;
|
||||||
|
GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTEDAPI;
|
||||||
|
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
bail:
|
||||||
|
return nErr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_hex_arch_ver(int domain, int * arch) {
|
||||||
|
if (!remote_handle_control) {
|
||||||
|
GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
|
||||||
|
return AEE_EUNSUPPORTEDAPI;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct remote_dsp_capability arch_ver;
|
||||||
|
arch_ver.domain = (uint32_t) domain;
|
||||||
|
arch_ver.attribute_ID = ARCH_VER;
|
||||||
|
arch_ver.capability = (uint32_t) 0;
|
||||||
|
|
||||||
|
int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
|
||||||
|
if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
|
||||||
|
GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
|
||||||
|
return AEE_EUNSUPPORTEDAPI;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (err != AEE_SUCCESS) {
|
||||||
|
GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (arch_ver.capability & 0xff) {
|
||||||
|
case 0x73:
|
||||||
|
*arch = 73;
|
||||||
|
return 0;
|
||||||
|
case 0x75:
|
||||||
|
*arch = 75;
|
||||||
|
return 0;
|
||||||
|
case 0x79:
|
||||||
|
*arch = 79;
|
||||||
|
return 0;
|
||||||
|
case 0x81:
|
||||||
|
*arch = 81;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||||
|
int nErr = AEE_SUCCESS;
|
||||||
|
*capability = 0;
|
||||||
|
|
||||||
|
if (remote_handle_control) {
|
||||||
|
if (domain == CDSP_DOMAIN_ID) {
|
||||||
|
/*
|
||||||
|
* Query the DSP for HVX SUPPORT information
|
||||||
|
* HVX is supported on CDSP only
|
||||||
|
*/
|
||||||
|
struct remote_dsp_capability dsp_capability_hvx_dsp;
|
||||||
|
dsp_capability_hvx_dsp.domain = (uint32_t) domain;
|
||||||
|
dsp_capability_hvx_dsp.attribute_ID = attr;
|
||||||
|
dsp_capability_hvx_dsp.capability = (uint32_t) 0;
|
||||||
|
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp,
|
||||||
|
sizeof(struct remote_dsp_capability));
|
||||||
|
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||||
|
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||||
|
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||||
|
nErr = AEE_SUCCESS;
|
||||||
|
goto bail;
|
||||||
|
} else if (nErr == AEE_SUCCESS) {
|
||||||
|
*capability = dsp_capability_hvx_dsp.capability;
|
||||||
|
} else {
|
||||||
|
GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTED;
|
||||||
|
GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain);
|
||||||
|
goto bail;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nErr = AEE_EUNSUPPORTEDAPI;
|
||||||
|
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
bail:
|
||||||
|
return nErr;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,219 @@
|
||||||
|
#ifndef HTP_UTILS_H
|
||||||
|
#define HTP_UTILS_H
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <AEEStdErr.h>
|
||||||
|
#include <inttypes.h>
|
||||||
|
#include <remote.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
/* Offset to differentiate HLOS and Hexagon error codes.
|
||||||
|
Stores the value of AEE_EOFFSET for Hexagon. */
|
||||||
|
#ifndef DSP_OFFSET
|
||||||
|
# define DSP_OFFSET 0x80000400
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* Errno for connection reset by peer. */
|
||||||
|
#ifndef ECONNRESET
|
||||||
|
# ifdef __hexagon__
|
||||||
|
# define ECONNRESET 104
|
||||||
|
# endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* Abstraction of different OS specific sleep APIs.
|
||||||
|
SLEEP accepts input in seconds. */
|
||||||
|
#ifndef SLEEP
|
||||||
|
# ifdef __hexagon__
|
||||||
|
# define SLEEP(x) \
|
||||||
|
{ /* Do nothing for simulator. */ \
|
||||||
|
}
|
||||||
|
# else
|
||||||
|
# ifdef _WINDOWS
|
||||||
|
# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
|
||||||
|
# else
|
||||||
|
# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
|
||||||
|
# endif
|
||||||
|
# endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* Include windows specific header files. */
|
||||||
|
#ifdef _WINDOWS
|
||||||
|
# include <sysinfoapi.h>
|
||||||
|
# include <windows.h>
|
||||||
|
# define _CRT_SECURE_NO_WARNINGS 1
|
||||||
|
# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
|
||||||
|
/* Including this file for custom implementation of getopt function. */
|
||||||
|
# include "getopt_custom.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* Includes and defines for all HLOS except windows */
|
||||||
|
#if !defined(__hexagon__) && !defined(_WINDOWS)
|
||||||
|
# include "unistd.h"
|
||||||
|
|
||||||
|
# include <sys/time.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* Includes and defines for Hexagon and all HLOS except Windows. */
|
||||||
|
#if !defined(_WINDOWS)
|
||||||
|
/* Weak reference to remote symbol for compilation. */
|
||||||
|
# pragma weak remote_session_control
|
||||||
|
# pragma weak remote_handle_control
|
||||||
|
# pragma weak remote_handle64_control
|
||||||
|
# pragma weak fastrpc_mmap
|
||||||
|
# pragma weak fastrpc_munmap
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if !defined(_WINDOWS)
|
||||||
|
# pragma weak remote_system_request
|
||||||
|
#endif
|
||||||
|
/**
|
||||||
|
* Wrapper for FastRPC Capability API: query DSP support.
|
||||||
|
*
|
||||||
|
* @param[out] domain pointer to supported domain.
|
||||||
|
* @return 0 if query is successful.
|
||||||
|
* non-zero if error, return value points to the error.
|
||||||
|
*/
|
||||||
|
int get_dsp_support(int * domain);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wrapper for FastRPC Capability API: query VTCM information.
|
||||||
|
*
|
||||||
|
* @param[in] domain value of domain in the queried.
|
||||||
|
* @param[out] capability capability value of the attribute queried.
|
||||||
|
* @param[in] attr value of the attribute to the queried.
|
||||||
|
* @return 0 if query is successful.
|
||||||
|
* non-zero if error, return value points to the error.
|
||||||
|
*/
|
||||||
|
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain.
|
||||||
|
*
|
||||||
|
* @return true if unsigned pd is supported.
|
||||||
|
* false if unsigned pd is not supported, capability query failed.
|
||||||
|
*/
|
||||||
|
|
||||||
|
bool get_unsignedpd_support(void);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wrapper for FastRPC Capability API: query unsigned pd support.
|
||||||
|
*
|
||||||
|
* @param[in] domain value of domain in the queried.
|
||||||
|
* @return true if unsigned pd is supported.
|
||||||
|
* false if unsigned pd is not supported, capability query failed.
|
||||||
|
*/
|
||||||
|
|
||||||
|
bool is_unsignedpd_supported(int domain_id);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* is_valid_domain_id API: query a domain id is valid.
|
||||||
|
*
|
||||||
|
* @param[in] domain value of domain in the queried.
|
||||||
|
* @param[in] compute_only value of domain is only compared with CDSP domains supported by the target when enabled.
|
||||||
|
* @return true if value of domain is valid.
|
||||||
|
* false if value of domain is not valid.
|
||||||
|
*/
|
||||||
|
|
||||||
|
bool is_valid_domain_id(int domain_id, int compute_only);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get_domain API: get domain struct from domain value.
|
||||||
|
*
|
||||||
|
* @param[in] domain value of a domain
|
||||||
|
* @return Returns domain struct of the domain if it is supported or else
|
||||||
|
* returns NULL.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
domain * get_domain(int domain_id);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get_domains_info API: get information for all the domains available on the device
|
||||||
|
*
|
||||||
|
* @param[in] domain_type pointer to domain type
|
||||||
|
* @param[in] num_domains pointer to number of domains
|
||||||
|
* @param[in] domains_info pointer to save discovered domains information.
|
||||||
|
* @return 0 if query is successful.
|
||||||
|
* non-zero if error, return value points to the error.
|
||||||
|
*
|
||||||
|
* It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get_effective_domain_id API: get effective domain id for given session id
|
||||||
|
*
|
||||||
|
* @param[in] domain_name pointer to domain name
|
||||||
|
* @param[in] session_id
|
||||||
|
* @param[in] effec_domain_id pointer to save obtained effective domain id.
|
||||||
|
* @return 0 if query is successful.
|
||||||
|
* non-zero if error, return value points to the error.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not
|
||||||
|
*
|
||||||
|
* @param[in] domain_id value of a domain
|
||||||
|
* @return Returns true or false stating support of Async FastRPC
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
bool is_async_fastrpc_supported(int domain_id);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information
|
||||||
|
*
|
||||||
|
* @param[in] domain_id value of a domain
|
||||||
|
* @return Returns true or false stating status notification support information
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
bool is_status_notification_supported(int domain_id);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get_hmx_support_info API: query the DSP for HMX SUPPORT information
|
||||||
|
*
|
||||||
|
* @param[in] domain_id value of a domain
|
||||||
|
* @param[out] capability capability value of the attribute queried.
|
||||||
|
* @param[in] attr value of the attribute to the queried.
|
||||||
|
* @return 0 if query is successful.
|
||||||
|
* non-zero if error, return value points to the error.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get_hex_arch_ver API: query the Hexagon processor architecture version information
|
||||||
|
*
|
||||||
|
* @param[in] domain_id value of a domain
|
||||||
|
* @param[out] Arch version (73, 75, ...)
|
||||||
|
* @return 0 if query is successful.
|
||||||
|
* non-zero if error, return value points to the error.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
int get_hex_arch_ver(int domain, int * arch);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get_hvx_support_info API: query the DSP for HVX SUPPORT information
|
||||||
|
*
|
||||||
|
* @param[in] domain_id value of a domain
|
||||||
|
* @param[out] capability capability value of the attribute queried.
|
||||||
|
* @param[in] attr value of the attribute to the queried.
|
||||||
|
* @return 0 if query is successful.
|
||||||
|
* non-zero if error, return value points to the error.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif //DSP_CAPABILITIES_UTILS_H
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
cmake_minimum_required(VERSION 3.22.2)
|
||||||
|
project(ggml-htp C CXX ASM)
|
||||||
|
|
||||||
|
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||||
|
|
||||||
|
include_directories(
|
||||||
|
${HEXAGON_SDK_ROOT}/incs
|
||||||
|
${HEXAGON_SDK_ROOT}/incs/stddef
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/../..
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
|
${CMAKE_CURRENT_BINARY_DIR})
|
||||||
|
|
||||||
|
set(HTP_LIB ggml-htp-${DSP_VERSION})
|
||||||
|
|
||||||
|
add_library(${HTP_LIB} SHARED
|
||||||
|
main.c
|
||||||
|
htp_iface_skel.c
|
||||||
|
worker-pool.c
|
||||||
|
htp-dma.c
|
||||||
|
hvx-sigmoid.c
|
||||||
|
hvx-inverse.c
|
||||||
|
hvx-exp.c
|
||||||
|
hvx-utils.c
|
||||||
|
matmul-ops.c
|
||||||
|
binary-ops.c
|
||||||
|
unary-ops.c
|
||||||
|
softmax-ops.c
|
||||||
|
act-ops.c
|
||||||
|
rope-ops.c
|
||||||
|
)
|
||||||
|
|
||||||
|
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||||
|
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>)
|
||||||
|
|
||||||
|
build_idl(htp_iface.idl ${HTP_LIB})
|
||||||
|
|
||||||
|
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
|
install(TARGETS ${HTP_LIB})
|
||||||
|
|
@ -0,0 +1,448 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-function"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||||
|
|
||||||
|
#ifdef HTP_DEBUG
|
||||||
|
# define FARF_HIGH 1
|
||||||
|
#endif
|
||||||
|
#include <HAP_farf.h>
|
||||||
|
#include <HAP_mem.h>
|
||||||
|
#include <HAP_perf.h>
|
||||||
|
#include <HAP_ps.h>
|
||||||
|
#include <hexagon_protos.h>
|
||||||
|
#include <hexagon_types.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <qurt_thread.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#include "ggml-common.h"
|
||||||
|
#include "htp-ctx.h"
|
||||||
|
#include "htp-dma.h"
|
||||||
|
#include "htp-msg.h"
|
||||||
|
#include "htp-ops.h"
|
||||||
|
#include "hvx-utils.h"
|
||||||
|
#include "ops-utils.h"
|
||||||
|
|
||||||
|
#define htp_act_preamble3 \
|
||||||
|
const uint32_t ne00 = src0->ne[0]; \
|
||||||
|
const uint32_t ne01 = src0->ne[1]; \
|
||||||
|
const uint32_t ne02 = src0->ne[2]; \
|
||||||
|
const uint32_t ne03 = src0->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t ne10 = src1->ne[0]; \
|
||||||
|
const uint32_t ne11 = src1->ne[1]; \
|
||||||
|
const uint32_t ne12 = src1->ne[2]; \
|
||||||
|
const uint32_t ne13 = src1->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t ne0 = dst->ne[0]; \
|
||||||
|
const uint32_t ne1 = dst->ne[1]; \
|
||||||
|
const uint32_t ne2 = dst->ne[2]; \
|
||||||
|
const uint32_t ne3 = dst->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb00 = src0->nb[0]; \
|
||||||
|
const uint32_t nb01 = src0->nb[1]; \
|
||||||
|
const uint32_t nb02 = src0->nb[2]; \
|
||||||
|
const uint32_t nb03 = src0->nb[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb10 = src1->nb[0]; \
|
||||||
|
const uint32_t nb11 = src1->nb[1]; \
|
||||||
|
const uint32_t nb12 = src1->nb[2]; \
|
||||||
|
const uint32_t nb13 = src1->nb[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb0 = dst->nb[0]; \
|
||||||
|
const uint32_t nb1 = dst->nb[1]; \
|
||||||
|
const uint32_t nb2 = dst->nb[2]; \
|
||||||
|
const uint32_t nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
#define htp_act_preamble2 \
|
||||||
|
const uint32_t ne00 = src0->ne[0]; \
|
||||||
|
const uint32_t ne01 = src0->ne[1]; \
|
||||||
|
const uint32_t ne02 = src0->ne[2]; \
|
||||||
|
const uint32_t ne03 = src0->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t ne0 = dst->ne[0]; \
|
||||||
|
const uint32_t ne1 = dst->ne[1]; \
|
||||||
|
const uint32_t ne2 = dst->ne[2]; \
|
||||||
|
const uint32_t ne3 = dst->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb00 = src0->nb[0]; \
|
||||||
|
const uint32_t nb01 = src0->nb[1]; \
|
||||||
|
const uint32_t nb02 = src0->nb[2]; \
|
||||||
|
const uint32_t nb03 = src0->nb[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb0 = dst->nb[0]; \
|
||||||
|
const uint32_t nb1 = dst->nb[1]; \
|
||||||
|
const uint32_t nb2 = dst->nb[2]; \
|
||||||
|
const uint32_t nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
||||||
|
const struct htp_tensor * src1,
|
||||||
|
struct htp_tensor * dst,
|
||||||
|
const int32_t * op_params,
|
||||||
|
struct htp_spad * src0_spad,
|
||||||
|
struct htp_spad * src1_spad,
|
||||||
|
struct htp_spad * dst_spad,
|
||||||
|
uint32_t nth,
|
||||||
|
uint32_t ith,
|
||||||
|
uint32_t src0_nrows_per_thread) {
|
||||||
|
htp_act_preamble3;
|
||||||
|
|
||||||
|
size_t src0_row_size = nb01;
|
||||||
|
size_t src1_row_size = nb11;
|
||||||
|
size_t dst_row_size = nb1;
|
||||||
|
|
||||||
|
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||||
|
|
||||||
|
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||||
|
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||||
|
|
||||||
|
// no work for this thread
|
||||||
|
if (src0_start_row >= src0_end_row) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t t1, t2;
|
||||||
|
t1 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
int is_aligned = 1;
|
||||||
|
int opt_path = 0;
|
||||||
|
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||||
|
is_aligned = 0;
|
||||||
|
FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||||
|
opt_path = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||||
|
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||||
|
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||||
|
|
||||||
|
bool src1_valid = src1->ne[0];
|
||||||
|
if (!src1_valid) {
|
||||||
|
data_src1 = data_src0;
|
||||||
|
src1_row_size = src0_row_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||||
|
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
|
||||||
|
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||||
|
|
||||||
|
const int32_t swapped = op_params[1];
|
||||||
|
|
||||||
|
const int nc = (src1_valid) ? ne0 : ne0 / 2;
|
||||||
|
|
||||||
|
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||||
|
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||||
|
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
|
||||||
|
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||||
|
|
||||||
|
if (ir + 1 < src0_end_row) {
|
||||||
|
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!src1_valid) {
|
||||||
|
src0 += swapped ? nc : 0;
|
||||||
|
src1 += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (1 == opt_path) {
|
||||||
|
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
|
||||||
|
hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
|
||||||
|
(uint8_t *) dst, nc);
|
||||||
|
} else {
|
||||||
|
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, nc, true);
|
||||||
|
hvx_add_scalar_f32(src0_spad_data, 1.0, src1_spad_data, nc);
|
||||||
|
hvx_inverse_f32(src1_spad_data, src0_spad_data, nc);
|
||||||
|
|
||||||
|
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, dst_spad_data, nc);
|
||||||
|
hvx_mul_f32(dst_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t2 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
FARF(HIGH, "swiglu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
|
||||||
|
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||||
|
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
|
||||||
|
const struct htp_tensor * src1,
|
||||||
|
struct htp_tensor * dst,
|
||||||
|
const int32_t * op_params,
|
||||||
|
struct htp_spad * src0_spad,
|
||||||
|
struct htp_spad * src1_spad,
|
||||||
|
struct htp_spad * dst_spad,
|
||||||
|
uint32_t nth,
|
||||||
|
uint32_t ith,
|
||||||
|
uint32_t src0_nrows_per_thread) {
|
||||||
|
htp_act_preamble3;
|
||||||
|
|
||||||
|
uint64_t t1, t2;
|
||||||
|
t1 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
const size_t src0_row_size = nb01;
|
||||||
|
const size_t src1_row_size = nb11;
|
||||||
|
const size_t dst_row_size = nb1;
|
||||||
|
|
||||||
|
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||||
|
|
||||||
|
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||||
|
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||||
|
|
||||||
|
// no work for this thread
|
||||||
|
if (src0_start_row >= src0_end_row) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||||
|
FARF(HIGH, "act-f32: unaligned addresses in activations op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||||
|
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||||
|
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||||
|
|
||||||
|
bool src1_valid = src1->ne[0];
|
||||||
|
if (!src1_valid) {
|
||||||
|
data_src1 = data_src0;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||||
|
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
|
||||||
|
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||||
|
|
||||||
|
const int32_t swapped = op_params[1];
|
||||||
|
const float alpha = ((const float *) (op_params))[2];
|
||||||
|
const float limit = ((const float *) (op_params))[3];
|
||||||
|
|
||||||
|
const int nc = (src1_valid) ? ne0 : ne0 / 2;
|
||||||
|
|
||||||
|
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||||
|
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||||
|
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
|
||||||
|
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||||
|
|
||||||
|
if (ir + 1 < src0_end_row) {
|
||||||
|
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0 += swapped ? nc : 0;
|
||||||
|
src1 += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x (src0_spad_data) = std::min(src0_p[k], limit);
|
||||||
|
hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
|
||||||
|
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
|
||||||
|
hvx_clamp_scalar_f32((const uint8_t *) src1, limit, limit, src1_spad_data, nc);
|
||||||
|
// y (src1_spad_data) = y1 + 1.f
|
||||||
|
hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc);
|
||||||
|
// x1 (dst_spad_data) = alpha * (x)
|
||||||
|
hvx_mul_scalar_f32(src0_spad_data, alpha, dst_spad_data, nc);
|
||||||
|
// x2 (dst_spad_data) = expf(-x1)
|
||||||
|
hvx_exp_f32(dst_spad_data, dst_spad_data, nc, true);
|
||||||
|
// x3 (dst_spad_data) = x2 + 1.f
|
||||||
|
hvx_add_scalar_f32(dst_spad_data, 1.0, dst_spad_data, nc);
|
||||||
|
// x4 (dst_spad_data) = 1 / x3
|
||||||
|
hvx_inverse_f32(dst_spad_data, dst_spad_data, nc);
|
||||||
|
// out_glu(dst_spad_data) = x * x4
|
||||||
|
hvx_mul_f32(src0_spad_data, dst_spad_data, dst_spad_data, nc);
|
||||||
|
// out = out_glu * (y + 1.f);
|
||||||
|
hvx_mul_f32(dst_spad_data, src1_spad_data, (uint8_t *) dst, nc);
|
||||||
|
}
|
||||||
|
|
||||||
|
t2 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0],
|
||||||
|
src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2],
|
||||||
|
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
|
||||||
|
struct htp_tensor * dst,
|
||||||
|
const int32_t * op_params,
|
||||||
|
struct htp_spad * src0_spad,
|
||||||
|
struct htp_spad * dst_spad,
|
||||||
|
uint32_t nth,
|
||||||
|
uint32_t ith,
|
||||||
|
uint32_t src0_nrows_per_thread) {
|
||||||
|
htp_act_preamble2;
|
||||||
|
|
||||||
|
uint64_t t1, t2;
|
||||||
|
t1 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
const size_t src0_row_size = nb01;
|
||||||
|
const size_t dst_row_size = nb1;
|
||||||
|
|
||||||
|
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||||
|
|
||||||
|
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||||
|
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||||
|
|
||||||
|
// no work for this thread
|
||||||
|
if (src0_start_row >= src0_end_row) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int is_aligned = 1;
|
||||||
|
int opt_path = 0;
|
||||||
|
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||||
|
is_aligned = 0;
|
||||||
|
FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||||
|
opt_path = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||||
|
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||||
|
|
||||||
|
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||||
|
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||||
|
|
||||||
|
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||||
|
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||||
|
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||||
|
|
||||||
|
if (ir + 1 < src0_end_row) {
|
||||||
|
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (1 == opt_path) {
|
||||||
|
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, ne0);
|
||||||
|
hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
|
||||||
|
} else {
|
||||||
|
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, ne0, true);
|
||||||
|
hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0);
|
||||||
|
hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0);
|
||||||
|
|
||||||
|
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t2 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
FARF(HIGH, "silu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02,
|
||||||
|
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
|
||||||
|
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||||
|
unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||||
|
octx->src0_nrows_per_thread);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
|
||||||
|
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||||
|
glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||||
|
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {
|
||||||
|
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||||
|
glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||||
|
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int execute_op_activations_fp32(struct htp_ops_context * octx) {
|
||||||
|
int err = HTP_STATUS_OK;
|
||||||
|
|
||||||
|
const struct htp_tensor * src0 = &octx->src0;
|
||||||
|
const struct htp_tensor * src1 = &octx->src1;
|
||||||
|
struct htp_tensor * dst = &octx->dst;
|
||||||
|
|
||||||
|
if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) {
|
||||||
|
FARF(ERROR, "Non-contiguous tensors are not supported at this time \n");
|
||||||
|
return HTP_STATUS_NO_SUPPORT;
|
||||||
|
}
|
||||||
|
|
||||||
|
worker_callback_t act_op_func;
|
||||||
|
const char * op_type = NULL;
|
||||||
|
|
||||||
|
switch (octx->op) {
|
||||||
|
case HTP_OP_UNARY_SILU:
|
||||||
|
act_op_func = unary_silu_fp32;
|
||||||
|
op_type = "silu-f32";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_GLU_SWIGLU:
|
||||||
|
act_op_func = glu_swiglu_fp32;
|
||||||
|
op_type = "swiglu-f32";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_GLU_SWIGLU_OAI:
|
||||||
|
act_op_func = glu_swiglu_oai_fp32;
|
||||||
|
op_type = "swiglu-oai-f32";
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
|
||||||
|
return HTP_STATUS_NO_SUPPORT;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_threads = octx->n_threads;
|
||||||
|
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||||
|
|
||||||
|
const size_t src0_row_size = src0->nb[1];
|
||||||
|
const size_t src1_row_size = src1->ne[0] ? src1->nb[1] : src0->nb[1];
|
||||||
|
const size_t dst_row_size = dst->nb[1];
|
||||||
|
|
||||||
|
// VTCM scratchpads for all tensors
|
||||||
|
// N rows per thread, padded to HVX vector size
|
||||||
|
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * octx->n_threads;
|
||||||
|
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * octx->n_threads;
|
||||||
|
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * octx->n_threads;
|
||||||
|
|
||||||
|
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||||
|
|
||||||
|
if (src1->ne[0]) {
|
||||||
|
FARF(HIGH,
|
||||||
|
"%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||||
|
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||||
|
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||||
|
octx->dst_spad.size);
|
||||||
|
} else {
|
||||||
|
FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
||||||
|
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||||
|
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the reserved vtcm size is sufficient
|
||||||
|
if (octx->ctx->vtcm_size < spad_size) {
|
||||||
|
FARF(ERROR, "act-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||||
|
spad_size);
|
||||||
|
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||||
|
}
|
||||||
|
|
||||||
|
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||||
|
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||||
|
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||||
|
|
||||||
|
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||||
|
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||||
|
|
||||||
|
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||||
|
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
|
||||||
|
}
|
||||||
|
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
int op_activations(struct htp_ops_context * octx) {
|
||||||
|
int err = HTP_STATUS_OK;
|
||||||
|
|
||||||
|
switch (octx->src0.type) {
|
||||||
|
case HTP_TYPE_F32:
|
||||||
|
err = execute_op_activations_fp32(octx);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = HTP_STATUS_NO_SUPPORT;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,344 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-function"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||||
|
|
||||||
|
#ifdef HTP_DEBUG
|
||||||
|
# define FARF_HIGH 1
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <HAP_farf.h>
|
||||||
|
#include <HAP_mem.h>
|
||||||
|
#include <HAP_perf.h>
|
||||||
|
#include <HAP_ps.h>
|
||||||
|
#include <hexagon_protos.h>
|
||||||
|
#include <hexagon_types.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <qurt_thread.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#include "ggml-common.h"
|
||||||
|
#include "htp-ctx.h"
|
||||||
|
#include "htp-dma.h"
|
||||||
|
#include "htp-msg.h"
|
||||||
|
#include "htp-ops.h"
|
||||||
|
#include "hvx-utils.h"
|
||||||
|
#include "ops-utils.h"
|
||||||
|
|
||||||
|
typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0,
|
||||||
|
const uint8_t * src1,
|
||||||
|
uint8_t * data_dst,
|
||||||
|
const int num_elems);
|
||||||
|
|
||||||
|
static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 };
|
||||||
|
static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
|
||||||
|
|
||||||
|
#define htp_binary_preamble \
|
||||||
|
const uint32_t ne00 = src0->ne[0]; \
|
||||||
|
const uint32_t ne01 = src0->ne[1]; \
|
||||||
|
const uint32_t ne02 = src0->ne[2]; \
|
||||||
|
const uint32_t ne03 = src0->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t ne10 = src1->ne[0]; \
|
||||||
|
const uint32_t ne11 = src1->ne[1]; \
|
||||||
|
const uint32_t ne12 = src1->ne[2]; \
|
||||||
|
const uint32_t ne13 = src1->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t ne0 = dst->ne[0]; \
|
||||||
|
const uint32_t ne1 = dst->ne[1]; \
|
||||||
|
const uint32_t ne2 = dst->ne[2]; \
|
||||||
|
const uint32_t ne3 = dst->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb00 = src0->nb[0]; \
|
||||||
|
const uint32_t nb01 = src0->nb[1]; \
|
||||||
|
const uint32_t nb02 = src0->nb[2]; \
|
||||||
|
const uint32_t nb03 = src0->nb[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb10 = src1->nb[0]; \
|
||||||
|
const uint32_t nb11 = src1->nb[1]; \
|
||||||
|
const uint32_t nb12 = src1->nb[2]; \
|
||||||
|
const uint32_t nb13 = src1->nb[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb0 = dst->nb[0]; \
|
||||||
|
const uint32_t nb1 = dst->nb[1]; \
|
||||||
|
const uint32_t nb2 = dst->nb[2]; \
|
||||||
|
const uint32_t nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
static void binary_job_f32_per_thread(const struct htp_tensor * src0,
|
||||||
|
const struct htp_tensor * src1,
|
||||||
|
struct htp_tensor * dst,
|
||||||
|
uint8_t * spad_data,
|
||||||
|
uint32_t nth,
|
||||||
|
uint32_t ith,
|
||||||
|
uint32_t src0_nrows_per_thread,
|
||||||
|
enum htp_op op) {
|
||||||
|
htp_binary_preamble;
|
||||||
|
|
||||||
|
const size_t src0_row_size = nb01;
|
||||||
|
const size_t src1_row_size = nb11;
|
||||||
|
const size_t dst_row_size = nb1;
|
||||||
|
|
||||||
|
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||||
|
const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
|
||||||
|
|
||||||
|
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||||
|
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||||
|
|
||||||
|
// no work for this thread
|
||||||
|
if (src0_start_row >= src0_end_row) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t t1, t2;
|
||||||
|
t1 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
int is_aligned = 1;
|
||||||
|
int opt_path = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
|
||||||
|
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||||
|
FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||||
|
is_aligned = 0;
|
||||||
|
}
|
||||||
|
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||||
|
opt_path = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op];
|
||||||
|
|
||||||
|
uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
|
||||||
|
|
||||||
|
const uint32_t nr0 = ne00 / ne10;
|
||||||
|
|
||||||
|
const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
|
||||||
|
uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
|
||||||
|
|
||||||
|
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||||
|
const uint8_t * restrict src1_ptr = NULL;
|
||||||
|
|
||||||
|
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||||
|
src1_ptr = data_src1 + (ir % src1_nrows) * src1_row_size;
|
||||||
|
|
||||||
|
if (ir + 1 < src0_end_row) {
|
||||||
|
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
|
||||||
|
if (src1_row_size == src0_row_size) {
|
||||||
|
htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nr0 > 1) {
|
||||||
|
if ((1 == is_aligned) && (nr0 == ne00)) {
|
||||||
|
hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
|
||||||
|
} else {
|
||||||
|
for (uint32_t r = 0; r < nr0; r++) {
|
||||||
|
memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00);
|
||||||
|
} else {
|
||||||
|
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
|
||||||
|
}
|
||||||
|
|
||||||
|
src0_ptr += src0_row_size;
|
||||||
|
dst_ptr += dst_row_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
t2 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
|
||||||
|
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||||
|
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
|
||||||
|
const struct htp_tensor * src1,
|
||||||
|
const struct htp_tensor * src2,
|
||||||
|
struct htp_tensor * dst,
|
||||||
|
uint8_t * spad_data,
|
||||||
|
uint32_t nth,
|
||||||
|
uint32_t ith,
|
||||||
|
uint32_t src0_nrows_per_thread,
|
||||||
|
hvx_elemwise_f32_func func_HVX) {
|
||||||
|
htp_binary_preamble;
|
||||||
|
|
||||||
|
const size_t src0_row_size = nb01;
|
||||||
|
const size_t src1_row_size = nb11;
|
||||||
|
const size_t dst_row_size = nb1;
|
||||||
|
|
||||||
|
const uint32_t ne02_ne01 = ne02 * ne01;
|
||||||
|
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||||
|
|
||||||
|
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||||
|
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||||
|
|
||||||
|
// no work for this thread
|
||||||
|
if (src0_start_row >= src0_end_row) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t t1, t2;
|
||||||
|
t1 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
|
||||||
|
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||||
|
FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||||
|
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||||
|
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||||
|
|
||||||
|
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||||
|
// src0 indices
|
||||||
|
const uint32_t i03 = ir / ne02_ne01;
|
||||||
|
const uint32_t i02 = (ir - i03 * ne02_ne01) / ne01;
|
||||||
|
const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
|
||||||
|
|
||||||
|
// src1 indices
|
||||||
|
const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]);
|
||||||
|
assert(i11 >= 0 && i11 < ne11);
|
||||||
|
|
||||||
|
float * restrict dst_ptr = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1);
|
||||||
|
const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01);
|
||||||
|
const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11);
|
||||||
|
|
||||||
|
if (ir + 1 < src0_end_row) {
|
||||||
|
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
|
||||||
|
if (src1_row_size == src0_row_size) {
|
||||||
|
htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t nr0 = ne00 / ne10;
|
||||||
|
if (nr0 > 1) {
|
||||||
|
for (uint32_t r = 0; r < nr0; r++) {
|
||||||
|
memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10);
|
||||||
|
}
|
||||||
|
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00);
|
||||||
|
} else {
|
||||||
|
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t2 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||||
|
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
|
||||||
|
src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1],
|
||||||
|
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||||
|
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||||
|
|
||||||
|
switch (octx->op) {
|
||||||
|
case HTP_OP_MUL:
|
||||||
|
case HTP_OP_ADD:
|
||||||
|
case HTP_OP_SUB:
|
||||||
|
binary_job_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->src1_spad.data, n, i,
|
||||||
|
octx->src0_nrows_per_thread, octx->op);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_ADD_ID:
|
||||||
|
binary_add_id_job_f32_per_thread(&octx->src0, &octx->src1, &octx->src2, &octx->dst, octx->src0_spad.data, n,
|
||||||
|
i, octx->src0_nrows_per_thread, hvx_add_f32);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
FARF(ERROR, "Unknown Binary Op %u", octx->op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static int execute_op_binary_f32(struct htp_ops_context * octx) {
|
||||||
|
int err = HTP_STATUS_OK;
|
||||||
|
|
||||||
|
const struct htp_tensor * src0 = &octx->src0;
|
||||||
|
const struct htp_tensor * src1 = &octx->src1;
|
||||||
|
struct htp_tensor * dst = &octx->dst;
|
||||||
|
|
||||||
|
worker_callback_t binary_op_func;
|
||||||
|
const char * op_type = NULL;
|
||||||
|
|
||||||
|
switch (octx->op) {
|
||||||
|
case HTP_OP_MUL:
|
||||||
|
binary_op_func = binary_job_dispatcher_f32;
|
||||||
|
op_type = "mul-f32";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_ADD:
|
||||||
|
binary_op_func = binary_job_dispatcher_f32;
|
||||||
|
op_type = "add-f32";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_SUB:
|
||||||
|
binary_op_func = binary_job_dispatcher_f32;
|
||||||
|
op_type = "sub-f32";
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_ADD_ID:
|
||||||
|
binary_op_func = binary_job_dispatcher_f32;
|
||||||
|
op_type = "add-id-f32";
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
FARF(ERROR, "Unsupported binary-Op %u\n", octx->op);
|
||||||
|
return HTP_STATUS_NO_SUPPORT;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_threads = octx->n_threads;
|
||||||
|
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||||
|
|
||||||
|
const size_t src0_row_size = src0->nb[1];
|
||||||
|
const size_t src1_row_size = src1->nb[1];
|
||||||
|
const size_t dst_row_size = dst->nb[1];
|
||||||
|
|
||||||
|
// VTCM scratchpads for all tensors
|
||||||
|
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||||
|
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||||
|
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
|
||||||
|
|
||||||
|
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||||
|
|
||||||
|
FARF(HIGH,
|
||||||
|
"%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||||
|
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||||
|
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||||
|
octx->dst_spad.size);
|
||||||
|
|
||||||
|
// Make sure the reserved vtcm size is sufficient
|
||||||
|
if (octx->ctx->vtcm_size < spad_size) {
|
||||||
|
FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
|
||||||
|
octx->ctx->vtcm_size, spad_size);
|
||||||
|
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||||
|
}
|
||||||
|
|
||||||
|
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||||
|
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||||
|
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||||
|
|
||||||
|
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||||
|
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||||
|
|
||||||
|
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||||
|
|
||||||
|
worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
|
||||||
|
}
|
||||||
|
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
int op_binary(struct htp_ops_context * octx) {
|
||||||
|
int err = HTP_STATUS_OK;
|
||||||
|
|
||||||
|
switch (octx->src0.type) {
|
||||||
|
case HTP_TYPE_F32:
|
||||||
|
err = execute_op_binary_f32(octx);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = HTP_STATUS_NO_SUPPORT;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,157 @@
|
||||||
|
if (HEXAGON_TOOLCHAIN_INCLUDED)
|
||||||
|
return()
|
||||||
|
endif()
|
||||||
|
set(HEXAGON_TOOLCHAIN_INCLUDED true)
|
||||||
|
|
||||||
|
#Cross Compiling for Hexagon
|
||||||
|
set(HEXAGON TRUE)
|
||||||
|
set(CMAKE_SYSTEM_NAME QURT)
|
||||||
|
set(CMAKE_SYSTEM_PROCESSOR Hexagon)
|
||||||
|
set(CMAKE_SYSTEM_VERSION "1") #${HEXAGON_PLATFORM_LEVEL})
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
||||||
|
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
|
||||||
|
set(CUSTOM_RUNELF_PATH "")
|
||||||
|
|
||||||
|
#To fix backward compatibility with EAI addon.
|
||||||
|
if (NOT HEXAGON_SDK_ROOT)
|
||||||
|
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (NOT HEXAGON_TOOLS_ROOT)
|
||||||
|
if (DEFINED ENV{HEXAGON_TOOLS_ROOT})
|
||||||
|
set(HEXAGON_TOOLS_ROOT $ENV{HEXAGON_TOOLS_ROOT})
|
||||||
|
endif()
|
||||||
|
if(NOT HEXAGON_TOOLS_ROOT)
|
||||||
|
set(HEXAGON_TOOLS_ROOT $ENV{DEFAULT_HEXAGON_TOOLS_ROOT})
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
|
||||||
|
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
|
||||||
|
|
||||||
|
#Get the Binary extension of the Hexagon Toolchain
|
||||||
|
if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
|
||||||
|
set(HEXAGON_TOOLCHAIN_SUFFIX .exe)
|
||||||
|
endif()
|
||||||
|
message(DEBUG "CMAKE_HOST_SYSTEM_NAME:${CMAKE_HOST_SYSTEM_NAME}")
|
||||||
|
|
||||||
|
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_arch.cmake)
|
||||||
|
|
||||||
|
set(HEXAGON_TOOLCHAIN ${HEXAGON_TOOLS_ROOT})
|
||||||
|
set(HEXAGON_LIB_DIR "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib")
|
||||||
|
set(HEXAGON_ISS_DIR ${HEXAGON_TOOLCHAIN}/Tools/lib/iss)
|
||||||
|
|
||||||
|
set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES
|
||||||
|
HEXAGON_SDK_ROOT
|
||||||
|
HEXAGON_TOOLS_ROOT
|
||||||
|
)
|
||||||
|
|
||||||
|
#QURT Related includes and linker flags
|
||||||
|
set(V_ARCH ${HEXAGON_ARCH})
|
||||||
|
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}")
|
||||||
|
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||||
|
|
||||||
|
if( ${TREE} MATCHES PAKMAN )
|
||||||
|
set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||||
|
endif()
|
||||||
|
message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}")
|
||||||
|
set(RTOS_DIR ${_QURT_INSTALL_DIR})
|
||||||
|
set(QCC_DIR "${HEXAGON_QCC_DIR}/${V_ARCH}/G0")
|
||||||
|
set(TARGET_DIR "${HEXAGON_LIB_DIR}/${V_ARCH}/G0")
|
||||||
|
|
||||||
|
include_directories(
|
||||||
|
${_QURT_INSTALL_DIR}/include
|
||||||
|
${_QURT_INSTALL_DIR}/include/qurt
|
||||||
|
${_QURT_INSTALL_DIR}/include/posix
|
||||||
|
)
|
||||||
|
|
||||||
|
set(QURT_START_LINK_LIBS)
|
||||||
|
set(QURT_START_LINK_LIBS
|
||||||
|
"${TARGET_DIR}/init.o"
|
||||||
|
"${RTOS_DIR}/lib/crt1.o"
|
||||||
|
"${RTOS_DIR}/lib/debugmon.o"
|
||||||
|
"${RTOS_DIR}/lib/libqurt.a"
|
||||||
|
"${TARGET_DIR}/libc.a"
|
||||||
|
"${TARGET_DIR}/libqcc.a"
|
||||||
|
"${TARGET_DIR}/libhexagon.a"
|
||||||
|
"${RTOS_DIR}/lib/libqurtcfs.a"
|
||||||
|
"${RTOS_DIR}/lib/libtimer_island.a"
|
||||||
|
"${RTOS_DIR}/lib/libtimer_main.a"
|
||||||
|
"${RTOS_DIR}/lib/libposix.a"
|
||||||
|
)
|
||||||
|
STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}")
|
||||||
|
|
||||||
|
set(QURT_END_LINK_LIBS
|
||||||
|
${TARGET_DIR}/fini.o
|
||||||
|
)
|
||||||
|
|
||||||
|
#Non QURT related includes and linker flags
|
||||||
|
|
||||||
|
set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}")
|
||||||
|
|
||||||
|
if (NOT NO_WRAP_MEM_API)
|
||||||
|
set(WRAP_MALLOC -Wl,--wrap=malloc)
|
||||||
|
set(WRAP_CALLOC -Wl,--wrap=calloc)
|
||||||
|
set(WRAP_FREE -Wl,--wrap=free)
|
||||||
|
set(WRAP_REALLOC -Wl,--wrap=realloc)
|
||||||
|
set(WRAP_MEMALIGN -Wl,--wrap=memalign)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(PIC_SHARED_LD_FLAGS
|
||||||
|
-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}
|
||||||
|
-G0
|
||||||
|
-fpic
|
||||||
|
-Wl,-Bsymbolic
|
||||||
|
-Wl,-L${TARGET_DIR_NOOS}/G0/pic
|
||||||
|
-Wl,-L${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/
|
||||||
|
-Wl,--no-threads ${WRAP_MALLOC} ${WRAP_CALLOC} ${WRAP_FREE} ${WRAP_REALLOC} ${WRAP_MEMALIGN}
|
||||||
|
-shared
|
||||||
|
"-o <TARGET> <SONAME_FLAG><TARGET_SONAME>"
|
||||||
|
"<LINK_FLAGS>"
|
||||||
|
-Wl,--start-group
|
||||||
|
"<OBJECTS>"
|
||||||
|
"<LINK_LIBRARIES>"
|
||||||
|
-Wl,--end-group
|
||||||
|
-lc
|
||||||
|
)
|
||||||
|
STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}")
|
||||||
|
|
||||||
|
set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}")
|
||||||
|
|
||||||
|
#System include paths
|
||||||
|
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)
|
||||||
|
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)
|
||||||
|
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)
|
||||||
|
|
||||||
|
#LLVM toolchain setup
|
||||||
|
#Compiler paths, options and architecture
|
||||||
|
set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||||
|
set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||||
|
set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||||
|
set(CMAKE_ASM_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||||
|
set(HEXAGON_LINKER ${CMAKE_C_COMPILER})
|
||||||
|
set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)
|
||||||
|
|
||||||
|
set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
|
||||||
|
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
|
||||||
|
|
||||||
|
#Compiler Options
|
||||||
|
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||||
|
|
||||||
|
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||||
|
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||||
|
|
||||||
|
set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||||
|
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||||
|
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||||
|
|
||||||
|
set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}")
|
||||||
|
set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}")
|
||||||
|
set(CMAKE_ASM_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}" )
|
||||||
|
|
||||||
|
#Linker Options
|
||||||
|
set(CMAKE_C_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}")
|
||||||
|
set(CMAKE_CXX_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}")
|
||||||
|
|
@ -0,0 +1,40 @@
|
||||||
|
#ifndef HTP_CTX_H
|
||||||
|
#define HTP_CTX_H
|
||||||
|
|
||||||
|
#include "htp-dma.h"
|
||||||
|
#include "worker-pool.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <dspqueue.h>
|
||||||
|
#include <stdatomic.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#define HTP_MAX_NTHREADS 10
|
||||||
|
|
||||||
|
// FIXME: move these into matmul-ops
|
||||||
|
#define HTP_SPAD_SRC0_NROWS 16
|
||||||
|
#define HTP_SPAD_SRC1_NROWS 16
|
||||||
|
#define HTP_SPAD_DST_NROWS 2
|
||||||
|
|
||||||
|
// Main context for htp DSP backend
|
||||||
|
struct htp_context {
|
||||||
|
dspqueue_t queue;
|
||||||
|
dma_queue * dma[HTP_MAX_NTHREADS];
|
||||||
|
worker_pool_context_t worker_pool;
|
||||||
|
uint32_t n_threads;
|
||||||
|
|
||||||
|
int thread_id;
|
||||||
|
int thread_prio;
|
||||||
|
|
||||||
|
uint8_t * vtcm_base;
|
||||||
|
size_t vtcm_size;
|
||||||
|
uint32_t vtcm_rctx;
|
||||||
|
|
||||||
|
atomic_bool vtcm_valid;
|
||||||
|
atomic_bool vtcm_inuse;
|
||||||
|
atomic_bool vtcm_needs_release;
|
||||||
|
|
||||||
|
uint32_t opmask;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif /* HTP_CTX_H */
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
#include "htp-dma.h"
|
||||||
|
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-function"
|
||||||
|
|
||||||
|
static inline uint32_t pow2_ceil(uint32_t x) {
|
||||||
|
if (x <= 1) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
int p = 2;
|
||||||
|
x--;
|
||||||
|
while (x >>= 1) {
|
||||||
|
p <<= 1;
|
||||||
|
}
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
dma_queue * dma_queue_create(size_t capacity) {
|
||||||
|
dma_queue * q = (dma_queue *) memalign(32, sizeof(dma_queue));
|
||||||
|
if (q == NULL) {
|
||||||
|
FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
capacity = pow2_ceil(capacity);
|
||||||
|
|
||||||
|
memset(q, 0, sizeof(dma_queue));
|
||||||
|
q->capacity = capacity;
|
||||||
|
q->idx_mask = capacity - 1;
|
||||||
|
|
||||||
|
q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t));
|
||||||
|
memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t));
|
||||||
|
|
||||||
|
q->dst = (void **) memalign(4, capacity * sizeof(void *));
|
||||||
|
memset(q->dst, 0, capacity * sizeof(void *));
|
||||||
|
|
||||||
|
q->tail = &q->desc[capacity - 1];
|
||||||
|
|
||||||
|
if (!q->desc && !q->dst) {
|
||||||
|
FARF(ERROR, "%s: failed to allocate DMA queue items\n", __FUNCTION__);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
FARF(HIGH, "dma-queue: capacity %u\n", capacity);
|
||||||
|
|
||||||
|
return q;
|
||||||
|
}
|
||||||
|
|
||||||
|
void dma_queue_delete(dma_queue * q) {
|
||||||
|
if (!q) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
free(q->desc);
|
||||||
|
free(q->dst);
|
||||||
|
free(q);
|
||||||
|
}
|
||||||
|
|
||||||
|
void dma_queue_flush(dma_queue * q) {
|
||||||
|
while (1) {
|
||||||
|
uint32_t s = dmwait() & 0x3;
|
||||||
|
if (s == HEXAGON_UDMA_DM0_STATUS_IDLE) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
q->tail = NULL;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,119 @@
|
||||||
|
#ifndef HTP_DMA_H
|
||||||
|
#define HTP_DMA_H
|
||||||
|
|
||||||
|
#include <HAP_farf.h>
|
||||||
|
#include <hexagon_protos.h>
|
||||||
|
#include <hexagon_types.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
hexagon_udma_descriptor_type1_t * desc; // descriptor pointers
|
||||||
|
hexagon_udma_descriptor_type1_t * tail; // tail pointer
|
||||||
|
void ** dst; // dst pointers
|
||||||
|
uint32_t push_idx;
|
||||||
|
uint32_t pop_idx;
|
||||||
|
uint32_t capacity;
|
||||||
|
uint32_t idx_mask;
|
||||||
|
} dma_queue;
|
||||||
|
|
||||||
|
dma_queue * dma_queue_create(size_t capacity);
|
||||||
|
void dma_queue_delete(dma_queue * q);
|
||||||
|
void dma_queue_flush(dma_queue * q);
|
||||||
|
|
||||||
|
// TODO: technically we don't need these and could use Q6_dmstart/wait/etc instead
|
||||||
|
// but those do not seem to always compiler properly.
|
||||||
|
static inline void dmstart(void * next) {
|
||||||
|
asm volatile(" release(%0):at" : : "r"(next));
|
||||||
|
asm volatile(" dmstart(%0)" : : "r"(next));
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void dmlink(void * cur, void * next) {
|
||||||
|
asm volatile(" release(%0):at" : : "r"(next));
|
||||||
|
asm volatile(" dmlink(%0, %1)" : : "r"(cur), "r"(next));
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline unsigned int dmpoll(void) {
|
||||||
|
unsigned int ret = 0;
|
||||||
|
asm volatile(" %0 = dmpoll" : "=r"(ret) : : "memory");
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline unsigned int dmwait(void) {
|
||||||
|
unsigned int ret = 0;
|
||||||
|
asm volatile(" %0 = dmwait" : "=r"(ret) : : "memory");
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool dma_queue_push(dma_queue * q,
|
||||||
|
void * dst,
|
||||||
|
const void * src,
|
||||||
|
size_t dst_row_size,
|
||||||
|
size_t src_row_size,
|
||||||
|
size_t nrows) {
|
||||||
|
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx];
|
||||||
|
|
||||||
|
desc->next = NULL;
|
||||||
|
desc->length = 0;
|
||||||
|
desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1;
|
||||||
|
desc->dstbypass = 1;
|
||||||
|
desc->srcbypass = 1;
|
||||||
|
desc->order = 0;
|
||||||
|
desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE;
|
||||||
|
desc->src = (void *) src;
|
||||||
|
desc->dst = (void *) dst;
|
||||||
|
desc->allocation = 0;
|
||||||
|
desc->padding = 0;
|
||||||
|
desc->roiwidth = src_row_size;
|
||||||
|
desc->roiheight = nrows;
|
||||||
|
desc->srcstride = src_row_size;
|
||||||
|
desc->dststride = dst_row_size;
|
||||||
|
desc->srcwidthoffset = 0;
|
||||||
|
desc->dstwidthoffset = 0;
|
||||||
|
|
||||||
|
q->dst[q->push_idx] = dst;
|
||||||
|
|
||||||
|
dmlink(q->tail, desc);
|
||||||
|
q->tail = desc;
|
||||||
|
|
||||||
|
// FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
|
||||||
|
q->push_idx = (q->push_idx + 1) & q->idx_mask;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline uint8_t * dma_queue_pop(dma_queue * q) {
|
||||||
|
if (q->push_idx == q->pop_idx) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx];
|
||||||
|
|
||||||
|
// Wait for desc to complete
|
||||||
|
while (1) {
|
||||||
|
dmpoll();
|
||||||
|
if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t * dst = (uint8_t *) q->dst[q->pop_idx];
|
||||||
|
|
||||||
|
// FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
|
||||||
|
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
|
||||||
|
return dst;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern "C"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif /* HTP_DMA_H */
|
||||||
|
|
@ -0,0 +1,156 @@
|
||||||
|
#ifndef HTP_MSG_H
|
||||||
|
#define HTP_MSG_H
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
// ggml-common.h must be included prio to this header
|
||||||
|
|
||||||
|
// Mask to enable various stages of the Ops.
|
||||||
|
// Used for debugging and profiling.
|
||||||
|
enum {
|
||||||
|
HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP)
|
||||||
|
HTP_OPMASK_QUANTIZE = (1 << 1), // Enable Quantize
|
||||||
|
HTP_OPMASK_COMPUTE = (1 << 2), // Enable Compute
|
||||||
|
};
|
||||||
|
|
||||||
|
// Op flags
|
||||||
|
enum {
|
||||||
|
HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0), // Skip dynamic quantization (reuse quantized tensors)
|
||||||
|
HTP_OPFLAGS_SKIP_COMPUTE = (1 << 1), // Skip actual computation (used for profiling)
|
||||||
|
HTP_OPFLAGS_EARLY_WAKEUP = (1 << 2) // Send early wakeup notification
|
||||||
|
};
|
||||||
|
|
||||||
|
enum htp_status {
|
||||||
|
HTP_STATUS_OK = 1,
|
||||||
|
HTP_STATUS_INTERNAL_ERR = 2,
|
||||||
|
HTP_STATUS_NO_SUPPORT = 3,
|
||||||
|
HTP_STATUS_INVAL_PARAMS = 4,
|
||||||
|
HTP_STATUS_VTCM_TOO_SMALL = 5,
|
||||||
|
};
|
||||||
|
|
||||||
|
// The values must match the ggml_type.
|
||||||
|
// Duplicated here because we can't include full ggml.h in the htp build.
|
||||||
|
// We have some static_asserts in the cpp code to ensure things are in sync.
|
||||||
|
enum htp_data_type {
|
||||||
|
HTP_TYPE_F32 = 0,
|
||||||
|
HTP_TYPE_F16 = 1,
|
||||||
|
HTP_TYPE_Q4_0 = 2,
|
||||||
|
HTP_TYPE_Q8_0 = 8,
|
||||||
|
HTP_TYPE_MXFP4 = 39,
|
||||||
|
HTP_TYPE_COUNT
|
||||||
|
};
|
||||||
|
|
||||||
|
// These values are manually translated over to HTP
|
||||||
|
// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
|
||||||
|
enum htp_op {
|
||||||
|
HTP_OP_MUL = 0,
|
||||||
|
HTP_OP_ADD = 1,
|
||||||
|
HTP_OP_SUB = 2,
|
||||||
|
HTP_OP_DIV = 3,
|
||||||
|
HTP_OP_MUL_MAT = 4,
|
||||||
|
HTP_OP_MUL_MAT_ID = 5,
|
||||||
|
HTP_OP_RMS_NORM = 6,
|
||||||
|
HTP_OP_UNARY_SILU = 7,
|
||||||
|
HTP_OP_GLU_SWIGLU = 8,
|
||||||
|
HTP_OP_GLU_SWIGLU_OAI = 9,
|
||||||
|
HTP_OP_SOFTMAX = 10,
|
||||||
|
HTP_OP_ADD_ID = 11,
|
||||||
|
HTP_OP_ROPE = 12,
|
||||||
|
INVALID
|
||||||
|
};
|
||||||
|
|
||||||
|
static inline size_t htp_type_block_size(uint32_t t) {
|
||||||
|
switch (t) {
|
||||||
|
case HTP_TYPE_F32:
|
||||||
|
return 1;
|
||||||
|
case HTP_TYPE_F16:
|
||||||
|
return 1;
|
||||||
|
case HTP_TYPE_Q4_0:
|
||||||
|
return QK4_0;
|
||||||
|
case HTP_TYPE_Q8_0:
|
||||||
|
return QK8_0;
|
||||||
|
case HTP_TYPE_MXFP4:
|
||||||
|
return QK_MXFP4;
|
||||||
|
default:
|
||||||
|
assert(0 && "unsupported HTP data type");
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline size_t htp_type_nbytes(uint32_t t) {
|
||||||
|
switch (t) {
|
||||||
|
case HTP_TYPE_F32:
|
||||||
|
return 4;
|
||||||
|
case HTP_TYPE_F16:
|
||||||
|
return 2;
|
||||||
|
case HTP_TYPE_Q4_0:
|
||||||
|
return sizeof(block_q4_0);
|
||||||
|
case HTP_TYPE_Q8_0:
|
||||||
|
return sizeof(block_q8_0);
|
||||||
|
case HTP_TYPE_MXFP4:
|
||||||
|
return sizeof(block_mxfp4);
|
||||||
|
default:
|
||||||
|
assert(0 && "unsupported HTP data type");
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char * htp_type_name(uint32_t t) {
|
||||||
|
switch (t) {
|
||||||
|
case HTP_TYPE_F32:
|
||||||
|
return "fp32";
|
||||||
|
case HTP_TYPE_F16:
|
||||||
|
return "fp16";
|
||||||
|
case HTP_TYPE_Q4_0:
|
||||||
|
return "q4_0";
|
||||||
|
case HTP_TYPE_Q8_0:
|
||||||
|
return "q8_0";
|
||||||
|
case HTP_TYPE_MXFP4:
|
||||||
|
return "mxfp4";
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Internal types
|
||||||
|
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||||
|
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||||
|
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
|
||||||
|
|
||||||
|
#define HTP_MAX_DIMS 4
|
||||||
|
|
||||||
|
struct htp_tensor {
|
||||||
|
uint32_t data; // Buffer offset in the messages, and data pointer on the NSP
|
||||||
|
uint32_t type; // Data type
|
||||||
|
uint32_t ne[HTP_MAX_DIMS]; // Number of elements
|
||||||
|
uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
|
||||||
|
};
|
||||||
|
|
||||||
|
#define HTP_MAX_OP_PARAMS 64
|
||||||
|
|
||||||
|
struct htp_general_req {
|
||||||
|
uint32_t op; // GGML/HTP Op
|
||||||
|
int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||||
|
// Params for the op, e.g. epsilon of RMS norm
|
||||||
|
uint32_t flags; // Request flags
|
||||||
|
|
||||||
|
struct htp_tensor src0; // Input0 tensor
|
||||||
|
struct htp_tensor src1; // Input1 tensor
|
||||||
|
struct htp_tensor src2; // Input2 tensor
|
||||||
|
struct htp_tensor dst; // Output tensor
|
||||||
|
|
||||||
|
// should be multiple of 64 bytes (cacheline)
|
||||||
|
};
|
||||||
|
|
||||||
|
struct htp_general_rsp {
|
||||||
|
uint32_t op; // GGML/HTP Op
|
||||||
|
uint32_t status; // HTP_STATUS_...
|
||||||
|
uint32_t prof_usecs; // Number of usec per request
|
||||||
|
uint32_t prof_cycles; // Number of cycles per request
|
||||||
|
uint32_t prof_pkts; // Number of instruction packets per request
|
||||||
|
uint8_t unused[44]; // Pad to 64 bytes
|
||||||
|
};
|
||||||
|
|
||||||
|
#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
|
||||||
|
#define HTP_MAX_PACKET_BUFFERS 4
|
||||||
|
|
||||||
|
#endif /* HTP_MSG_H */
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
#ifndef HTP_OPS_H
|
||||||
|
#define HTP_OPS_H
|
||||||
|
|
||||||
|
#include "htp-ctx.h"
|
||||||
|
#include "htp-msg.h"
|
||||||
|
#include "worker-pool.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
// ggml-common.h must be included prior to this header
|
||||||
|
|
||||||
|
struct htp_spad {
|
||||||
|
uint8_t * data;
|
||||||
|
size_t size;
|
||||||
|
size_t size_per_thread;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct htp_ops_context {
|
||||||
|
struct htp_context * ctx;
|
||||||
|
|
||||||
|
enum htp_op op;
|
||||||
|
int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||||
|
|
||||||
|
struct htp_tensor src0;
|
||||||
|
struct htp_tensor src1;
|
||||||
|
struct htp_tensor src2;
|
||||||
|
struct htp_tensor dst;
|
||||||
|
|
||||||
|
struct htp_spad src0_spad;
|
||||||
|
struct htp_spad src1_spad;
|
||||||
|
struct htp_spad src2_spad;
|
||||||
|
struct htp_spad dst_spad;
|
||||||
|
|
||||||
|
worker_pool_context_t * wpool; // worker pool
|
||||||
|
uint32_t n_threads; // num threads
|
||||||
|
|
||||||
|
uint32_t src0_nrows_per_thread;
|
||||||
|
uint32_t src1_nrows_per_thread;
|
||||||
|
|
||||||
|
uint32_t flags;
|
||||||
|
};
|
||||||
|
|
||||||
|
int op_matmul(struct htp_ops_context * octx);
|
||||||
|
int op_matmul_id(struct htp_ops_context * octx);
|
||||||
|
int op_binary(struct htp_ops_context * octx);
|
||||||
|
int op_unary(struct htp_ops_context * octx);
|
||||||
|
int op_activations(struct htp_ops_context * octx);
|
||||||
|
int op_softmax(struct htp_ops_context * octx);
|
||||||
|
int op_add_id(struct htp_ops_context * octx);
|
||||||
|
int op_rope(struct htp_ops_context * octx);
|
||||||
|
|
||||||
|
#endif /* HTP_OPS_H */
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
// FastRPC IDL interface for GGML HTP
|
||||||
|
|
||||||
|
#ifndef HTP_IDL
|
||||||
|
#define HTP_IDL
|
||||||
|
|
||||||
|
#include "AEEStdDef.idl"
|
||||||
|
#include "remote.idl"
|
||||||
|
|
||||||
|
interface htp_iface : remote_handle64 {
|
||||||
|
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx);
|
||||||
|
AEEResult stop();
|
||||||
|
AEEResult enable_etm();
|
||||||
|
AEEResult disable_etm();
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif /* HTP_IDL */
|
||||||
|
|
@ -0,0 +1,80 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-function"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||||
|
|
||||||
|
#include <hexagon_protos.h>
|
||||||
|
#include <hexagon_types.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#include "ggml-common.h"
|
||||||
|
#include "htp-ctx.h"
|
||||||
|
#include "htp-dma.h"
|
||||||
|
#include "htp-msg.h"
|
||||||
|
#include "htp-ops.h"
|
||||||
|
#include "hvx-utils.h"
|
||||||
|
#include "ops-utils.h"
|
||||||
|
|
||||||
|
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
|
||||||
|
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
int num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
int unaligned_addr = 0;
|
||||||
|
int unaligned_loop = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||||
|
FARF(HIGH, "hvx_exp_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
unaligned_addr = 1;
|
||||||
|
}
|
||||||
|
// assert((0 == unaligned_addr) || (0 == num_elems_whole));
|
||||||
|
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||||
|
unaligned_loop = 1;
|
||||||
|
FARF(HIGH, "hvx_exp_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector vec_out = Q6_V_vzero();
|
||||||
|
|
||||||
|
if (0 == unaligned_loop) {
|
||||||
|
HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
|
||||||
|
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
if (true == negate) {
|
||||||
|
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
|
||||||
|
*p_vec_out++ = hvx_vec_exp_fp32(neg_vec_in);
|
||||||
|
} else {
|
||||||
|
*p_vec_out++ = hvx_vec_exp_fp32(*p_vec_in1++);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||||
|
|
||||||
|
if (true == negate) {
|
||||||
|
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
||||||
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(neg_vec_in);
|
||||||
|
} else {
|
||||||
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * srcf = (float *) src + num_elems_whole;
|
||||||
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||||
|
|
||||||
|
if (true == negate) {
|
||||||
|
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
||||||
|
|
||||||
|
vec_out = hvx_vec_exp_fp32(neg_vec_in);
|
||||||
|
} else {
|
||||||
|
vec_out = hvx_vec_exp_fp32(in);
|
||||||
|
}
|
||||||
|
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-function"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||||
|
|
||||||
|
#include <hexagon_protos.h>
|
||||||
|
#include <hexagon_types.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#include "ggml-common.h"
|
||||||
|
#include "htp-ctx.h"
|
||||||
|
#include "htp-dma.h"
|
||||||
|
#include "htp-msg.h"
|
||||||
|
#include "htp-ops.h"
|
||||||
|
#include "hvx-utils.h"
|
||||||
|
#include "ops-utils.h"
|
||||||
|
|
||||||
|
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
||||||
|
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
int num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
int unaligned_addr = 0;
|
||||||
|
int unaligned_loop = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||||
|
FARF(HIGH, "hvx_inverse_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
unaligned_addr = 1;
|
||||||
|
}
|
||||||
|
// assert((0 == unaligned_addr) || (0 == num_elems_whole));
|
||||||
|
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||||
|
unaligned_loop = 1;
|
||||||
|
FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (0 == unaligned_loop) {
|
||||||
|
HVX_Vector * p_vec_in = (HVX_Vector *) src;
|
||||||
|
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
*p_vec_out++ = hvx_vec_inverse_fp32(*p_vec_in++);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||||
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * srcf = (float *) src + num_elems_whole;
|
||||||
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||||
|
HVX_Vector out = hvx_vec_inverse_fp32(in);
|
||||||
|
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-function"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||||
|
|
||||||
|
#include <hexagon_protos.h>
|
||||||
|
#include <hexagon_types.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#include "ggml-common.h"
|
||||||
|
#include "htp-ctx.h"
|
||||||
|
#include "htp-dma.h"
|
||||||
|
#include "htp-msg.h"
|
||||||
|
#include "htp-ops.h"
|
||||||
|
#include "hvx-utils.h"
|
||||||
|
#include "ops-utils.h"
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
// Reference algo used in hvx-utils
|
||||||
|
static void fast_sigmoid_f32(const float* restrict src, float* restrict dst, const int num_elems)
|
||||||
|
{
|
||||||
|
const float c1 = 0.03138777;
|
||||||
|
const float c2 = 0.276281267;
|
||||||
|
const float c_log2f = 1.442695022;
|
||||||
|
|
||||||
|
int32_t store_ints[32];
|
||||||
|
float store_floats[3][32];
|
||||||
|
|
||||||
|
for (int i = 0; i < num_elems; i++)
|
||||||
|
{
|
||||||
|
float v = src0[i];
|
||||||
|
|
||||||
|
v *= c_log2f*0.5;
|
||||||
|
int intPart = (int)v;
|
||||||
|
float x = (v - intPart);
|
||||||
|
float xx = x * x;
|
||||||
|
float v1 = c_log2f + c2 * xx;
|
||||||
|
float v2 = x + xx * c1 * x;
|
||||||
|
float v3 = (v2 + v1);
|
||||||
|
*((int*)&v3) += intPart << 24;
|
||||||
|
float v4 = v2 - v1;
|
||||||
|
float v5 = v3 - v4;
|
||||||
|
float res = v3 / v5;
|
||||||
|
|
||||||
|
dst[i] = res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
@ -0,0 +1,947 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-function"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||||
|
|
||||||
|
#ifdef HTP_DEBUG
|
||||||
|
# define FARF_HIGH 1
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <HAP_farf.h>
|
||||||
|
#include <HAP_mem.h>
|
||||||
|
#include <HAP_perf.h>
|
||||||
|
#include <HAP_ps.h>
|
||||||
|
#include <hexagon_protos.h>
|
||||||
|
#include <hexagon_types.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#include "ggml-common.h"
|
||||||
|
#include "hvx-utils.h"
|
||||||
|
|
||||||
|
#define htp_binary_ops_preamble \
|
||||||
|
int step_of_4 = num_elems >> 7; \
|
||||||
|
int step_of_2 = (num_elems - step_of_4 * VLEN_FP32 * 4) >> 6; \
|
||||||
|
int step_of_1 = (num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2) >> 5; \
|
||||||
|
int remaining = num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32; \
|
||||||
|
\
|
||||||
|
const uint8_t * restrict src0_curr = src0; \
|
||||||
|
const uint8_t * restrict src1_curr = src1; \
|
||||||
|
uint8_t * restrict dst_curr = dst;
|
||||||
|
|
||||||
|
void hvx_mul_f32(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems) {
|
||||||
|
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
int num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
int unaligned_addr = 0;
|
||||||
|
int unaligned_loop = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
|
||||||
|
(0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||||
|
FARF(HIGH, "hvx_mul_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
unaligned_addr = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||||
|
unaligned_loop = 1;
|
||||||
|
FARF(HIGH, "hvx_mul_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (0 == unaligned_loop) {
|
||||||
|
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
|
||||||
|
HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
|
||||||
|
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, *vec_in2++);
|
||||||
|
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
|
||||||
|
HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2);
|
||||||
|
|
||||||
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * src0f = (const float *) src0 + num_elems_whole;
|
||||||
|
const float * src1f = (const float *) src1 + num_elems_whole;
|
||||||
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in1 = *(HVX_UVector *) src0f;
|
||||||
|
HVX_Vector in2 = *(HVX_UVector *) src1f;
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2);
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_mul_f32_opt(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems) {
|
||||||
|
htp_binary_ops_preamble;
|
||||||
|
|
||||||
|
for (int i = 0; i < step_of_4; i++) {
|
||||||
|
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||||
|
|
||||||
|
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||||
|
|
||||||
|
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
|
||||||
|
|
||||||
|
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||||
|
|
||||||
|
HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
|
||||||
|
|
||||||
|
src0_curr += 4 * VLEN;
|
||||||
|
|
||||||
|
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(v3a, v3b);
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||||
|
|
||||||
|
HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
|
||||||
|
|
||||||
|
HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v4a, v4b);
|
||||||
|
|
||||||
|
src1_curr += 4 * VLEN;
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
|
||||||
|
|
||||||
|
dst_curr += 4 * VLEN;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < step_of_2; i++) {
|
||||||
|
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||||
|
|
||||||
|
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||||
|
|
||||||
|
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
|
||||||
|
|
||||||
|
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||||
|
|
||||||
|
src0_curr += 2 * VLEN;
|
||||||
|
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
|
||||||
|
|
||||||
|
src1_curr += 2 * VLEN;
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||||
|
|
||||||
|
dst_curr += 2 * VLEN;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
HVX_Vector va = *(HVX_Vector *) src0_curr;
|
||||||
|
|
||||||
|
src0_curr += VLEN;
|
||||||
|
|
||||||
|
HVX_Vector vb = *(HVX_Vector *) src1_curr;
|
||||||
|
|
||||||
|
src1_curr += VLEN;
|
||||||
|
|
||||||
|
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(va, vb);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
|
||||||
|
|
||||||
|
dst_curr += VLEN;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (remaining > 0) {
|
||||||
|
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
|
||||||
|
hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_mul_mul_f32_opt(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
const uint8_t * restrict src2,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems) {
|
||||||
|
const uint8_t * restrict src0_curr = src0;
|
||||||
|
const uint8_t * restrict src1_curr = src1;
|
||||||
|
const uint8_t * restrict src2_curr = src2;
|
||||||
|
uint8_t * restrict dst_curr = dst;
|
||||||
|
|
||||||
|
int step_of_2 = num_elems >> 6;
|
||||||
|
int step_of_1 = (num_elems - step_of_2 * VLEN_FP32 * 2) >> 5;
|
||||||
|
int remaining = num_elems - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32;
|
||||||
|
|
||||||
|
for (int i = 0; i < step_of_2; i++) {
|
||||||
|
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||||
|
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||||
|
HVX_Vector v1c = *(HVX_Vector *) src2_curr;
|
||||||
|
|
||||||
|
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v1_ = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
|
||||||
|
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1_), v1c);
|
||||||
|
|
||||||
|
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||||
|
|
||||||
|
HVX_Vector v2c = *(HVX_Vector *) (src2_curr + VLEN);
|
||||||
|
|
||||||
|
src0_curr += 2 * VLEN;
|
||||||
|
|
||||||
|
HVX_Vector v2_ = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2_), v2c);
|
||||||
|
|
||||||
|
src1_curr += 2 * VLEN;
|
||||||
|
src2_curr += 2 * VLEN;
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||||
|
|
||||||
|
dst_curr += 2 * VLEN;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
HVX_Vector va = *(HVX_Vector *) src0_curr;
|
||||||
|
src0_curr += VLEN;
|
||||||
|
|
||||||
|
HVX_Vector vb = *(HVX_Vector *) src1_curr;
|
||||||
|
src1_curr += VLEN;
|
||||||
|
|
||||||
|
HVX_Vector vc = *(HVX_Vector *) src2_curr;
|
||||||
|
src2_curr += VLEN;
|
||||||
|
|
||||||
|
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(va, vb);
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), vc);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v2);
|
||||||
|
dst_curr += VLEN;
|
||||||
|
}
|
||||||
|
if (remaining > 0) {
|
||||||
|
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), *(HVX_Vector *) src2_curr);
|
||||||
|
hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v2));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_add_f32(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems) {
|
||||||
|
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
int num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
int unaligned_addr = 0;
|
||||||
|
int unaligned_loop = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
|
||||||
|
(0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||||
|
FARF(HIGH, "hvx_add_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
unaligned_addr = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||||
|
unaligned_loop = 1;
|
||||||
|
FARF(HIGH, "hvx_add_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (0 == unaligned_loop) {
|
||||||
|
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
|
||||||
|
HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
|
||||||
|
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, *vec_in2++);
|
||||||
|
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
|
||||||
|
HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2);
|
||||||
|
|
||||||
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * src0f = (const float *) src0 + num_elems_whole;
|
||||||
|
const float * src1f = (const float *) src1 + num_elems_whole;
|
||||||
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in1 = *(HVX_UVector *) src0f;
|
||||||
|
HVX_Vector in2 = *(HVX_UVector *) src1f;
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2);
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_add_f32_opt(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems) {
|
||||||
|
htp_binary_ops_preamble;
|
||||||
|
|
||||||
|
for (int i = 0; i < step_of_4; i++) {
|
||||||
|
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||||
|
|
||||||
|
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||||
|
|
||||||
|
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b);
|
||||||
|
|
||||||
|
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||||
|
|
||||||
|
HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
|
||||||
|
|
||||||
|
src0_curr += 4 * VLEN;
|
||||||
|
|
||||||
|
HVX_Vector v3 = Q6_Vqf32_vadd_VsfVsf(v3a, v3b);
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||||
|
|
||||||
|
HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
|
||||||
|
|
||||||
|
HVX_Vector v4 = Q6_Vqf32_vadd_VsfVsf(v4a, v4b);
|
||||||
|
|
||||||
|
src1_curr += 4 * VLEN;
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
|
||||||
|
|
||||||
|
dst_curr += 4 * VLEN;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < step_of_2; i++) {
|
||||||
|
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||||
|
|
||||||
|
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||||
|
|
||||||
|
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b);
|
||||||
|
|
||||||
|
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||||
|
|
||||||
|
src0_curr += 2 * VLEN;
|
||||||
|
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b);
|
||||||
|
|
||||||
|
src1_curr += 2 * VLEN;
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||||
|
|
||||||
|
dst_curr += 2 * VLEN;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
HVX_Vector va = *(HVX_Vector *) src0_curr;
|
||||||
|
|
||||||
|
src0_curr += VLEN;
|
||||||
|
|
||||||
|
HVX_Vector vb = *(HVX_Vector *) src1_curr;
|
||||||
|
|
||||||
|
src1_curr += VLEN;
|
||||||
|
|
||||||
|
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(va, vb);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
|
||||||
|
|
||||||
|
dst_curr += VLEN;
|
||||||
|
}
|
||||||
|
if (remaining > 0) {
|
||||||
|
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
|
||||||
|
hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
|
||||||
|
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
size_t num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
int unaligned_addr = 0;
|
||||||
|
int unaligned_loop = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||||
|
FARF(HIGH, "hvx_add_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
unaligned_addr = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||||
|
unaligned_loop = 1;
|
||||||
|
FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
|
||||||
|
|
||||||
|
if (0 == unaligned_loop) {
|
||||||
|
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
|
||||||
|
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, val_vec);
|
||||||
|
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
|
||||||
|
|
||||||
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * srcf = (const float *) src + num_elems_whole;
|
||||||
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
|
||||||
|
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
size_t num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
int unaligned_addr = 0;
|
||||||
|
int unaligned_loop = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||||
|
FARF(HIGH, "hvx_mul_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
unaligned_addr = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||||
|
unaligned_loop = 1;
|
||||||
|
FARF(HIGH, "hvx_mul_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
|
||||||
|
|
||||||
|
if (0 == unaligned_loop) {
|
||||||
|
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
|
||||||
|
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, val_vec);
|
||||||
|
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec);
|
||||||
|
|
||||||
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * srcf = (const float *) src + num_elems_whole;
|
||||||
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec);
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_sub_f32(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems) {
|
||||||
|
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
size_t num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
int unaligned_addr = 0;
|
||||||
|
int unaligned_loop = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
|
||||||
|
(0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||||
|
FARF(HIGH, "hvx_sub_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
unaligned_addr = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||||
|
unaligned_loop = 1;
|
||||||
|
FARF(HIGH, "hvx_sub_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (0 == unaligned_loop) {
|
||||||
|
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
|
||||||
|
HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
|
||||||
|
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, *vec_in2++);
|
||||||
|
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
|
||||||
|
HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2);
|
||||||
|
|
||||||
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * src0f = (const float *) src0 + num_elems_whole;
|
||||||
|
const float * src1f = (const float *) src1 + num_elems_whole;
|
||||||
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in1 = *(HVX_UVector *) src0f;
|
||||||
|
HVX_Vector in2 = *(HVX_UVector *) src1f;
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2);
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_sub_f32_opt(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems) {
|
||||||
|
htp_binary_ops_preamble;
|
||||||
|
|
||||||
|
for (int i = 0; i < step_of_4; i++) {
|
||||||
|
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||||
|
|
||||||
|
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||||
|
|
||||||
|
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b);
|
||||||
|
|
||||||
|
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||||
|
|
||||||
|
HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
|
||||||
|
|
||||||
|
src0_curr += 4 * VLEN;
|
||||||
|
|
||||||
|
HVX_Vector v3 = Q6_Vqf32_vsub_VsfVsf(v3a, v3b);
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||||
|
|
||||||
|
HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
|
||||||
|
|
||||||
|
HVX_Vector v4 = Q6_Vqf32_vsub_VsfVsf(v4a, v4b);
|
||||||
|
|
||||||
|
src1_curr += 4 * VLEN;
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
|
||||||
|
|
||||||
|
dst_curr += 4 * VLEN;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < step_of_2; i++) {
|
||||||
|
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||||
|
|
||||||
|
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||||
|
|
||||||
|
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b);
|
||||||
|
|
||||||
|
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||||
|
|
||||||
|
src0_curr += 2 * VLEN;
|
||||||
|
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b);
|
||||||
|
|
||||||
|
src1_curr += 2 * VLEN;
|
||||||
|
|
||||||
|
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||||
|
|
||||||
|
dst_curr += 2 * VLEN;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
HVX_Vector va = *(HVX_Vector *) src0_curr;
|
||||||
|
|
||||||
|
src0_curr += VLEN;
|
||||||
|
|
||||||
|
HVX_Vector vb = *(HVX_Vector *) src1_curr;
|
||||||
|
|
||||||
|
src1_curr += VLEN;
|
||||||
|
|
||||||
|
HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(va, vb);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
|
||||||
|
|
||||||
|
dst_curr += VLEN;
|
||||||
|
}
|
||||||
|
if (remaining > 0) {
|
||||||
|
HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
|
||||||
|
hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
|
||||||
|
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
size_t num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
int unaligned_addr = 0;
|
||||||
|
int unaligned_loop = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||||
|
FARF(HIGH, "hvx_sub_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
unaligned_addr = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||||
|
unaligned_loop = 1;
|
||||||
|
FARF(HIGH, "hvx_sub_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
|
||||||
|
|
||||||
|
if (0 == unaligned_loop) {
|
||||||
|
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
|
||||||
|
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, val_vec);
|
||||||
|
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec);
|
||||||
|
|
||||||
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * srcf = (const float *) src + num_elems_whole;
|
||||||
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec);
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) {
|
||||||
|
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
int num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
if (0 == htp_is_aligned((void *) src, VLEN)) {
|
||||||
|
FARF(HIGH, "hvx_sum_of_squares_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
|
||||||
|
|
||||||
|
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
|
||||||
|
|
||||||
|
HVX_Vector sum_vec_acc = Q6_V_vsplat_R(0x00000000);
|
||||||
|
HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000);
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1, *vec_in1);
|
||||||
|
sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, v);
|
||||||
|
vec_in1++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * srcf = (const float *) src + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector vec_left = *(HVX_UVector *) srcf;
|
||||||
|
|
||||||
|
HVX_Vector vec_left_sq = Q6_Vqf32_vmpy_VsfVsf(vec_left, vec_left);
|
||||||
|
HVX_Vector vec_tmp = Q6_V_valign_VVR(vec_left_sq, zero_vec, left_over * SIZEOF_FP32);
|
||||||
|
|
||||||
|
sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, vec_tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec_acc);
|
||||||
|
return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
|
||||||
|
}
|
||||||
|
|
||||||
|
float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) {
|
||||||
|
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
int num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
int unaligned_addr = 0;
|
||||||
|
int unaligned_loop = 0;
|
||||||
|
if (0 == htp_is_aligned((void *) src, VLEN)) {
|
||||||
|
FARF(HIGH, "hvx_self_sum_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
unaligned_addr = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||||
|
unaligned_loop = 1;
|
||||||
|
FARF(HIGH, "hvx_self_sum_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000);
|
||||||
|
HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000);
|
||||||
|
|
||||||
|
if (0 == unaligned_loop) {
|
||||||
|
HVX_Vector * vec_in = (HVX_Vector *) src;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
// sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, *vec_in++);
|
||||||
|
sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), *vec_in++);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||||
|
|
||||||
|
sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * srcf = (const float *) src + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector vec_left = *(HVX_UVector *) srcf;
|
||||||
|
HVX_Vector vec_tmp = Q6_V_valign_VVR(vec_left, zero_vec, left_over * SIZEOF_FP32);
|
||||||
|
// sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, vec_tmp);
|
||||||
|
sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), vec_tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec);
|
||||||
|
return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) {
|
||||||
|
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
int num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
int unaligned_addr = 0;
|
||||||
|
int unaligned_loop = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||||
|
FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
unaligned_addr = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||||
|
unaligned_loop = 1;
|
||||||
|
FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
|
||||||
|
|
||||||
|
if (0 == unaligned_loop) {
|
||||||
|
HVX_Vector * vec_in1 = (HVX_Vector *) src;
|
||||||
|
HVX_Vector * vec_out = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec);
|
||||||
|
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
|
||||||
|
|
||||||
|
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * srcf = (const float *) src + num_elems_whole;
|
||||||
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||||
|
|
||||||
|
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
|
||||||
|
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
int num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
int unaligned_addr = 0;
|
||||||
|
int unaligned_loop = 0;
|
||||||
|
if (0 == htp_is_aligned((void *) src, VLEN)) {
|
||||||
|
FARF(HIGH, "hvx_self_max_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
unaligned_addr = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||||
|
unaligned_loop = 1;
|
||||||
|
FARF(HIGH, "hvx_self_max_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector vec_max = hvx_vec_splat_fp32(((const float *) src)[0]);
|
||||||
|
HVX_Vector vec_first = hvx_vec_splat_fp32(((const float *) src)[0]);
|
||||||
|
|
||||||
|
if (0 == unaligned_loop) {
|
||||||
|
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, *vec_in++);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||||
|
|
||||||
|
vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * srcf = (const float *) src + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||||
|
|
||||||
|
HVX_Vector temp = Q6_V_valign_VVR(in, vec_first, left_over * SIZEOF_FP32);
|
||||||
|
vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector v = hvx_vec_reduce_max_fp32(vec_max);
|
||||||
|
return hvx_vec_get_fp32(v);
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
|
||||||
|
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
size_t num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||||
|
FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
|
||||||
|
|
||||||
|
const float * src_f = (const float *) src;
|
||||||
|
|
||||||
|
HVX_Vector vec_min = Q6_V_vsplat_R(val);
|
||||||
|
|
||||||
|
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
|
||||||
|
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
|
||||||
|
*vec_out++ = Q6_Vsf_equals_Vqf32(vec_min);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * srcf = (const float *) src + num_elems_whole;
|
||||||
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||||
|
|
||||||
|
vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, in);
|
||||||
|
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(vec_min));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void hvx_clamp_scalar_f32(const uint8_t * restrict src,
|
||||||
|
const float limit_left,
|
||||||
|
const float limit_right,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems) {
|
||||||
|
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||||
|
size_t num_elems_whole = num_elems - left_over;
|
||||||
|
|
||||||
|
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||||
|
FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
|
||||||
|
|
||||||
|
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
|
||||||
|
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
HVX_Vector range_left = hvx_vec_splat_fp32(limit_left);
|
||||||
|
HVX_Vector range_right = hvx_vec_splat_fp32(limit_right);
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||||
|
HVX_Vector in_vec = *vec_in++;
|
||||||
|
HVX_Vector temp_v = in_vec;
|
||||||
|
|
||||||
|
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
|
||||||
|
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
|
||||||
|
|
||||||
|
in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
|
||||||
|
in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
|
||||||
|
|
||||||
|
*vec_out++ = Q6_Vsf_equals_Vqf32(in_vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (left_over > 0) {
|
||||||
|
const float * srcf = (const float *) src + num_elems_whole;
|
||||||
|
float * dstf = (float *) dst + num_elems_whole;
|
||||||
|
|
||||||
|
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||||
|
|
||||||
|
HVX_Vector temp_v = in;
|
||||||
|
|
||||||
|
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in, range_right);
|
||||||
|
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in);
|
||||||
|
|
||||||
|
in = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
|
||||||
|
in = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
|
||||||
|
|
||||||
|
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(in));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,998 @@
|
||||||
|
#ifndef HVX_UTILS_H
|
||||||
|
#define HVX_UTILS_H
|
||||||
|
|
||||||
|
#include "ops-utils.h"
|
||||||
|
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#define SIZEOF_FP32 (4)
|
||||||
|
#define SIZEOF_FP16 (2)
|
||||||
|
#define VLEN (128)
|
||||||
|
#define VLEN_FP32 (VLEN / SIZEOF_FP32)
|
||||||
|
#define VLEN_FP16 (VLEN / SIZEOF_FP16)
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_splat_fp32(float i) {
|
||||||
|
union {
|
||||||
|
float f;
|
||||||
|
int32_t i;
|
||||||
|
} fp32 = { .f = i };
|
||||||
|
|
||||||
|
return Q6_V_vsplat_R(fp32.i);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) {
|
||||||
|
// Rotate as needed.
|
||||||
|
v = Q6_V_vlalign_VVR(v, v, (size_t) addr);
|
||||||
|
|
||||||
|
uint32_t left_off = (size_t) addr & 127;
|
||||||
|
uint32_t right_off = left_off + n;
|
||||||
|
|
||||||
|
HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) addr);
|
||||||
|
HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off);
|
||||||
|
|
||||||
|
if (right_off > 128) {
|
||||||
|
Q6_vmem_QRIV(qr, (HVX_Vector *) addr + 1, v);
|
||||||
|
// all 1's
|
||||||
|
qr = Q6_Q_vcmp_eq_VbVb(v, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
ql_not = Q6_Q_or_QQn(ql_not, qr);
|
||||||
|
Q6_vmem_QnRIV(ql_not, (HVX_Vector *) addr, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void hvx_vec_store_a(void * ptr, size_t n, HVX_Vector v) {
|
||||||
|
assert((unsigned long) ptr % 128 == 0);
|
||||||
|
|
||||||
|
HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) ptr);
|
||||||
|
HVX_VectorPred qr = Q6_Q_vsetq2_R(n);
|
||||||
|
ql_not = Q6_Q_or_QQn(ql_not, qr);
|
||||||
|
Q6_vmem_QnRIV(ql_not, (HVX_Vector *) ptr, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) {
|
||||||
|
// vdelta control to replicate first 4 bytes across all elements
|
||||||
|
static const uint8_t __attribute__((aligned(128))) repl[128] = {
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||||
|
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||||
|
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||||
|
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||||
|
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||||
|
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||||
|
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||||
|
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||||
|
};
|
||||||
|
|
||||||
|
HVX_Vector ctrl = *(HVX_Vector *) repl;
|
||||||
|
return Q6_V_vdelta_VV(v, ctrl);
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy n fp16 elements : source and destination are aligned to HVX Vector (128)
|
||||||
|
static inline void hvx_copy_fp16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||||
|
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||||
|
HVX_Vector * restrict vsrc = (HVX_Vector *) src;
|
||||||
|
|
||||||
|
assert((unsigned long) dst % 128 == 0);
|
||||||
|
assert((unsigned long) src % 128 == 0);
|
||||||
|
|
||||||
|
uint32_t nvec = n / 64;
|
||||||
|
uint32_t nloe = n % 64;
|
||||||
|
|
||||||
|
uint32_t i = 0;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (; i < nvec; i++) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
vdst[i] = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nloe) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy n fp16 elements : source is aligned, destination is potentially unaligned
|
||||||
|
static inline void hvx_copy_fp16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||||
|
HVX_UVector * restrict vdst = (HVX_UVector *) dst;
|
||||||
|
HVX_Vector * restrict vsrc = (HVX_Vector *) src;
|
||||||
|
|
||||||
|
assert((unsigned long) src % 128 == 0);
|
||||||
|
|
||||||
|
uint32_t nvec = n / 64;
|
||||||
|
uint32_t nloe = n % 64;
|
||||||
|
|
||||||
|
uint32_t i = 0;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (; i < nvec; i++) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
vdst[i] = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nloe) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy n fp16 elements : source is aligned, destination is potentially unaligned
|
||||||
|
static inline void hvx_copy_fp16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||||
|
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||||
|
HVX_UVector * restrict vsrc = (HVX_UVector *) src;
|
||||||
|
|
||||||
|
assert((unsigned long) dst % 128 == 0);
|
||||||
|
|
||||||
|
uint32_t nvec = n / 64;
|
||||||
|
uint32_t nloe = n % 64;
|
||||||
|
|
||||||
|
uint32_t i = 0;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (; i < nvec; i++) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
vdst[i] = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nloe) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy n fp32 elements : source and destination are aligned to HVX Vector (128)
|
||||||
|
static inline void hvx_copy_fp32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||||
|
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||||
|
HVX_Vector * restrict vsrc = (HVX_Vector *) src;
|
||||||
|
|
||||||
|
assert((unsigned long) dst % 128 == 0);
|
||||||
|
assert((unsigned long) src % 128 == 0);
|
||||||
|
|
||||||
|
uint32_t nvec = n / 32;
|
||||||
|
uint32_t nloe = n % 32;
|
||||||
|
|
||||||
|
uint32_t i = 0;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (; i < nvec; i++) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
vdst[i] = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nloe) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy n fp32 elements : source is aligned, destination is unaligned
|
||||||
|
static inline void hvx_copy_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||||
|
HVX_UVector * restrict vdst = (HVX_UVector *) dst;
|
||||||
|
HVX_Vector * restrict vsrc = (HVX_Vector *) src;
|
||||||
|
|
||||||
|
assert((unsigned long) src % 128 == 0);
|
||||||
|
|
||||||
|
uint32_t nvec = n / 32;
|
||||||
|
uint32_t nloe = n % 32;
|
||||||
|
|
||||||
|
uint32_t i = 0;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (; i < nvec; i++) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
vdst[i] = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nloe) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy n fp32 elements : source is unaligned, destination is aligned
|
||||||
|
static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||||
|
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||||
|
HVX_UVector * restrict vsrc = (HVX_UVector *) src;
|
||||||
|
|
||||||
|
assert((unsigned long) dst % 128 == 0);
|
||||||
|
|
||||||
|
uint32_t nvec = n / 32;
|
||||||
|
uint32_t nloe = n % 32;
|
||||||
|
|
||||||
|
uint32_t i = 0;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (; i < nvec; i++) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
vdst[i] = v;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nloe) {
|
||||||
|
HVX_Vector v = vsrc[i];
|
||||||
|
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned
|
||||||
|
static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) {
|
||||||
|
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
HVX_Vector velem = hvx_vec_splat_fp32(elem);
|
||||||
|
|
||||||
|
assert((unsigned long) dst % 128 == 0);
|
||||||
|
|
||||||
|
uint32_t nvec = n / 32;
|
||||||
|
uint32_t nloe = n % 32;
|
||||||
|
|
||||||
|
uint32_t i = 0;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (; i < nvec; i++) {
|
||||||
|
vdst[i] = velem;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (nloe) {
|
||||||
|
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), velem);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||||
|
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||||
|
uint32_t right_off = left_off + n;
|
||||||
|
return right_off <= chunk_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
|
||||||
|
union {
|
||||||
|
HVX_Vector v;
|
||||||
|
__fp16 d[64];
|
||||||
|
} u = { .v = v };
|
||||||
|
|
||||||
|
const uint32_t n0 = n / 16;
|
||||||
|
const uint32_t n1 = n % 16;
|
||||||
|
int i = 0;
|
||||||
|
for (; i < n0; i++) {
|
||||||
|
htp_dump_fp16_line(pref, u.d + (16 * i), 16);
|
||||||
|
}
|
||||||
|
if (n1) {
|
||||||
|
htp_dump_fp16_line(pref, u.d + (16 * i), n1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_vec_dump_fp16(char * pref, HVX_Vector v) {
|
||||||
|
hvx_vec_dump_fp16_n(pref, v, 64);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_vec_dump_fp32_n(char * pref, HVX_Vector v, uint32_t n) {
|
||||||
|
union {
|
||||||
|
HVX_Vector v;
|
||||||
|
float d[32];
|
||||||
|
} u = { .v = v };
|
||||||
|
|
||||||
|
const uint32_t n0 = n / 16;
|
||||||
|
const uint32_t n1 = n % 16;
|
||||||
|
int i = 0;
|
||||||
|
for (; i < n0; i++) {
|
||||||
|
htp_dump_fp32_line(pref, u.d + (16 * i), 16);
|
||||||
|
}
|
||||||
|
if (n1) {
|
||||||
|
htp_dump_fp32_line(pref, u.d + (16 * i), n1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_vec_dump_fp32_hmt(char * pref, HVX_Vector v) {
|
||||||
|
union {
|
||||||
|
HVX_Vector v;
|
||||||
|
float d[32];
|
||||||
|
} u = { .v = v };
|
||||||
|
|
||||||
|
FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1],
|
||||||
|
u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_vec_dump_fp32(char * pref, HVX_Vector v) {
|
||||||
|
hvx_vec_dump_fp32_n(pref, v, 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_vec_dump_int32(char * pref, HVX_Vector v) {
|
||||||
|
union {
|
||||||
|
HVX_Vector v;
|
||||||
|
int32_t d[32];
|
||||||
|
} u = { .v = v };
|
||||||
|
|
||||||
|
for (int i = 0; i < 32 / 16; i++) {
|
||||||
|
htp_dump_int32_line(pref, u.d + (16 * i), 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) {
|
||||||
|
union {
|
||||||
|
HVX_Vector v;
|
||||||
|
int32_t d[32];
|
||||||
|
} u = { .v = v };
|
||||||
|
|
||||||
|
FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12],
|
||||||
|
u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) {
|
||||||
|
union {
|
||||||
|
HVX_Vector v;
|
||||||
|
int8_t d[128];
|
||||||
|
} u = { .v = v };
|
||||||
|
|
||||||
|
FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60],
|
||||||
|
u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_vec_dump_int8(char * pref, HVX_Vector v) {
|
||||||
|
union {
|
||||||
|
HVX_Vector v;
|
||||||
|
int8_t d[128];
|
||||||
|
} u = { .v = v };
|
||||||
|
|
||||||
|
for (int i = 0; i < 128 / 16; i++) {
|
||||||
|
htp_dump_int8_line(pref, u.d + (16 * i), 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) {
|
||||||
|
union {
|
||||||
|
HVX_Vector v;
|
||||||
|
uint8_t d[128];
|
||||||
|
} u = { .v = v };
|
||||||
|
|
||||||
|
for (int i = 0; i < 128 / 16; i++) {
|
||||||
|
htp_dump_uint8_line(pref, u.d + (16 * i), 16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) {
|
||||||
|
typedef union {
|
||||||
|
HVX_Vector v;
|
||||||
|
int8_t d[128];
|
||||||
|
} U;
|
||||||
|
|
||||||
|
U u0 = { .v = v0 };
|
||||||
|
U u1 = { .v = v1 };
|
||||||
|
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
if (u0.d[i] != u1.d[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline float hvx_vec_get_fp32(HVX_Vector v) {
|
||||||
|
float __attribute__((aligned(128))) x;
|
||||||
|
hvx_vec_store_a(&x, 4, v);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_int32_reduce_sum_n(HVX_Vector in, unsigned int n) {
|
||||||
|
unsigned int total = n * 4; // total vec nbytes
|
||||||
|
unsigned int width = 4; // int32
|
||||||
|
|
||||||
|
HVX_Vector sum = in, sum_t;
|
||||||
|
while (width < total) {
|
||||||
|
sum_t = Q6_V_vror_VR(sum, width); // rotate right
|
||||||
|
sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum
|
||||||
|
width = width << 1;
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_int32_reduce_sum(HVX_Vector in) {
|
||||||
|
return hvx_vec_int32_reduce_sum_n(in, 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_qf32_reduce_sum_n(HVX_Vector in, unsigned int n) {
|
||||||
|
unsigned int total = n * 4; // total vec nbytes
|
||||||
|
unsigned int width = 4; // fp32 nbytes
|
||||||
|
|
||||||
|
HVX_Vector sum = in, sum_t;
|
||||||
|
while (width < total) {
|
||||||
|
sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right
|
||||||
|
sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum
|
||||||
|
width = width << 1;
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_qf32_reduce_sum(HVX_Vector in) {
|
||||||
|
return hvx_vec_qf32_reduce_sum_n(in, 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_fp32_reduce_sum_n(HVX_Vector in, unsigned int n) {
|
||||||
|
unsigned int total = n * 4; // total vec nbytes
|
||||||
|
unsigned int width = 4; // fp32 nbytes
|
||||||
|
|
||||||
|
HVX_Vector sum = in, sum_t;
|
||||||
|
while (width < total) {
|
||||||
|
sum_t = Q6_V_vror_VR(sum, width); // rotate right
|
||||||
|
sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
|
||||||
|
width = width << 1;
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_fp32_reduce_sum(HVX_Vector in) {
|
||||||
|
return hvx_vec_fp32_reduce_sum_n(in, 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_reduce_max_fp16(HVX_Vector in) {
|
||||||
|
unsigned total = 128; // total vec nbytes
|
||||||
|
unsigned width = 2; // fp16 nbytes
|
||||||
|
|
||||||
|
HVX_Vector _max = in, _max_t;
|
||||||
|
while (width < total) {
|
||||||
|
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||||
|
_max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
|
||||||
|
width = width << 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return _max;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_reduce_max2_fp16(HVX_Vector in, HVX_Vector _max) {
|
||||||
|
unsigned total = 128; // total vec nbytes
|
||||||
|
unsigned width = 2; // fp32 nbytes
|
||||||
|
|
||||||
|
HVX_Vector _max_t;
|
||||||
|
|
||||||
|
_max = Q6_Vhf_vmax_VhfVhf(in, _max);
|
||||||
|
while (width < total) {
|
||||||
|
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||||
|
_max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
|
||||||
|
width = width << 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return _max;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_reduce_max_fp32(HVX_Vector in) {
|
||||||
|
unsigned total = 128; // total vec nbytes
|
||||||
|
unsigned width = 4; // fp32 nbytes
|
||||||
|
|
||||||
|
HVX_Vector _max = in, _max_t;
|
||||||
|
while (width < total) {
|
||||||
|
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||||
|
_max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
|
||||||
|
width = width << 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return _max;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_reduce_max2_fp32(HVX_Vector in, HVX_Vector _max) {
|
||||||
|
unsigned total = 128; // total vec nbytes
|
||||||
|
unsigned width = 4; // fp32 nbytes
|
||||||
|
|
||||||
|
HVX_Vector _max_t;
|
||||||
|
|
||||||
|
_max = Q6_Vsf_vmax_VsfVsf(in, _max);
|
||||||
|
while (width < total) {
|
||||||
|
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||||
|
_max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
|
||||||
|
width = width << 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return _max;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_abs_fp16(HVX_Vector v) {
|
||||||
|
// abs by clearing the fp16 sign bit
|
||||||
|
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
|
||||||
|
return Q6_V_vand_VV(v, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) {
|
||||||
|
// neg by setting the fp16 sign bit
|
||||||
|
HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
|
||||||
|
return Q6_V_vor_VV(v, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
|
||||||
|
// abs by clearing the fp32 sign bit
|
||||||
|
HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff);
|
||||||
|
return Q6_V_vand_VV(v, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
|
||||||
|
#if __HTP_ARCH__ > 75
|
||||||
|
return Q6_Vsf_vfneg_Vsf(v);
|
||||||
|
#else
|
||||||
|
// neg by setting the fp32 sign bit
|
||||||
|
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
|
||||||
|
return Q6_V_vor_VV(v, mask);
|
||||||
|
#endif // __HTP_ARCH__ > 75
|
||||||
|
}
|
||||||
|
|
||||||
|
// ====================================================
|
||||||
|
// FUNCTION: 1/(x+1) y(0) = 1, y(0.5) = 0.6667, y(1) = 0.5
|
||||||
|
// Order:3; continuity: True; Ends forced: True
|
||||||
|
// Mode: unsigned; Result fractional bits: 14
|
||||||
|
// Peak Error: 1.1295e-04 Rms Error: 2.8410e-05 Mean Error: 1.1370e-05
|
||||||
|
// 32769 -32706 31252 -10589
|
||||||
|
// 32590 -30635 22793 -4493
|
||||||
|
// 32066 -27505 16481 -2348
|
||||||
|
// 31205 -24054 11849 -1306
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) {
|
||||||
|
// input is 0..0xffff representing 0.0 .. 1.0
|
||||||
|
HVX_Vector p;
|
||||||
|
p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull);
|
||||||
|
p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull);
|
||||||
|
p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull);
|
||||||
|
p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull);
|
||||||
|
return p; // signed result, 14 fractional bits
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find reciprocal of fp16.
|
||||||
|
// (1) first, convert to fp32, multiplying by 1.0; this is done to
|
||||||
|
// handle denormals. Ignoring sign and zero, result should be at
|
||||||
|
// least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000)
|
||||||
|
// (exponent in range [103,143])
|
||||||
|
// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly
|
||||||
|
// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32
|
||||||
|
// (4) convert that to fp16
|
||||||
|
// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace
|
||||||
|
// the result with the max value.
|
||||||
|
static inline HVX_Vector hvx_vec_inverse_fp16(HVX_Vector vals) {
|
||||||
|
HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF);
|
||||||
|
HVX_Vector avals = Q6_V_vand_VV(vals, em_mask);
|
||||||
|
HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(avals, vals);
|
||||||
|
// is too small to 1/x ? for 'standard' fp16, this would be 0x101
|
||||||
|
HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals);
|
||||||
|
|
||||||
|
HVX_VectorPair to_qf32 = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00)); // *1.0
|
||||||
|
HVX_Vector to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32));
|
||||||
|
HVX_Vector to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32));
|
||||||
|
|
||||||
|
// bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector
|
||||||
|
HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9));
|
||||||
|
// likewise extract the upper 16 from each, containing the exponents in range 103..142
|
||||||
|
HVX_Vector exp_u16 = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0);
|
||||||
|
//Get exponent in IEEE 32-bit representation
|
||||||
|
exp_u16 = Q6_Vuh_vlsr_VuhR(exp_u16, 7);
|
||||||
|
|
||||||
|
// so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane
|
||||||
|
// We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0)
|
||||||
|
// Use poly to transform to 1/x, with 14 fractional bits
|
||||||
|
//
|
||||||
|
HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16);
|
||||||
|
|
||||||
|
HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros
|
||||||
|
|
||||||
|
// Get mantissa for 16-bit represenation
|
||||||
|
HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF));
|
||||||
|
|
||||||
|
//Compute Reciprocal Exponent
|
||||||
|
HVX_Vector exp_recip =
|
||||||
|
Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1)));
|
||||||
|
//Convert it for 16-bit representation
|
||||||
|
exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15));
|
||||||
|
exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10);
|
||||||
|
|
||||||
|
//Merge exponent and mantissa for reciprocal
|
||||||
|
HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip);
|
||||||
|
// map 'small' inputs to standard largest value 0x7bff
|
||||||
|
recip = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip);
|
||||||
|
// add sign back
|
||||||
|
recip = Q6_V_vandor_VQR(recip, is_neg, 0x80008000);
|
||||||
|
return recip;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define IEEE_VSF_EXPLEN (8)
|
||||||
|
#define IEEE_VSF_EXPBIAS (127)
|
||||||
|
#define IEEE_VSF_EXPMASK (0xFF)
|
||||||
|
#define IEEE_VSF_MANTLEN (23)
|
||||||
|
#define IEEE_VSF_MANTMASK (0x7FFFFF)
|
||||||
|
#define IEEE_VSF_MIMPMASK (0x800000)
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_truncate_fp32(HVX_Vector in_vec) {
|
||||||
|
HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
|
||||||
|
HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
|
||||||
|
HVX_Vector const_zero_v = Q6_V_vzero();
|
||||||
|
|
||||||
|
HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
|
||||||
|
|
||||||
|
HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
|
||||||
|
expval_v &= IEEE_VSF_EXPMASK;
|
||||||
|
expval_v -= IEEE_VSF_EXPBIAS;
|
||||||
|
|
||||||
|
// negative exp == fractional value
|
||||||
|
HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
|
||||||
|
|
||||||
|
HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift
|
||||||
|
|
||||||
|
HVX_Vector mant_v = in_vec & mask_mant_v; // obtain mantissa
|
||||||
|
HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0
|
||||||
|
|
||||||
|
vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer
|
||||||
|
vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0
|
||||||
|
|
||||||
|
HVX_Vector neg_vout = -vout;
|
||||||
|
|
||||||
|
vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives
|
||||||
|
|
||||||
|
return (vout);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_floor_fp32(HVX_Vector in_vec) {
|
||||||
|
HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
|
||||||
|
HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
|
||||||
|
HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN);
|
||||||
|
HVX_Vector const_zero_v = Q6_V_vzero();
|
||||||
|
HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf
|
||||||
|
|
||||||
|
HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
|
||||||
|
|
||||||
|
HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
|
||||||
|
expval_v &= IEEE_VSF_EXPMASK;
|
||||||
|
expval_v -= IEEE_VSF_EXPBIAS;
|
||||||
|
|
||||||
|
HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
|
||||||
|
HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v);
|
||||||
|
HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v);
|
||||||
|
HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec);
|
||||||
|
|
||||||
|
// if expval < 0 (q_negexp) // <0, floor is 0
|
||||||
|
// if vin > 0
|
||||||
|
// floor = 0
|
||||||
|
// if vin < 0
|
||||||
|
// floor = -1
|
||||||
|
// if expval < mant_len (q_expltmn) // >0, but fraction may exist
|
||||||
|
// get sign (q_negative)
|
||||||
|
// mask >> expval // fraction bits to mask off
|
||||||
|
// vout = ~(mask) // apply mask to remove fraction
|
||||||
|
// if (qneg) // negative floor is one less (more, sign bit for neg)
|
||||||
|
// vout += ((impl_mask) >> expval)
|
||||||
|
// if (mask && vin)
|
||||||
|
// vout = vin
|
||||||
|
// else // already an integer
|
||||||
|
// ; // no change
|
||||||
|
|
||||||
|
// compute floor
|
||||||
|
mask_mant_v >>= expval_v;
|
||||||
|
HVX_Vector neg_addin_v = mask_impl_v >> expval_v;
|
||||||
|
HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v);
|
||||||
|
HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec);
|
||||||
|
|
||||||
|
HVX_Vector mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v); // chk if bits set
|
||||||
|
HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v);
|
||||||
|
|
||||||
|
HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear
|
||||||
|
HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits
|
||||||
|
|
||||||
|
vout = in_vec;
|
||||||
|
vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval<mant
|
||||||
|
vout = Q6_V_vmux_QVV(q_integral, in_vec, vout); // integral values
|
||||||
|
vout = Q6_V_vmux_QVV(q_negexp_pos, const_zero_v, vout); // expval<0 x>0 -> 0
|
||||||
|
vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1
|
||||||
|
|
||||||
|
return vout;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
|
||||||
|
// This looks complicated.
|
||||||
|
// Ideally should just be Q6_Vh_equals_Vhf(vin)
|
||||||
|
// but that instruction does not do proper rounding.
|
||||||
|
|
||||||
|
// convert to qf32, multiplying by 1.0 in the process.
|
||||||
|
HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00));
|
||||||
|
|
||||||
|
// 'in-range' values are +/32752.
|
||||||
|
// add 192K to it, convert to sf
|
||||||
|
HVX_Vector v192K = Q6_V_vsplat_R(0x48400000);
|
||||||
|
HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K));
|
||||||
|
HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K));
|
||||||
|
|
||||||
|
// for in-range cases, result is {163858... 229360} so the exponent is always 144.
|
||||||
|
// if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer.
|
||||||
|
// Start by <<10 to get the final 'sign' bit in bit 15...
|
||||||
|
vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10);
|
||||||
|
vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10);
|
||||||
|
|
||||||
|
// now round down to 16
|
||||||
|
return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) {
|
||||||
|
HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3);
|
||||||
|
HVX_Vector two_sf = hvx_vec_splat_fp32(2.0);
|
||||||
|
|
||||||
|
// First approximation
|
||||||
|
HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf);
|
||||||
|
|
||||||
|
HVX_Vector r_qf;
|
||||||
|
|
||||||
|
// Refine
|
||||||
|
r_qf = Q6_Vqf32_vmpy_VsfVsf(
|
||||||
|
i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf)))));
|
||||||
|
r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
|
||||||
|
r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
|
||||||
|
r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
|
||||||
|
r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
|
||||||
|
|
||||||
|
return Q6_Vsf_equals_Vqf32(r_qf);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
|
||||||
|
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
|
||||||
|
#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267
|
||||||
|
#define FAST_SIGMOID_C3 (0x3f000000) // 0.5
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_fast_sigmoid_fp32(HVX_Vector v) {
|
||||||
|
v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
|
||||||
|
v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3));
|
||||||
|
|
||||||
|
HVX_Vector in_int = hvx_vec_truncate_fp32(Q6_Vsf_equals_Vqf32(v));
|
||||||
|
HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int));
|
||||||
|
HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x);
|
||||||
|
|
||||||
|
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2));
|
||||||
|
v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
|
||||||
|
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1));
|
||||||
|
v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx);
|
||||||
|
v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x);
|
||||||
|
|
||||||
|
HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1));
|
||||||
|
HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1);
|
||||||
|
v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24);
|
||||||
|
v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent);
|
||||||
|
v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24);
|
||||||
|
|
||||||
|
HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1));
|
||||||
|
HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4));
|
||||||
|
|
||||||
|
HVX_Vector res = hvx_vec_inverse_fp32(v5);
|
||||||
|
res = Q6_Vqf32_vmpy_VsfVsf(v3, res);
|
||||||
|
|
||||||
|
return Q6_Vsf_equals_Vqf32(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!)
|
||||||
|
#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!)
|
||||||
|
#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!)
|
||||||
|
#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!)
|
||||||
|
#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!)
|
||||||
|
#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
|
||||||
|
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
|
||||||
|
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
|
||||||
|
#define EXP_ONE (0x3f800000) // 1.0
|
||||||
|
#define EXP_RANGE_R (0x41a00000) // 20.0
|
||||||
|
#define EXP_RANGE_L (0xc1a00000) // -20.0
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_exp_fp32(HVX_Vector in_vec) {
|
||||||
|
HVX_Vector z_qf32_v;
|
||||||
|
HVX_Vector x_v;
|
||||||
|
HVX_Vector x_qf32_v;
|
||||||
|
HVX_Vector y_v;
|
||||||
|
HVX_Vector k_v;
|
||||||
|
HVX_Vector f_v;
|
||||||
|
HVX_Vector epsilon_v;
|
||||||
|
HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E);
|
||||||
|
HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2);
|
||||||
|
HVX_Vector E_const;
|
||||||
|
HVX_Vector zero_v = Q6_V_vzero();
|
||||||
|
|
||||||
|
// exp(x) is approximated as follows:
|
||||||
|
// f = floor(x/ln(2)) = floor(x*log2(e))
|
||||||
|
// epsilon = x - f*ln(2)
|
||||||
|
// exp(x) = exp(epsilon+f*ln(2))
|
||||||
|
// = exp(epsilon)*exp(f*ln(2))
|
||||||
|
// = exp(epsilon)*2^f
|
||||||
|
//
|
||||||
|
// Since epsilon is close to zero, it can be approximated with its Taylor series:
|
||||||
|
// exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+...
|
||||||
|
// Preserving the first eight elements, we get:
|
||||||
|
// exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7
|
||||||
|
// = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2
|
||||||
|
|
||||||
|
HVX_Vector temp_v = in_vec;
|
||||||
|
|
||||||
|
// Clamp inputs to (-20.0, 20.0)
|
||||||
|
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
|
||||||
|
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
|
||||||
|
|
||||||
|
in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
|
||||||
|
in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
|
||||||
|
|
||||||
|
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
|
||||||
|
epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
|
||||||
|
|
||||||
|
// f_v is the floating point result and k_v is the integer result
|
||||||
|
f_v = hvx_vec_floor_fp32(epsilon_v);
|
||||||
|
k_v = hvx_vec_truncate_fp32(f_v);
|
||||||
|
|
||||||
|
x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v);
|
||||||
|
|
||||||
|
// x = x - f_v * logn2;
|
||||||
|
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2);
|
||||||
|
x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v);
|
||||||
|
// normalize before every QFloat's vmpy
|
||||||
|
x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
|
||||||
|
|
||||||
|
// z = x * x;
|
||||||
|
z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
|
||||||
|
z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
|
||||||
|
|
||||||
|
x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
|
||||||
|
|
||||||
|
// y = E4 + E5 * x;
|
||||||
|
E_const = Q6_V_vsplat_R(EXP_COEFF_5);
|
||||||
|
y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
|
||||||
|
E_const = Q6_V_vsplat_R(EXP_COEFF_4);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||||
|
|
||||||
|
// y = E3 + y * x;
|
||||||
|
E_const = Q6_V_vsplat_R(EXP_COEFF_3);
|
||||||
|
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||||
|
|
||||||
|
// y = E2 + y * x;
|
||||||
|
E_const = Q6_V_vsplat_R(EXP_COEFF_2);
|
||||||
|
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||||
|
|
||||||
|
// y = E1 + y * x;
|
||||||
|
E_const = Q6_V_vsplat_R(EXP_COEFF_1);
|
||||||
|
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||||
|
|
||||||
|
// y = E0 + y * x;
|
||||||
|
E_const = Q6_V_vsplat_R(EXP_COEFF_0);
|
||||||
|
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||||
|
|
||||||
|
// y = x + y * z;
|
||||||
|
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v);
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||||
|
|
||||||
|
// y = y + 1.0;
|
||||||
|
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE));
|
||||||
|
|
||||||
|
// insert exponents
|
||||||
|
// y = ldexpf(y, k);
|
||||||
|
// y_v += k_v; // qf32
|
||||||
|
// modify exponent
|
||||||
|
|
||||||
|
y_v = Q6_Vsf_equals_Vqf32(y_v);
|
||||||
|
|
||||||
|
// add k_v to the exponent of y_v
|
||||||
|
HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1);
|
||||||
|
|
||||||
|
y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1);
|
||||||
|
y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent);
|
||||||
|
|
||||||
|
// exponent cannot be negative; if overflow is detected, result is set to zero
|
||||||
|
HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent);
|
||||||
|
|
||||||
|
y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN);
|
||||||
|
|
||||||
|
y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v);
|
||||||
|
|
||||||
|
return y_v;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation
|
||||||
|
#define RSQRT_ONE_HALF 0x3f000000 // 0.5
|
||||||
|
#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
|
||||||
|
|
||||||
|
static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) {
|
||||||
|
//Algorithm :
|
||||||
|
// x2 = input*0.5
|
||||||
|
// y = * (long *) &input
|
||||||
|
// y = 0x5f3759df - (y>>2)
|
||||||
|
// y = y*(threehalfs - x2*y*y)
|
||||||
|
|
||||||
|
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
|
||||||
|
HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF);
|
||||||
|
HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES);
|
||||||
|
|
||||||
|
HVX_Vector x2, y, ypower2, temp;
|
||||||
|
|
||||||
|
x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf);
|
||||||
|
x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero());
|
||||||
|
|
||||||
|
y = Q6_Vw_vasr_VwR(in_vec, 1);
|
||||||
|
y = Q6_Vw_vsub_VwVw(rsqrtconst, y);
|
||||||
|
|
||||||
|
// 1st iteration
|
||||||
|
ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y);
|
||||||
|
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
|
||||||
|
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
|
||||||
|
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
|
||||||
|
temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp));
|
||||||
|
|
||||||
|
// 2nd iteration
|
||||||
|
y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
|
||||||
|
ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
|
||||||
|
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
|
||||||
|
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
|
||||||
|
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
|
||||||
|
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
|
||||||
|
|
||||||
|
// 3rd iteration
|
||||||
|
y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
|
||||||
|
ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
|
||||||
|
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
|
||||||
|
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
|
||||||
|
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
|
||||||
|
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
|
||||||
|
|
||||||
|
return Q6_Vsf_equals_Vqf32(temp);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
||||||
|
int step_of_1 = num_elems >> 5;
|
||||||
|
int remaining = num_elems - step_of_1 * VLEN_FP32;
|
||||||
|
|
||||||
|
assert(remaining == 0);
|
||||||
|
|
||||||
|
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||||
|
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
v_dst[i] = hvx_vec_fast_sigmoid_fp32(v_src[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
|
||||||
|
void hvx_mul_f32(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems);
|
||||||
|
void hvx_mul_f32_opt(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems);
|
||||||
|
void hvx_mul_mul_f32_opt(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
const uint8_t * restrict src2,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems);
|
||||||
|
void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
|
||||||
|
void hvx_add_f32(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems);
|
||||||
|
void hvx_add_f32_opt(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems);
|
||||||
|
void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
|
||||||
|
void hvx_sub_f32(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems);
|
||||||
|
void hvx_sub_f32_opt(const uint8_t * restrict src0,
|
||||||
|
const uint8_t * restrict src1,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems);
|
||||||
|
void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
|
||||||
|
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale);
|
||||||
|
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
|
||||||
|
void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
|
||||||
|
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate);
|
||||||
|
float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems);
|
||||||
|
float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems);
|
||||||
|
void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
|
||||||
|
void hvx_clamp_scalar_f32(const uint8_t * restrict src,
|
||||||
|
const float limit_left,
|
||||||
|
const float limit_right,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems);
|
||||||
|
|
||||||
|
#endif /* HVX_UTILS_H */
|
||||||
|
|
@ -0,0 +1,945 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-function"
|
||||||
|
|
||||||
|
#define FARF_ERROR 1
|
||||||
|
#define FARF_HIGH 1
|
||||||
|
#define FARF_MEDIUM 0
|
||||||
|
#define FARF_LOW 0
|
||||||
|
#include <AEEStdErr.h>
|
||||||
|
#include <dspqueue.h>
|
||||||
|
#include <HAP_compute_res.h>
|
||||||
|
#include <HAP_etm_config.h>
|
||||||
|
#include <HAP_farf.h>
|
||||||
|
#include <HAP_mem.h>
|
||||||
|
#include <HAP_perf.h>
|
||||||
|
#include <HAP_power.h>
|
||||||
|
#include <HAP_ps.h>
|
||||||
|
#include <qurt.h>
|
||||||
|
#include <qurt_thread.h>
|
||||||
|
#include <remote.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#include "ggml-common.h"
|
||||||
|
#include "htp-ctx.h"
|
||||||
|
#include "htp-dma.h"
|
||||||
|
#include "htp-msg.h"
|
||||||
|
#include "htp-ops.h"
|
||||||
|
#include "ops-utils.h"
|
||||||
|
#include "worker-pool.h"
|
||||||
|
|
||||||
|
AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
|
||||||
|
struct htp_context * ctx;
|
||||||
|
int err = 0;
|
||||||
|
|
||||||
|
ctx = calloc(1, sizeof(*ctx));
|
||||||
|
if (ctx == NULL) {
|
||||||
|
return AEE_ENOMEMORY;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the context structure as a handle
|
||||||
|
*handle = (remote_handle64) ctx;
|
||||||
|
|
||||||
|
// Enable FARF logs
|
||||||
|
HAP_setFARFRuntimeLoggingParams(0xffff, NULL, 0);
|
||||||
|
|
||||||
|
// Set client class
|
||||||
|
{
|
||||||
|
HAP_power_request_t request;
|
||||||
|
memset(&request, 0, sizeof(HAP_power_request_t));
|
||||||
|
request.type = HAP_power_set_apptype;
|
||||||
|
request.apptype = HAP_POWER_COMPUTE_CLIENT_CLASS;
|
||||||
|
|
||||||
|
if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
HAP_power_request_t request;
|
||||||
|
memset(&request, 0, sizeof(request));
|
||||||
|
|
||||||
|
request.type = HAP_power_set_DCVS_v3;
|
||||||
|
request.dcvs_v3.set_dcvs_enable = TRUE;
|
||||||
|
request.dcvs_v3.dcvs_enable = TRUE;
|
||||||
|
request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE;
|
||||||
|
request.dcvs_v3.set_bus_params = TRUE;
|
||||||
|
request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX;
|
||||||
|
request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX;
|
||||||
|
request.dcvs_v3.bus_params.target_corner = HAP_DCVS_VCORNER_MAX;
|
||||||
|
request.dcvs_v3.set_core_params = TRUE;
|
||||||
|
request.dcvs_v3.core_params.min_corner = HAP_DCVS_VCORNER_MAX;
|
||||||
|
request.dcvs_v3.core_params.max_corner = HAP_DCVS_VCORNER_MAX;
|
||||||
|
request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX;
|
||||||
|
request.dcvs_v3.set_sleep_disable = TRUE;
|
||||||
|
request.dcvs_v3.sleep_disable = TRUE;
|
||||||
|
if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
memset(&request, 0, sizeof(request));
|
||||||
|
request.type = HAP_power_set_HVX;
|
||||||
|
request.hvx.power_up = TRUE;
|
||||||
|
if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Power on HMX
|
||||||
|
HAP_power_request_t request;
|
||||||
|
memset(&request, 0, sizeof(HAP_power_request_t));
|
||||||
|
request.type = HAP_power_set_HMX;
|
||||||
|
request.hmx.power_up = TRUE;
|
||||||
|
FARF(ALWAYS, "Powering HMX on\n");
|
||||||
|
err = HAP_power_set((void *) &ctx, &request);
|
||||||
|
if (err != AEE_SUCCESS) {
|
||||||
|
FARF(ERROR, "Error powering on HMX.");
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return AEE_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
AEEResult htp_iface_close(remote_handle64 handle) {
|
||||||
|
struct htp_context * ctx = (struct htp_context *) handle;
|
||||||
|
|
||||||
|
if (!ctx) {
|
||||||
|
return AEE_EBADPARM;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->queue) {
|
||||||
|
FARF(ERROR, "Closing handle with queue still open");
|
||||||
|
return AEE_EITEMBUSY;
|
||||||
|
}
|
||||||
|
|
||||||
|
free(ctx);
|
||||||
|
return AEE_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
AEEResult htp_iface_enable_etm(remote_handle64 handle) {
|
||||||
|
int err = HAP_user_etm_enable();
|
||||||
|
if (err) {
|
||||||
|
if (err == AEE_EVERSIONNOTSUPPORT) {
|
||||||
|
FARF(ERROR, "API HAP_user_etm_enable is not supported\n");
|
||||||
|
} else {
|
||||||
|
FARF(ERROR, "Error executing HAP_user_etm_enable with error code : 0x%x\n", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
AEEResult htp_iface_disable_etm(remote_handle64 handle) {
|
||||||
|
int err = HAP_user_etm_disable();
|
||||||
|
if (err) {
|
||||||
|
if (err == AEE_EVERSIONNOTSUPPORT) {
|
||||||
|
FARF(ERROR, "API HAP_user_etm_disable is not supported\n");
|
||||||
|
} else {
|
||||||
|
FARF(ERROR, "Error executing HAP_user_etm_disable with error code : 0x%x\n", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int vtcm_acquire(struct htp_context * ctx) {
|
||||||
|
if (!ctx->vtcm_valid) {
|
||||||
|
// Temporarily bump thread priority to make sure it's higher than other sessions.
|
||||||
|
// This way the resource manager will notify the other thread to release VTCM.
|
||||||
|
// Note that we need to reaquire VTCM at normal priority for this to work next time.
|
||||||
|
qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10);
|
||||||
|
HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
|
||||||
|
HAP_compute_res_release_cached(ctx->vtcm_rctx);
|
||||||
|
qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio);
|
||||||
|
|
||||||
|
HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
|
||||||
|
ctx->vtcm_valid = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->vtcm_inuse = true;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int vtcm_release(struct htp_context * ctx) {
|
||||||
|
ctx->vtcm_inuse = false;
|
||||||
|
|
||||||
|
if (ctx->vtcm_valid && ctx->vtcm_needs_release) {
|
||||||
|
ctx->vtcm_valid = false;
|
||||||
|
ctx->vtcm_needs_release = false;
|
||||||
|
HAP_compute_res_release_cached(ctx->vtcm_rctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int vtcm_release_callback(unsigned int rctx, void * state) {
|
||||||
|
struct htp_context * ctx = (struct htp_context *) state;
|
||||||
|
|
||||||
|
if (!ctx || ctx->vtcm_rctx != rctx) {
|
||||||
|
return AEE_EBADPARM;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If VTCM is not inuse (not processing Ops) release it right here
|
||||||
|
// otherwise we'll release it once we're done with the current Op.
|
||||||
|
|
||||||
|
if (ctx->vtcm_inuse) {
|
||||||
|
ctx->vtcm_needs_release = false;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->vtcm_valid = false;
|
||||||
|
HAP_compute_res_release_cached(ctx->vtcm_rctx);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int vtcm_alloc(struct htp_context * ctx) {
|
||||||
|
unsigned int vtcm_size = 8 * 1024 * 1024; // 8MB default
|
||||||
|
HAP_compute_res_query_VTCM(0, &vtcm_size, NULL, NULL, NULL);
|
||||||
|
|
||||||
|
compute_res_attr_t attr;
|
||||||
|
HAP_compute_res_attr_init(&attr);
|
||||||
|
HAP_compute_res_attr_set_serialize(&attr, 0);
|
||||||
|
HAP_compute_res_attr_set_cache_mode(&attr, 1);
|
||||||
|
HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size);
|
||||||
|
HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx);
|
||||||
|
HAP_compute_res_attr_set_hmx_param(&attr, 1);
|
||||||
|
|
||||||
|
// Allocate VTCM for scratch pads
|
||||||
|
uint32_t rctx = HAP_compute_res_acquire(&attr, 1000000 /* timeout */);
|
||||||
|
if (!rctx) {
|
||||||
|
FARF(ERROR, "failed to allocate %zu bytes VTCM\n", ctx->vtcm_size);
|
||||||
|
return AEE_ENOMEMORY;
|
||||||
|
}
|
||||||
|
|
||||||
|
void * vtcm_ptr;
|
||||||
|
if (HAP_compute_res_attr_get_vtcm_ptr_v2(&attr, &vtcm_ptr, &vtcm_size) != 0) {
|
||||||
|
HAP_compute_res_release(rctx);
|
||||||
|
FARF(ERROR, "failed to allocate %zu bytes VTCM (new)\n", ctx->vtcm_size);
|
||||||
|
return AEE_ENOMEMORY;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->vtcm_base = (uint8_t *) vtcm_ptr;
|
||||||
|
ctx->vtcm_size = vtcm_size;
|
||||||
|
ctx->vtcm_rctx = rctx;
|
||||||
|
ctx->vtcm_valid = false;
|
||||||
|
ctx->vtcm_inuse = false;
|
||||||
|
ctx->vtcm_needs_release = false;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void vtcm_free(struct htp_context * ctx) {
|
||||||
|
if (ctx->vtcm_rctx) {
|
||||||
|
HAP_compute_res_release(ctx->vtcm_rctx);
|
||||||
|
ctx->vtcm_base = 0;
|
||||||
|
ctx->vtcm_rctx = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
|
||||||
|
static void htp_error_callback(dspqueue_t queue, int error, void * context);
|
||||||
|
|
||||||
|
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) {
|
||||||
|
struct htp_context * ctx = (struct htp_context *) handle;
|
||||||
|
|
||||||
|
if (!ctx) {
|
||||||
|
return AEE_EBADPARM;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->queue) {
|
||||||
|
FARF(ERROR, "Queue already open");
|
||||||
|
return AEE_EITEMBUSY;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Import queue created on the CPU
|
||||||
|
int err = dspqueue_import(dsp_queue_id, // Queue ID from dspqueue_export
|
||||||
|
htp_packet_callback, // Packet callback
|
||||||
|
htp_error_callback, // Error callback; no errors expected on the DSP
|
||||||
|
(void *) ctx, // Callback context
|
||||||
|
&ctx->queue);
|
||||||
|
|
||||||
|
if (err) {
|
||||||
|
FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->thread_id = qurt_thread_get_id();
|
||||||
|
ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id);
|
||||||
|
|
||||||
|
// allocate VTCM
|
||||||
|
err = vtcm_alloc(ctx);
|
||||||
|
if (err != AEE_SUCCESS) {
|
||||||
|
FARF(ERROR, "Unable to allocate VTCM");
|
||||||
|
return AEE_ENOMEMORY;
|
||||||
|
}
|
||||||
|
|
||||||
|
qurt_sysenv_max_hthreads_t hw_threads;
|
||||||
|
qurt_sysenv_get_max_hw_threads(&hw_threads);
|
||||||
|
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
|
||||||
|
|
||||||
|
if (n_hvx == 0) {
|
||||||
|
n_hvx = hw_nhvx;
|
||||||
|
}
|
||||||
|
if (n_hvx > hw_threads.max_hthreads) {
|
||||||
|
n_hvx = hw_threads.max_hthreads;
|
||||||
|
}
|
||||||
|
if (n_hvx > HTP_MAX_NTHREADS) {
|
||||||
|
n_hvx = HTP_MAX_NTHREADS;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->n_threads = n_hvx;
|
||||||
|
for (int i = 0; i < ctx->n_threads; i++) {
|
||||||
|
ctx->dma[i] = dma_queue_create(HTP_SPAD_SRC0_NROWS * 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// init worker pool
|
||||||
|
err = worker_pool_init(&ctx->worker_pool, n_hvx);
|
||||||
|
if (err != AEE_SUCCESS) {
|
||||||
|
FARF(ERROR, "Unable to create worker pool");
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
FARF(HIGH, "session %u started: n-hvx %u vtcm-size %zu vtcm-rctx %u n-threads %u thread-id %d thread-prio %d \n",
|
||||||
|
sess_id, hw_nhvx, ctx->vtcm_size, ctx->vtcm_rctx, ctx->n_threads, ctx->thread_id, ctx->thread_prio);
|
||||||
|
|
||||||
|
return AEE_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
AEEResult htp_iface_stop(remote_handle64 handle) {
|
||||||
|
struct htp_context * ctx = (struct htp_context *) handle;
|
||||||
|
if (!ctx) {
|
||||||
|
return AEE_EBADPARM;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ctx->queue) {
|
||||||
|
FARF(ERROR, "Queue not open");
|
||||||
|
return AEE_EBADSTATE;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close queue. dspqueue_close() will also wait for callbacks to finish.
|
||||||
|
int err = dspqueue_close(ctx->queue);
|
||||||
|
ctx->queue = NULL;
|
||||||
|
if (err != 0) {
|
||||||
|
FARF(ERROR, "Queue close failed with 0x%08x", (unsigned) err);
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->worker_pool) {
|
||||||
|
// Release worker pool
|
||||||
|
worker_pool_release(&ctx->worker_pool);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < ctx->n_threads; i++) {
|
||||||
|
dma_queue_delete(ctx->dma[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
vtcm_free(ctx);
|
||||||
|
|
||||||
|
return AEE_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void htp_error_callback(dspqueue_t queue, int error, void * context) {
|
||||||
|
// No errors expected on the DSP.
|
||||||
|
FARF(ERROR, "Error callback: 0x%08x", (unsigned) error);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct profile_data {
|
||||||
|
uint64_t usecs;
|
||||||
|
uint64_t cycles;
|
||||||
|
uint64_t pkts;
|
||||||
|
};
|
||||||
|
|
||||||
|
static inline void profile_start(struct profile_data * d) {
|
||||||
|
d->usecs = HAP_perf_get_qtimer_count();
|
||||||
|
d->cycles = htp_get_cycles();
|
||||||
|
d->pkts = htp_get_pktcnt();
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void profile_stop(struct profile_data * d) {
|
||||||
|
d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
|
||||||
|
d->cycles = htp_get_cycles() - d->cycles;
|
||||||
|
d->pkts = htp_get_pktcnt() - d->pkts;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int send_htp_rsp(struct htp_context * c,
|
||||||
|
uint32_t op,
|
||||||
|
uint32_t status,
|
||||||
|
struct dspqueue_buffer * bufs,
|
||||||
|
size_t n_bufs,
|
||||||
|
struct profile_data * prof) {
|
||||||
|
// Prep response struct
|
||||||
|
struct htp_general_rsp rsp;
|
||||||
|
rsp.op = op;
|
||||||
|
rsp.status = status;
|
||||||
|
rsp.prof_usecs = prof->usecs;
|
||||||
|
rsp.prof_cycles = prof->cycles;
|
||||||
|
rsp.prof_pkts = prof->pkts;
|
||||||
|
|
||||||
|
int err = dspqueue_write(c->queue,
|
||||||
|
0, // Flags
|
||||||
|
n_bufs,
|
||||||
|
bufs, // Buffer references
|
||||||
|
sizeof(rsp),
|
||||||
|
(const uint8_t *) &rsp, // Message
|
||||||
|
DSPQUEUE_TIMEOUT_NONE);
|
||||||
|
|
||||||
|
if (err != 0) {
|
||||||
|
FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err);
|
||||||
|
}
|
||||||
|
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void proc_matmul_req(struct htp_context * ctx,
|
||||||
|
struct htp_general_req * req,
|
||||||
|
struct dspqueue_buffer * bufs,
|
||||||
|
size_t n_bufs) {
|
||||||
|
// Prep response buffer structs (needed for error responses, etc)
|
||||||
|
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||||
|
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||||
|
rsp_bufs[0].fd = bufs[0].fd;
|
||||||
|
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||||
|
rsp_bufs[0].size = bufs[0].size;
|
||||||
|
rsp_bufs[0].offset = bufs[0].offset;
|
||||||
|
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
rsp_bufs[1].fd = bufs[1].fd;
|
||||||
|
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||||
|
rsp_bufs[1].size = bufs[1].size;
|
||||||
|
rsp_bufs[1].offset = bufs[1].offset;
|
||||||
|
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
// We had written to the output buffer, we'd also need to flush it
|
||||||
|
rsp_bufs[2].fd = bufs[2].fd;
|
||||||
|
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||||
|
rsp_bufs[2].size = bufs[2].size;
|
||||||
|
rsp_bufs[2].offset = bufs[2].offset;
|
||||||
|
rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||||
|
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||||
|
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||||
|
|
||||||
|
// Setup Op context
|
||||||
|
struct htp_ops_context octx = { 0 };
|
||||||
|
octx.ctx = ctx;
|
||||||
|
octx.src0 = req->src0;
|
||||||
|
octx.src1 = req->src1;
|
||||||
|
octx.dst = req->dst;
|
||||||
|
octx.flags = req->flags;
|
||||||
|
octx.op = req->op;
|
||||||
|
|
||||||
|
// Update data pointers
|
||||||
|
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||||
|
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||||
|
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||||
|
octx.n_threads = ctx->n_threads;
|
||||||
|
|
||||||
|
struct profile_data prof;
|
||||||
|
profile_start(&prof);
|
||||||
|
|
||||||
|
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||||
|
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||||
|
rsp_status = op_matmul(&octx);
|
||||||
|
vtcm_release(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
profile_stop(&prof);
|
||||||
|
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void proc_matmul_id_req(struct htp_context * ctx,
|
||||||
|
struct htp_general_req * req,
|
||||||
|
struct dspqueue_buffer * bufs,
|
||||||
|
size_t n_bufs) {
|
||||||
|
// Prep response buffer structs (needed for error responses, etc)
|
||||||
|
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||||
|
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||||
|
rsp_bufs[0].fd = bufs[0].fd;
|
||||||
|
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||||
|
rsp_bufs[0].size = bufs[0].size;
|
||||||
|
rsp_bufs[0].offset = bufs[0].offset;
|
||||||
|
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
rsp_bufs[1].fd = bufs[1].fd;
|
||||||
|
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||||
|
rsp_bufs[1].size = bufs[1].size;
|
||||||
|
rsp_bufs[1].offset = bufs[1].offset;
|
||||||
|
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
rsp_bufs[2].fd = bufs[2].fd;
|
||||||
|
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||||
|
rsp_bufs[2].size = bufs[2].size;
|
||||||
|
rsp_bufs[2].offset = bufs[2].offset;
|
||||||
|
rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
// We had written to the output buffer, we'd also need to flush it
|
||||||
|
rsp_bufs[3].fd = bufs[3].fd;
|
||||||
|
rsp_bufs[3].ptr = bufs[3].ptr;
|
||||||
|
rsp_bufs[3].size = bufs[3].size;
|
||||||
|
rsp_bufs[3].offset = bufs[3].offset;
|
||||||
|
rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||||
|
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||||
|
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||||
|
|
||||||
|
// Setup Op context
|
||||||
|
struct htp_ops_context octx = { 0 };
|
||||||
|
octx.ctx = ctx;
|
||||||
|
octx.src0 = req->src0;
|
||||||
|
octx.src1 = req->src1;
|
||||||
|
octx.src2 = req->src2;
|
||||||
|
octx.dst = req->dst;
|
||||||
|
octx.flags = req->flags;
|
||||||
|
octx.op = req->op;
|
||||||
|
|
||||||
|
// Update data pointers
|
||||||
|
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||||
|
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||||
|
octx.src2.data = (uint32_t) bufs[2].ptr;
|
||||||
|
octx.dst.data = (uint32_t) bufs[3].ptr;
|
||||||
|
octx.n_threads = ctx->n_threads;
|
||||||
|
|
||||||
|
struct profile_data prof;
|
||||||
|
profile_start(&prof);
|
||||||
|
|
||||||
|
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||||
|
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||||
|
rsp_status = op_matmul_id(&octx);
|
||||||
|
vtcm_release(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
profile_stop(&prof);
|
||||||
|
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||||
|
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||||
|
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||||
|
|
||||||
|
rsp_bufs[0].fd = bufs[0].fd;
|
||||||
|
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||||
|
rsp_bufs[0].offset = bufs[0].offset;
|
||||||
|
rsp_bufs[0].size = bufs[0].size;
|
||||||
|
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
rsp_bufs[1].fd = bufs[1].fd;
|
||||||
|
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||||
|
rsp_bufs[1].offset = bufs[1].offset;
|
||||||
|
rsp_bufs[1].size = bufs[1].size;
|
||||||
|
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
// We had written to the output buffer, we'd also need to flush it
|
||||||
|
rsp_bufs[2].fd = bufs[2].fd;
|
||||||
|
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||||
|
rsp_bufs[2].offset = bufs[2].offset;
|
||||||
|
rsp_bufs[2].size = bufs[2].size;
|
||||||
|
rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||||
|
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||||
|
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||||
|
|
||||||
|
// Setup Op context
|
||||||
|
struct htp_ops_context octx = { 0 };
|
||||||
|
octx.ctx = ctx;
|
||||||
|
octx.src0 = req->src0;
|
||||||
|
octx.src1 = req->src1;
|
||||||
|
octx.dst = req->dst;
|
||||||
|
octx.flags = req->flags;
|
||||||
|
octx.op = req->op;
|
||||||
|
|
||||||
|
// Update data pointers
|
||||||
|
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||||
|
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||||
|
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||||
|
octx.n_threads = ctx->n_threads;
|
||||||
|
|
||||||
|
struct profile_data prof;
|
||||||
|
profile_start(&prof);
|
||||||
|
|
||||||
|
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||||
|
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||||
|
rsp_status = op_binary(&octx);
|
||||||
|
vtcm_release(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
profile_stop(&prof);
|
||||||
|
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||||
|
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||||
|
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||||
|
|
||||||
|
rsp_bufs[0].fd = bufs[0].fd;
|
||||||
|
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||||
|
rsp_bufs[0].offset = bufs[0].offset;
|
||||||
|
rsp_bufs[0].size = bufs[0].size;
|
||||||
|
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
rsp_bufs[1].fd = bufs[1].fd;
|
||||||
|
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||||
|
rsp_bufs[1].offset = bufs[1].offset;
|
||||||
|
rsp_bufs[1].size = bufs[1].size;
|
||||||
|
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
rsp_bufs[2].fd = bufs[2].fd;
|
||||||
|
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||||
|
rsp_bufs[2].offset = bufs[2].offset;
|
||||||
|
rsp_bufs[2].size = bufs[2].size;
|
||||||
|
rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
// We had written to the output buffer, we'd also need to flush it
|
||||||
|
rsp_bufs[3].fd = bufs[3].fd;
|
||||||
|
rsp_bufs[3].ptr = bufs[3].ptr;
|
||||||
|
rsp_bufs[3].offset = bufs[3].offset;
|
||||||
|
rsp_bufs[3].size = bufs[3].size;
|
||||||
|
rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||||
|
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||||
|
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||||
|
|
||||||
|
// Setup Op context
|
||||||
|
struct htp_ops_context octx = { 0 };
|
||||||
|
octx.ctx = ctx;
|
||||||
|
octx.src0 = req->src0;
|
||||||
|
octx.src1 = req->src1;
|
||||||
|
octx.src2 = req->src2;
|
||||||
|
octx.dst = req->dst;
|
||||||
|
octx.flags = req->flags;
|
||||||
|
octx.op = req->op;
|
||||||
|
|
||||||
|
// Update data pointers
|
||||||
|
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||||
|
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||||
|
octx.src2.data = (uint32_t) bufs[2].ptr;
|
||||||
|
octx.dst.data = (uint32_t) bufs[3].ptr;
|
||||||
|
octx.n_threads = ctx->n_threads;
|
||||||
|
|
||||||
|
struct profile_data prof;
|
||||||
|
profile_start(&prof);
|
||||||
|
|
||||||
|
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||||
|
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||||
|
rsp_status = op_binary(&octx);
|
||||||
|
vtcm_release(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
profile_stop(&prof);
|
||||||
|
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||||
|
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||||
|
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||||
|
|
||||||
|
rsp_bufs[0].fd = bufs[0].fd;
|
||||||
|
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||||
|
rsp_bufs[0].offset = bufs[0].offset;
|
||||||
|
rsp_bufs[0].size = bufs[0].size;
|
||||||
|
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
// We had written to the output buffer, we'd also need to flush it
|
||||||
|
rsp_bufs[1].fd = bufs[1].fd;
|
||||||
|
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||||
|
rsp_bufs[1].offset = bufs[1].offset;
|
||||||
|
rsp_bufs[1].size = bufs[1].size;
|
||||||
|
rsp_bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||||
|
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||||
|
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||||
|
|
||||||
|
// Setup Op context
|
||||||
|
struct htp_ops_context octx = { 0 };
|
||||||
|
octx.ctx = ctx;
|
||||||
|
octx.src0 = req->src0;
|
||||||
|
octx.dst = req->dst;
|
||||||
|
octx.flags = req->flags;
|
||||||
|
octx.op = req->op;
|
||||||
|
|
||||||
|
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||||
|
|
||||||
|
// Update data pointers
|
||||||
|
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||||
|
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||||
|
octx.n_threads = ctx->n_threads;
|
||||||
|
|
||||||
|
struct profile_data prof;
|
||||||
|
profile_start(&prof);
|
||||||
|
|
||||||
|
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||||
|
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||||
|
rsp_status = op_unary(&octx);
|
||||||
|
vtcm_release(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
profile_stop(&prof);
|
||||||
|
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 2, &prof);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void proc_activations_req(struct htp_context * ctx,
|
||||||
|
struct htp_general_req * req,
|
||||||
|
struct dspqueue_buffer * bufs,
|
||||||
|
uint32_t n_bufs) {
|
||||||
|
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||||
|
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||||
|
|
||||||
|
rsp_bufs[0].fd = bufs[0].fd;
|
||||||
|
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||||
|
rsp_bufs[0].offset = bufs[0].offset;
|
||||||
|
rsp_bufs[0].size = bufs[0].size;
|
||||||
|
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
int write_idx = 1;
|
||||||
|
if (3 == n_bufs) {
|
||||||
|
rsp_bufs[1].fd = bufs[1].fd;
|
||||||
|
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||||
|
rsp_bufs[1].offset = bufs[1].offset;
|
||||||
|
rsp_bufs[1].size = bufs[1].size;
|
||||||
|
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
write_idx = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We had written to the output buffer, we'd also need to flush it
|
||||||
|
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||||
|
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||||
|
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||||
|
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||||
|
rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||||
|
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||||
|
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||||
|
|
||||||
|
// Setup Op context
|
||||||
|
struct htp_ops_context octx = { 0 };
|
||||||
|
octx.ctx = ctx;
|
||||||
|
octx.src0 = req->src0;
|
||||||
|
if (3 == n_bufs) {
|
||||||
|
octx.src1 = req->src1;
|
||||||
|
}
|
||||||
|
octx.dst = req->dst;
|
||||||
|
octx.flags = req->flags;
|
||||||
|
octx.op = req->op;
|
||||||
|
|
||||||
|
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||||
|
|
||||||
|
// Update data pointers
|
||||||
|
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||||
|
if (3 == n_bufs) {
|
||||||
|
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||||
|
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||||
|
} else {
|
||||||
|
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||||
|
}
|
||||||
|
octx.n_threads = ctx->n_threads;
|
||||||
|
|
||||||
|
struct profile_data prof;
|
||||||
|
profile_start(&prof);
|
||||||
|
|
||||||
|
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||||
|
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||||
|
if (octx.op == HTP_OP_SOFTMAX) {
|
||||||
|
rsp_status = op_softmax(&octx);
|
||||||
|
} else {
|
||||||
|
rsp_status = op_activations(&octx);
|
||||||
|
}
|
||||||
|
vtcm_release(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
profile_stop(&prof);
|
||||||
|
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void proc_rope_req(struct htp_context * ctx,
|
||||||
|
struct htp_general_req * req,
|
||||||
|
struct dspqueue_buffer * bufs,
|
||||||
|
uint32_t n_bufs) {
|
||||||
|
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||||
|
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||||
|
|
||||||
|
rsp_bufs[0].fd = bufs[0].fd;
|
||||||
|
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||||
|
rsp_bufs[0].offset = bufs[0].offset;
|
||||||
|
rsp_bufs[0].size = bufs[0].size;
|
||||||
|
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
rsp_bufs[1].fd = bufs[1].fd;
|
||||||
|
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||||
|
rsp_bufs[1].offset = bufs[1].offset;
|
||||||
|
rsp_bufs[1].size = bufs[1].size;
|
||||||
|
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
int write_idx = 2;
|
||||||
|
if (4 == n_bufs) {
|
||||||
|
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||||
|
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||||
|
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||||
|
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||||
|
rsp_bufs[write_idx].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||||
|
|
||||||
|
write_idx++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We had written to the output buffer, we'd also need to flush it
|
||||||
|
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||||
|
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||||
|
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||||
|
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||||
|
rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||||
|
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||||
|
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||||
|
|
||||||
|
// Setup Op context
|
||||||
|
struct htp_ops_context octx = { 0 };
|
||||||
|
octx.ctx = ctx;
|
||||||
|
octx.src0 = req->src0;
|
||||||
|
octx.src1 = req->src1;
|
||||||
|
if (4 == n_bufs) {
|
||||||
|
octx.src2 = req->src2;
|
||||||
|
}
|
||||||
|
octx.dst = req->dst;
|
||||||
|
octx.flags = req->flags;
|
||||||
|
octx.op = req->op;
|
||||||
|
|
||||||
|
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||||
|
|
||||||
|
// Update data pointers
|
||||||
|
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||||
|
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||||
|
if (4 == n_bufs) {
|
||||||
|
octx.src2.data = (uint32_t) bufs[2].ptr;
|
||||||
|
octx.dst.data = (uint32_t) bufs[3].ptr;
|
||||||
|
} else {
|
||||||
|
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||||
|
}
|
||||||
|
octx.n_threads = ctx->n_threads;
|
||||||
|
|
||||||
|
struct profile_data prof;
|
||||||
|
profile_start(&prof);
|
||||||
|
|
||||||
|
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||||
|
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||||
|
rsp_status = op_rope(&octx);
|
||||||
|
vtcm_release(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
profile_stop(&prof);
|
||||||
|
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||||
|
struct htp_context * ctx = (struct htp_context *) context;
|
||||||
|
|
||||||
|
// Repeatedly read packets from the queue until it's empty. We don't
|
||||||
|
// necessarily get a separate callback for each packet, and new packets
|
||||||
|
// may arrive while we're processing the previous one. This ensures we
|
||||||
|
// keep the DSP busy as much as possible and avoid waiting for the CPU.
|
||||||
|
|
||||||
|
while (1) {
|
||||||
|
struct htp_general_req req;
|
||||||
|
uint32_t req_size;
|
||||||
|
|
||||||
|
struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
|
||||||
|
uint32_t n_bufs;
|
||||||
|
uint32_t flags;
|
||||||
|
|
||||||
|
// Read packet from queue
|
||||||
|
int err = dspqueue_read_noblock(queue, &flags,
|
||||||
|
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
|
||||||
|
&n_bufs, // Number of buffer references
|
||||||
|
bufs, // Buffer references
|
||||||
|
sizeof(req), // Max message length
|
||||||
|
&req_size, // Message length
|
||||||
|
(uint8_t *) &req); // Message
|
||||||
|
|
||||||
|
if (err == AEE_EWOULDBLOCK) {
|
||||||
|
// Consumed all packets available for now
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (err != 0) {
|
||||||
|
FARF(ERROR, "dspqueue_read_noblock failed: 0x%08x", (unsigned) err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req_size != sizeof(req)) {
|
||||||
|
FARF(ERROR, "Invalid request size");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.flags & HTP_OPFLAGS_EARLY_WAKEUP) {
|
||||||
|
// Host wants early notification
|
||||||
|
dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process packet based on its message type
|
||||||
|
switch (req.op) {
|
||||||
|
case HTP_OP_MUL_MAT:
|
||||||
|
if (n_bufs != 3) {
|
||||||
|
FARF(ERROR, "Bad matmul-req buffer list");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
proc_matmul_req(ctx, &req, bufs, n_bufs);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_MUL_MAT_ID:
|
||||||
|
if (n_bufs != 4) {
|
||||||
|
FARF(ERROR, "Bad matmul-id-req buffer list");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
proc_matmul_id_req(ctx, &req, bufs, n_bufs);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_MUL:
|
||||||
|
case HTP_OP_ADD:
|
||||||
|
case HTP_OP_SUB:
|
||||||
|
if (n_bufs != 3) {
|
||||||
|
FARF(ERROR, "Bad binary-req buffer list");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
proc_binary_req(ctx, &req, bufs);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_RMS_NORM:
|
||||||
|
if (n_bufs != 2) {
|
||||||
|
FARF(ERROR, "Bad unary-req buffer list");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
proc_unary_req(ctx, &req, bufs);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_UNARY_SILU:
|
||||||
|
if (n_bufs != 2) {
|
||||||
|
FARF(ERROR, "Bad act-req buffer list");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
proc_activations_req(ctx, &req, bufs, n_bufs);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_GLU_SWIGLU:
|
||||||
|
case HTP_OP_SOFTMAX:
|
||||||
|
if ((n_bufs != 2) && (n_bufs != 3)) {
|
||||||
|
FARF(ERROR, "Bad act-req buffer list");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
proc_activations_req(ctx, &req, bufs, n_bufs);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_ADD_ID:
|
||||||
|
if (n_bufs != 4) {
|
||||||
|
FARF(ERROR, "Bad add-id-req buffer list");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
proc_add_id_req(ctx, &req, bufs);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case HTP_OP_ROPE:
|
||||||
|
if ((n_bufs != 3) && (n_bufs != 4)) {
|
||||||
|
FARF(ERROR, "Bad rope-req buffer list");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
proc_rope_req(ctx, &req, bufs, n_bufs);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
FARF(ERROR, "Unknown Op %u", req.op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,116 @@
|
||||||
|
#ifndef OPS_UTILS_H
|
||||||
|
#define OPS_UTILS_H
|
||||||
|
|
||||||
|
#include "htp-msg.h"
|
||||||
|
|
||||||
|
#ifndef MAX
|
||||||
|
# define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifndef MIN
|
||||||
|
# define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
|
#endif
|
||||||
|
|
||||||
|
static inline uint64_t htp_get_cycles() {
|
||||||
|
uint64_t cycles = 0;
|
||||||
|
asm volatile(" %0 = c15:14\n" : "=r"(cycles));
|
||||||
|
return cycles;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline uint64_t htp_get_pktcnt() {
|
||||||
|
uint64_t pktcnt;
|
||||||
|
asm volatile(" %0 = c19:18\n" : "=r"(pktcnt));
|
||||||
|
return pktcnt;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline int32_t htp_is_aligned(void * addr, uint32_t align) {
|
||||||
|
return ((size_t) addr & (align - 1)) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline uint32_t htp_round_up(uint32_t n, uint32_t m) {
|
||||||
|
return m * ((n + m - 1) / m);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) {
|
||||||
|
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||||
|
asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline int32_t htp_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||||
|
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||||
|
uint32_t right_off = left_off + n;
|
||||||
|
return right_off <= chunk_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void htp_dump_int8_line(char * pref, const int8_t * x, int n) {
|
||||||
|
char str[1024], *p = str;
|
||||||
|
p += sprintf(p, "%s: ", pref);
|
||||||
|
for (int i = 0; i < 16; i++) {
|
||||||
|
p += sprintf(p, "%d, ", x[i]);
|
||||||
|
}
|
||||||
|
FARF(HIGH, "%s\n", str);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void htp_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) {
|
||||||
|
char str[1024], *p = str;
|
||||||
|
p += sprintf(p, "%s: ", pref);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
p += sprintf(p, "%d, ", x[i]);
|
||||||
|
}
|
||||||
|
FARF(HIGH, "%s\n", str);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void htp_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
|
||||||
|
char str[1024], *p = str;
|
||||||
|
p += sprintf(p, "%s: ", pref);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
p += sprintf(p, "%d, ", (int) x[i]);
|
||||||
|
}
|
||||||
|
FARF(HIGH, "%s\n", str);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void htp_dump_fp16_line(char * pref, const __fp16 * x, uint32_t n) {
|
||||||
|
char str[1024], *p = str;
|
||||||
|
p += sprintf(p, "%s: ", pref);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
p += sprintf(p, "%.6f, ", (float) x[i]);
|
||||||
|
}
|
||||||
|
FARF(HIGH, "%s\n", str);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void htp_dump_fp32_line(char * pref, const float * x, uint32_t n) {
|
||||||
|
char str[1024], *p = str;
|
||||||
|
p += sprintf(p, "%s: ", pref);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
p += sprintf(p, "%.6f, ", x[i]);
|
||||||
|
}
|
||||||
|
FARF(HIGH, "%s\n", str);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void htp_dump_f32(char * pref, const float * x, uint32_t n) {
|
||||||
|
uint32_t n0 = n / 16;
|
||||||
|
uint32_t n1 = n % 16;
|
||||||
|
|
||||||
|
uint32_t i = 0;
|
||||||
|
for (; i < n0; i++) {
|
||||||
|
htp_dump_fp32_line(pref, x + (16 * i), 16);
|
||||||
|
}
|
||||||
|
if (n1) {
|
||||||
|
htp_dump_fp32_line(pref, x + (16 * i), n1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void htp_dump_f16(char * pref, const __fp16 * x, uint32_t n) {
|
||||||
|
uint32_t n0 = n / 16;
|
||||||
|
uint32_t n1 = n % 16;
|
||||||
|
|
||||||
|
uint32_t i = 0;
|
||||||
|
for (; i < n0; i++) {
|
||||||
|
htp_dump_fp16_line(pref, x + (16 * i), 16);
|
||||||
|
}
|
||||||
|
if (n1) {
|
||||||
|
htp_dump_fp16_line(pref, x + (16 * i), n1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif /* OPS_UTILS_H */
|
||||||
|
|
@ -0,0 +1,418 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-function"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||||
|
|
||||||
|
#ifdef HTP_DEBUG
|
||||||
|
# define FARF_HIGH 1
|
||||||
|
#endif
|
||||||
|
#include <HAP_farf.h>
|
||||||
|
#include <HAP_mem.h>
|
||||||
|
#include <HAP_perf.h>
|
||||||
|
#include <HAP_ps.h>
|
||||||
|
#include <hexagon_protos.h>
|
||||||
|
#include <hexagon_types.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <qurt_thread.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#include "ggml-common.h"
|
||||||
|
#include "htp-ctx.h"
|
||||||
|
#include "htp-dma.h"
|
||||||
|
#include "htp-msg.h"
|
||||||
|
#include "htp-ops.h"
|
||||||
|
#include "hvx-utils.h"
|
||||||
|
#include "ops-utils.h"
|
||||||
|
|
||||||
|
#define htp_rope_preamble \
|
||||||
|
const uint32_t ne00 = src0->ne[0]; \
|
||||||
|
const uint32_t ne01 = src0->ne[1]; \
|
||||||
|
const uint32_t ne02 = src0->ne[2]; \
|
||||||
|
const uint32_t ne03 = src0->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t ne0 = dst->ne[0]; \
|
||||||
|
const uint32_t ne1 = dst->ne[1]; \
|
||||||
|
const uint32_t ne2 = dst->ne[2]; \
|
||||||
|
const uint32_t ne3 = dst->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb00 = src0->nb[0]; \
|
||||||
|
const uint32_t nb01 = src0->nb[1]; \
|
||||||
|
const uint32_t nb02 = src0->nb[2]; \
|
||||||
|
const uint32_t nb03 = src0->nb[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb0 = dst->nb[0]; \
|
||||||
|
const uint32_t nb1 = dst->nb[1]; \
|
||||||
|
const uint32_t nb2 = dst->nb[2]; \
|
||||||
|
const uint32_t nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
struct rope_th_ctx {
|
||||||
|
int32_t n_dims;
|
||||||
|
int32_t mode;
|
||||||
|
int32_t n_ctx_orig;
|
||||||
|
int32_t sections[4];
|
||||||
|
|
||||||
|
float freq_base;
|
||||||
|
float freq_scale;
|
||||||
|
float ext_factor;
|
||||||
|
float attn_factor;
|
||||||
|
float beta_fast;
|
||||||
|
float beta_slow;
|
||||||
|
float theta_scale;
|
||||||
|
float corr_dims[2];
|
||||||
|
|
||||||
|
struct htp_ops_context * octx;
|
||||||
|
};
|
||||||
|
|
||||||
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||||
|
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
|
||||||
|
|
||||||
|
return (1 - MIN(1, MAX(0, y)));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void rope_cache_init(const float theta_base,
|
||||||
|
float freq_scale,
|
||||||
|
const float * freq_factors,
|
||||||
|
float * corr_dims,
|
||||||
|
uint32_t ne0,
|
||||||
|
float ext_factor,
|
||||||
|
float mscale,
|
||||||
|
float * cache,
|
||||||
|
float theta_scale) {
|
||||||
|
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||||
|
float theta = theta_base;
|
||||||
|
|
||||||
|
for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||||
|
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
|
||||||
|
|
||||||
|
float theta_extrap = theta / ff;
|
||||||
|
|
||||||
|
// Get n-d rotational scaling corrected for extrapolation
|
||||||
|
float theta_interp = freq_scale * theta_extrap;
|
||||||
|
float theta2 = theta_interp;
|
||||||
|
|
||||||
|
if (ext_factor != 0.0f) {
|
||||||
|
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||||
|
theta2 = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||||
|
|
||||||
|
// Get n-d magnitude scaling corrected for interpolation
|
||||||
|
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
cache[i0 + 0] = cosf(theta2) * mscale;
|
||||||
|
cache[i0 + 1] = sinf(theta2) * mscale;
|
||||||
|
|
||||||
|
theta *= theta_scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define M_PI 3.1415926535897932384626433
|
||||||
|
|
||||||
|
static void rope_corr_dims(int n_dims,
|
||||||
|
int n_ctx_orig,
|
||||||
|
float freq_base,
|
||||||
|
float beta_fast,
|
||||||
|
float beta_slow,
|
||||||
|
float * dims) {
|
||||||
|
float start = floorf(n_dims * logf(n_ctx_orig / (beta_fast * 2 * (float) M_PI)) / (2 * logf(freq_base)));
|
||||||
|
float end = ceilf(n_dims * logf(n_ctx_orig / (beta_slow * 2 * (float) M_PI)) / (2 * logf(freq_base)));
|
||||||
|
dims[0] = MAX(0, start);
|
||||||
|
dims[1] = MIN(n_dims - 1, end);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
|
||||||
|
memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
|
||||||
|
|
||||||
|
const int32_t * op_params = &octx->op_params[0];
|
||||||
|
|
||||||
|
rope_ctx->n_dims = ((const int32_t *) op_params)[1];
|
||||||
|
rope_ctx->mode = ((const int32_t *) op_params)[2];
|
||||||
|
rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
|
||||||
|
|
||||||
|
memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
|
||||||
|
memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
|
||||||
|
memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
|
||||||
|
memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
|
||||||
|
memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
|
||||||
|
memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
|
||||||
|
memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
|
||||||
|
|
||||||
|
rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
|
||||||
|
|
||||||
|
rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
|
||||||
|
rope_ctx->beta_slow, rope_ctx->corr_dims);
|
||||||
|
|
||||||
|
rope_ctx->octx = octx;
|
||||||
|
FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
|
||||||
|
rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_calc_rope_f32(const float * restrict src0,
|
||||||
|
float * restrict dst,
|
||||||
|
const int num_elems,
|
||||||
|
const float * restrict theta_cache) {
|
||||||
|
// for (int i = 0; i < num_elems; i += 2) {
|
||||||
|
//const float cos_theta = theta_cache[i + 0];
|
||||||
|
//const float sin_theta = theta_cache[i + 1];
|
||||||
|
|
||||||
|
//const float x0 = src[0];
|
||||||
|
//const float x1 = src[1];
|
||||||
|
|
||||||
|
//dst[0] = x0*cos_theta - x1*sin_theta;
|
||||||
|
//dst[1] = x0*sin_theta + x1*cos_theta;
|
||||||
|
|
||||||
|
//src += 2;
|
||||||
|
//dst += 2;
|
||||||
|
// }
|
||||||
|
|
||||||
|
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
|
||||||
|
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
|
||||||
|
uint8_t * restrict dst_curr = (uint8_t *) dst;
|
||||||
|
|
||||||
|
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
|
||||||
|
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
|
||||||
|
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
|
||||||
|
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
|
||||||
|
|
||||||
|
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1
|
||||||
|
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
|
||||||
|
|
||||||
|
HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin));
|
||||||
|
HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin));
|
||||||
|
HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin));
|
||||||
|
HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin));
|
||||||
|
|
||||||
|
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
||||||
|
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
||||||
|
|
||||||
|
HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore);
|
||||||
|
*(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
|
||||||
|
|
||||||
|
src0_curr += 2 * VLEN;
|
||||||
|
theta_curr += 2 * VLEN;
|
||||||
|
dst_curr += 2 * VLEN;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||||
|
const uint32_t ir0,
|
||||||
|
const uint32_t ir1,
|
||||||
|
int nth,
|
||||||
|
int ith,
|
||||||
|
int opt_path) {
|
||||||
|
struct htp_ops_context * octx = rope_ctx->octx;
|
||||||
|
|
||||||
|
const struct htp_tensor * src0 = &octx->src0;
|
||||||
|
const struct htp_tensor * src1 = &octx->src1;
|
||||||
|
const struct htp_tensor * src2 = &octx->src2;
|
||||||
|
struct htp_tensor * dst = &octx->dst;
|
||||||
|
|
||||||
|
htp_rope_preamble;
|
||||||
|
|
||||||
|
const int32_t * pos = (const int32_t *) src1->data;
|
||||||
|
|
||||||
|
float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
|
||||||
|
|
||||||
|
const float * freq_factors = NULL;
|
||||||
|
if (src2 != NULL) {
|
||||||
|
freq_factors = (const float *) src2->data;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ir = 0;
|
||||||
|
|
||||||
|
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||||
|
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||||
|
const int32_t p = pos[i2];
|
||||||
|
|
||||||
|
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
|
||||||
|
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
|
||||||
|
|
||||||
|
for (uint32_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
||||||
|
if (ir++ < ir0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (ir > ir1) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
|
||||||
|
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
|
||||||
|
|
||||||
|
const float * src_loc = src;
|
||||||
|
float * dst_data_loc = dst_data;
|
||||||
|
|
||||||
|
if (1 == opt_path) {
|
||||||
|
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||||
|
} else {
|
||||||
|
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
|
||||||
|
const float cos_theta = wp0[i0 + 0];
|
||||||
|
const float sin_theta = wp0[i0 + 1];
|
||||||
|
|
||||||
|
const float x0 = src_loc[0];
|
||||||
|
const float x1 = src_loc[1];
|
||||||
|
|
||||||
|
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||||
|
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
|
||||||
|
|
||||||
|
src_loc += 2;
|
||||||
|
dst_data_loc += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
|
||||||
|
dst_data_loc[0] = src_loc[0];
|
||||||
|
dst_data_loc[1] = src_loc[1];
|
||||||
|
|
||||||
|
src_loc += 2;
|
||||||
|
dst_data_loc += 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
|
||||||
|
struct htp_ops_context * octx = rope_ctx->octx;
|
||||||
|
|
||||||
|
const struct htp_tensor * src0 = &octx->src0;
|
||||||
|
const struct htp_tensor * src1 = &octx->src1;
|
||||||
|
struct htp_tensor * dst = &octx->dst;
|
||||||
|
|
||||||
|
htp_rope_preamble;
|
||||||
|
|
||||||
|
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||||
|
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||||
|
|
||||||
|
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||||
|
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||||
|
|
||||||
|
// no work for this thread
|
||||||
|
if (src0_start_row >= src0_end_row) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t t1, t2;
|
||||||
|
t1 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
int is_aligned = 1;
|
||||||
|
int opt_path = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
|
||||||
|
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||||
|
FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
|
||||||
|
is_aligned = 0;
|
||||||
|
}
|
||||||
|
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||||
|
opt_path = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
|
||||||
|
|
||||||
|
t2 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
|
||||||
|
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||||
|
struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
|
||||||
|
|
||||||
|
rope_job_f32_per_thread(rope_ctx, n, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||||
|
int err = HTP_STATUS_OK;
|
||||||
|
|
||||||
|
const struct htp_tensor * src0 = &octx->src0;
|
||||||
|
const struct htp_tensor * src1 = &octx->src1;
|
||||||
|
const struct htp_tensor * src2 = &octx->src2;
|
||||||
|
struct htp_tensor * dst = &octx->dst;
|
||||||
|
|
||||||
|
worker_callback_t op_func;
|
||||||
|
const char * op_type = NULL;
|
||||||
|
|
||||||
|
struct rope_th_ctx rope_ctx;
|
||||||
|
|
||||||
|
switch (octx->op) {
|
||||||
|
case HTP_OP_ROPE:
|
||||||
|
op_func = rope_job_dispatcher_f32;
|
||||||
|
op_type = "rope-f32";
|
||||||
|
|
||||||
|
init_rope_ctx(&rope_ctx, octx);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
FARF(ERROR, "Unsupported Op %u\n", octx->op);
|
||||||
|
return HTP_STATUS_NO_SUPPORT;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_threads = octx->n_threads;
|
||||||
|
|
||||||
|
const size_t src0_row_size = src0->nb[1];
|
||||||
|
const size_t src1_row_size = src0_row_size;
|
||||||
|
const size_t dst_row_size = dst->nb[1];
|
||||||
|
|
||||||
|
// VTCM scratchpads for all tensors
|
||||||
|
// N rows per thread, padded to HVX vector size
|
||||||
|
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||||
|
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||||
|
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
|
||||||
|
|
||||||
|
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||||
|
|
||||||
|
if (src2->ne[0]) {
|
||||||
|
FARF(HIGH,
|
||||||
|
"%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
|
||||||
|
"dst-spad-size %u\n",
|
||||||
|
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||||
|
src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
|
||||||
|
dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||||
|
} else {
|
||||||
|
FARF(HIGH,
|
||||||
|
"%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||||
|
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||||
|
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||||
|
octx->dst_spad.size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the reserved vtcm size is sufficient
|
||||||
|
if (octx->ctx->vtcm_size < spad_size) {
|
||||||
|
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||||
|
spad_size);
|
||||||
|
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||||
|
}
|
||||||
|
|
||||||
|
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||||
|
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||||
|
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||||
|
|
||||||
|
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||||
|
|
||||||
|
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||||
|
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||||
|
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||||
|
worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
|
||||||
|
}
|
||||||
|
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
int op_rope(struct htp_ops_context * octx) {
|
||||||
|
int err = HTP_STATUS_OK;
|
||||||
|
|
||||||
|
switch (octx->src0.type) {
|
||||||
|
case HTP_TYPE_F32:
|
||||||
|
err = execute_op_rope_f32(octx);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = HTP_STATUS_NO_SUPPORT;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,402 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-function"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||||
|
|
||||||
|
#ifdef HTP_DEBUG
|
||||||
|
# define FARF_HIGH 1
|
||||||
|
#endif
|
||||||
|
#include <HAP_farf.h>
|
||||||
|
#include <HAP_mem.h>
|
||||||
|
#include <HAP_perf.h>
|
||||||
|
#include <HAP_ps.h>
|
||||||
|
#include <hexagon_protos.h>
|
||||||
|
#include <hexagon_types.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <qurt_thread.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#include "ggml-common.h"
|
||||||
|
#include "htp-ctx.h"
|
||||||
|
#include "htp-dma.h"
|
||||||
|
#include "htp-msg.h"
|
||||||
|
#include "htp-ops.h"
|
||||||
|
#include "hvx-utils.h"
|
||||||
|
#include "ops-utils.h"
|
||||||
|
|
||||||
|
#define htp_softmax_preamble3 \
|
||||||
|
const uint32_t ne00 = src0->ne[0]; \
|
||||||
|
const uint32_t ne01 = src0->ne[1]; \
|
||||||
|
const uint32_t ne02 = src0->ne[2]; \
|
||||||
|
const uint32_t ne03 = src0->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb00 = src0->nb[0]; \
|
||||||
|
const uint32_t nb01 = src0->nb[1]; \
|
||||||
|
const uint32_t nb02 = src0->nb[2]; \
|
||||||
|
const uint32_t nb03 = src0->nb[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \
|
||||||
|
const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \
|
||||||
|
const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \
|
||||||
|
const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \
|
||||||
|
\
|
||||||
|
const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \
|
||||||
|
const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \
|
||||||
|
const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \
|
||||||
|
const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \
|
||||||
|
\
|
||||||
|
const uint32_t ne0 = dst->ne[0]; \
|
||||||
|
const uint32_t ne1 = dst->ne[1]; \
|
||||||
|
const uint32_t ne2 = dst->ne[2]; \
|
||||||
|
const uint32_t ne3 = dst->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb0 = dst->nb[0]; \
|
||||||
|
const uint32_t nb1 = dst->nb[1]; \
|
||||||
|
const uint32_t nb2 = dst->nb[2]; \
|
||||||
|
const uint32_t nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
struct softmax_th_ctx {
|
||||||
|
bool use_f16;
|
||||||
|
bool use_src1;
|
||||||
|
uint32_t n_head;
|
||||||
|
uint32_t n_head_log2;
|
||||||
|
|
||||||
|
float scale;
|
||||||
|
float max_bias;
|
||||||
|
float m0;
|
||||||
|
float m1;
|
||||||
|
|
||||||
|
struct htp_ops_context * octx;
|
||||||
|
};
|
||||||
|
|
||||||
|
static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
|
||||||
|
const struct htp_tensor * src0 = &octx->src0;
|
||||||
|
const struct htp_tensor * src1 = &octx->src1;
|
||||||
|
|
||||||
|
memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
|
||||||
|
|
||||||
|
memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
|
||||||
|
memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
|
softmax_ctx->n_head = src0->ne[2];
|
||||||
|
softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
|
||||||
|
|
||||||
|
softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
|
||||||
|
softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
|
||||||
|
|
||||||
|
softmax_ctx->use_src1 = (src1->ne[0] != 0);
|
||||||
|
softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
|
||||||
|
|
||||||
|
softmax_ctx->octx = octx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
const int num_elems,
|
||||||
|
float scale,
|
||||||
|
const uint8_t * restrict mask,
|
||||||
|
float slope) {
|
||||||
|
const uint8_t * restrict src_curr = src;
|
||||||
|
uint8_t * restrict dst_curr = dst;
|
||||||
|
const uint8_t * restrict mask_curr = mask;
|
||||||
|
|
||||||
|
HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
|
||||||
|
HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
|
||||||
|
|
||||||
|
int step_of_1 = num_elems >> 5;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
HVX_Vector v1 = *(HVX_Vector *) src_curr;
|
||||||
|
|
||||||
|
HVX_Vector v3 = *(HVX_Vector *) mask_curr;
|
||||||
|
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);
|
||||||
|
|
||||||
|
HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v3, slope_vec);
|
||||||
|
|
||||||
|
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, v4);
|
||||||
|
|
||||||
|
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v5);
|
||||||
|
|
||||||
|
src_curr += VLEN;
|
||||||
|
dst_curr += VLEN;
|
||||||
|
mask_curr += VLEN;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
uint8_t * restrict pad,
|
||||||
|
const int num_elems) {
|
||||||
|
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||||
|
HVX_Vector * restrict v_pad = (HVX_Vector *) pad;
|
||||||
|
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000);
|
||||||
|
HVX_Vector max_vec = hvx_vec_splat_fp32(((const float *) src)[0]);
|
||||||
|
HVX_Vector zero_v = Q6_V_vzero();
|
||||||
|
HVX_Vector one_v = hvx_vec_splat_fp32(1.0);
|
||||||
|
|
||||||
|
int step_of_1 = num_elems >> 5;
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
HVX_Vector v1 = v_src[i];
|
||||||
|
max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector v = hvx_vec_reduce_max_fp32(max_vec);
|
||||||
|
max_vec = hvx_vec_repl4(v);
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
HVX_Vector v1 = v_src[i];
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec);
|
||||||
|
|
||||||
|
HVX_Vector v3 = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(v2));
|
||||||
|
|
||||||
|
sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3);
|
||||||
|
|
||||||
|
v_pad[i] = v3;
|
||||||
|
}
|
||||||
|
|
||||||
|
v = hvx_vec_qf32_reduce_sum(sum_vec);
|
||||||
|
sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v));
|
||||||
|
|
||||||
|
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
|
||||||
|
HVX_Vector v4 = hvx_vec_inverse_fp32(sum_vec);
|
||||||
|
HVX_Vector scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v);
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
HVX_Vector v1 = v_pad[i];
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);
|
||||||
|
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static float hvx_softmax_f32(const uint8_t * restrict src,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
uint8_t * restrict spad,
|
||||||
|
const int num_elems,
|
||||||
|
const float max) {
|
||||||
|
hvx_sub_scalar_f32(src, max, spad, num_elems);
|
||||||
|
|
||||||
|
hvx_exp_f32(spad, dst, num_elems, false);
|
||||||
|
|
||||||
|
float sum = hvx_self_sum_f32(dst, num_elems);
|
||||||
|
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
|
||||||
|
struct htp_ops_context * octx = softmax_ctx->octx;
|
||||||
|
|
||||||
|
const struct htp_tensor * src0 = &octx->src0;
|
||||||
|
const struct htp_tensor * src1 = &octx->src1;
|
||||||
|
const struct htp_tensor * dst = &octx->dst;
|
||||||
|
|
||||||
|
htp_softmax_preamble3;
|
||||||
|
|
||||||
|
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
|
||||||
|
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
|
||||||
|
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1);
|
||||||
|
|
||||||
|
float * wp0 = (float *) src0_spad_data;
|
||||||
|
float * wp1 = (float *) src1_spad_data;
|
||||||
|
float * wp2 = (float *) dst_spad_data;
|
||||||
|
|
||||||
|
for (uint32_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (uint32_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||||
|
const uint32_t i11 = i01;
|
||||||
|
const uint32_t i12 = i02 % ne12;
|
||||||
|
const uint32_t i13 = i03 % ne13;
|
||||||
|
|
||||||
|
// ALiBi
|
||||||
|
const uint32_t h = i02; // head
|
||||||
|
|
||||||
|
const float slope = (softmax_ctx->max_bias > 0.0f) ?
|
||||||
|
h < softmax_ctx->n_head_log2 ?
|
||||||
|
powf(softmax_ctx->m0, h + 1) :
|
||||||
|
powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
|
||||||
|
1.0f;
|
||||||
|
|
||||||
|
float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
|
||||||
|
float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
|
||||||
|
|
||||||
|
// broadcast the mask across rows
|
||||||
|
__fp16 * mp_f16 = (softmax_ctx->use_src1) ?
|
||||||
|
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||||
|
NULL;
|
||||||
|
float * mp_f32 = (softmax_ctx->use_src1) ?
|
||||||
|
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||||
|
NULL;
|
||||||
|
|
||||||
|
if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
|
||||||
|
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
|
||||||
|
(const uint8_t *) mp_f32, slope);
|
||||||
|
} else {
|
||||||
|
hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale);
|
||||||
|
if (mp_f32) {
|
||||||
|
if (softmax_ctx->use_f16) {
|
||||||
|
for (int i = 0; i < ne00; ++i) {
|
||||||
|
wp0[i] += slope * (float) mp_f16[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < ne00; ++i) {
|
||||||
|
wp0[i] += slope * mp_f32[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (1 == opt_path) {
|
||||||
|
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||||
|
} else {
|
||||||
|
float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
|
||||||
|
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||||
|
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||||
|
hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
|
||||||
|
struct htp_ops_context * octx = softmax_ctx->octx;
|
||||||
|
|
||||||
|
const struct htp_tensor * src0 = &octx->src0;
|
||||||
|
const struct htp_tensor * src1 = &octx->src1;
|
||||||
|
struct htp_tensor * dst = &octx->dst;
|
||||||
|
|
||||||
|
htp_softmax_preamble3;
|
||||||
|
|
||||||
|
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||||
|
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||||
|
|
||||||
|
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||||
|
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||||
|
|
||||||
|
// no work for this thread
|
||||||
|
if (src0_start_row >= src0_end_row) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t t1, t2;
|
||||||
|
t1 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
int is_aligned = 1;
|
||||||
|
int opt_path = 0;
|
||||||
|
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||||
|
is_aligned = 0;
|
||||||
|
FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||||
|
opt_path = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
|
||||||
|
|
||||||
|
t2 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||||
|
softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||||
|
ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
|
||||||
|
struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
|
||||||
|
softmax_job_f32_per_thread(p_softmax_ctx, n, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||||
|
int err = HTP_STATUS_OK;
|
||||||
|
|
||||||
|
const struct htp_tensor * src0 = &octx->src0;
|
||||||
|
const struct htp_tensor * src1 = &octx->src1;
|
||||||
|
struct htp_tensor * dst = &octx->dst;
|
||||||
|
|
||||||
|
worker_callback_t op_func;
|
||||||
|
const char * op_type = NULL;
|
||||||
|
|
||||||
|
struct softmax_th_ctx softmax_ctx;
|
||||||
|
|
||||||
|
switch (octx->op) {
|
||||||
|
case HTP_OP_SOFTMAX:
|
||||||
|
op_func = softmax_job_dispatcher_f32;
|
||||||
|
op_type = "softmax-f32";
|
||||||
|
|
||||||
|
init_softmax_ctx(&softmax_ctx, octx);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
FARF(ERROR, "Unsupported Op %u\n", octx->op);
|
||||||
|
return HTP_STATUS_NO_SUPPORT;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t n_threads = octx->n_threads;
|
||||||
|
|
||||||
|
const size_t src0_row_size = src0->nb[1];
|
||||||
|
const size_t src1_row_size = src0_row_size;
|
||||||
|
const size_t dst_row_size = dst->nb[1];
|
||||||
|
|
||||||
|
// VTCM scratchpads for all tensors
|
||||||
|
// N rows per thread, padded to HVX vector size
|
||||||
|
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||||
|
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||||
|
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
|
||||||
|
|
||||||
|
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||||
|
|
||||||
|
if (src1->ne[0]) {
|
||||||
|
FARF(HIGH,
|
||||||
|
"%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||||
|
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||||
|
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||||
|
octx->dst_spad.size);
|
||||||
|
} else {
|
||||||
|
FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
||||||
|
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||||
|
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure the reserved vtcm size is sufficient
|
||||||
|
if (octx->ctx->vtcm_size < spad_size) {
|
||||||
|
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||||
|
spad_size);
|
||||||
|
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||||
|
}
|
||||||
|
|
||||||
|
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||||
|
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||||
|
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||||
|
|
||||||
|
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||||
|
|
||||||
|
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||||
|
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||||
|
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||||
|
worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
|
||||||
|
}
|
||||||
|
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
int op_softmax(struct htp_ops_context * octx) {
|
||||||
|
int err = HTP_STATUS_OK;
|
||||||
|
|
||||||
|
switch (octx->src0.type) {
|
||||||
|
case HTP_TYPE_F32:
|
||||||
|
err = execute_op_softmax_f32(octx);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = HTP_STATUS_NO_SUPPORT;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,255 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-function"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||||
|
|
||||||
|
#ifdef HTP_DEBUG
|
||||||
|
# define FARF_HIGH 1
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <HAP_farf.h>
|
||||||
|
#include <HAP_mem.h>
|
||||||
|
#include <HAP_perf.h>
|
||||||
|
#include <HAP_ps.h>
|
||||||
|
#include <hexagon_protos.h>
|
||||||
|
#include <hexagon_types.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <qurt_thread.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_C
|
||||||
|
#include "ggml-common.h"
|
||||||
|
#include "htp-ctx.h"
|
||||||
|
#include "htp-dma.h"
|
||||||
|
#include "htp-msg.h"
|
||||||
|
#include "htp-ops.h"
|
||||||
|
#include "hvx-utils.h"
|
||||||
|
#include "ops-utils.h"
|
||||||
|
|
||||||
|
#define htp_unary_preamble \
|
||||||
|
const uint32_t ne00 = src->ne[0]; \
|
||||||
|
const uint32_t ne01 = src->ne[1]; \
|
||||||
|
const uint32_t ne02 = src->ne[2]; \
|
||||||
|
const uint32_t ne03 = src->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t ne0 = dst->ne[0]; \
|
||||||
|
const uint32_t ne1 = dst->ne[1]; \
|
||||||
|
const uint32_t ne2 = dst->ne[2]; \
|
||||||
|
const uint32_t ne3 = dst->ne[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb00 = src->nb[0]; \
|
||||||
|
const uint32_t nb01 = src->nb[1]; \
|
||||||
|
const uint32_t nb02 = src->nb[2]; \
|
||||||
|
const uint32_t nb03 = src->nb[3]; \
|
||||||
|
\
|
||||||
|
const uint32_t nb0 = dst->nb[0]; \
|
||||||
|
const uint32_t nb1 = dst->nb[1]; \
|
||||||
|
const uint32_t nb2 = dst->nb[2]; \
|
||||||
|
const uint32_t nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
||||||
|
uint8_t * restrict dst,
|
||||||
|
uint8_t * restrict pad,
|
||||||
|
const int num_elems,
|
||||||
|
float epsilon) {
|
||||||
|
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||||
|
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||||
|
|
||||||
|
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
||||||
|
HVX_Vector epsilon_v = hvx_vec_splat_fp32(epsilon);
|
||||||
|
|
||||||
|
int step_of_1 = num_elems >> 5;
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
HVX_Vector v1 = v_src[i];
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||||
|
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||||
|
}
|
||||||
|
|
||||||
|
HVX_Vector reduced_sum = hvx_vec_qf32_reduce_sum(sum_v);
|
||||||
|
sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum));
|
||||||
|
|
||||||
|
HVX_Vector t_v = hvx_vec_splat_fp32((float) num_elems);
|
||||||
|
HVX_Vector denom_v = hvx_vec_inverse_fp32(t_v);
|
||||||
|
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
|
||||||
|
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
|
||||||
|
|
||||||
|
HVX_Vector scale_v = hvx_vec_rsqrt_fp32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
|
||||||
|
|
||||||
|
#pragma unroll(4)
|
||||||
|
for (int i = 0; i < step_of_1; i++) {
|
||||||
|
HVX_Vector v1 = v_src[i];
|
||||||
|
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
||||||
|
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void rms_norm_htp_f32(const float * restrict src,
|
||||||
|
float * restrict dst,
|
||||||
|
uint8_t * restrict spad,
|
||||||
|
const uint32_t num_rows,
|
||||||
|
const uint32_t row_elems,
|
||||||
|
const size_t row_size,
|
||||||
|
int32_t * op_params,
|
||||||
|
int opt_path) {
|
||||||
|
float epsilon = 0.f;
|
||||||
|
memcpy(&epsilon, op_params, sizeof(float));
|
||||||
|
|
||||||
|
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||||
|
const float * restrict src_local = src + (ir * row_elems);
|
||||||
|
float * restrict dst_local = dst + (ir * row_elems);
|
||||||
|
|
||||||
|
if (ir + 1 < num_rows) {
|
||||||
|
htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (1 == opt_path) {
|
||||||
|
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
||||||
|
} else {
|
||||||
|
float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
|
||||||
|
|
||||||
|
const float mean = sum / row_elems;
|
||||||
|
const float scale = 1.0f / sqrtf(mean + epsilon);
|
||||||
|
|
||||||
|
hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||||
|
struct htp_tensor * dst,
|
||||||
|
uint8_t * spad,
|
||||||
|
int htp_op,
|
||||||
|
int32_t * op_params,
|
||||||
|
uint32_t nth,
|
||||||
|
uint32_t ith,
|
||||||
|
uint32_t src0_nrows_per_thread) {
|
||||||
|
htp_unary_preamble;
|
||||||
|
|
||||||
|
const size_t src0_row_size = nb01;
|
||||||
|
const size_t dst_row_size = nb1;
|
||||||
|
|
||||||
|
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||||
|
|
||||||
|
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||||
|
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||||
|
|
||||||
|
// no work for this thread
|
||||||
|
if (src0_start_row >= src0_end_row) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t t1, t2;
|
||||||
|
t1 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
int is_aligned = 1;
|
||||||
|
int opt_path = 0;
|
||||||
|
if ((0 == htp_is_aligned((void *) src->data, VLEN)) || (0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||||
|
is_aligned = 0;
|
||||||
|
FARF(HIGH, "unary-f32: unaligned addresses in unary op, possibly slower execution\n");
|
||||||
|
}
|
||||||
|
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||||
|
opt_path = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t * restrict data_src = (const uint8_t *) src->data;
|
||||||
|
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||||
|
|
||||||
|
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
|
||||||
|
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
|
||||||
|
uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01);
|
||||||
|
|
||||||
|
switch (htp_op) {
|
||||||
|
case HTP_OP_RMS_NORM:
|
||||||
|
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
t2 = HAP_perf_get_qtimer_count();
|
||||||
|
|
||||||
|
FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
|
||||||
|
src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
|
||||||
|
dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||||
|
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||||
|
|
||||||
|
unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
|
||||||
|
octx->src0_nrows_per_thread);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||||
|
int err = HTP_STATUS_OK;
|
||||||
|
|
||||||
|
const struct htp_tensor * src0 = &octx->src0;
|
||||||
|
struct htp_tensor * dst = &octx->dst;
|
||||||
|
|
||||||
|
worker_callback_t unary_op_func;
|
||||||
|
const char * op_type = NULL;
|
||||||
|
|
||||||
|
switch (octx->op) {
|
||||||
|
case HTP_OP_RMS_NORM:
|
||||||
|
unary_op_func = unary_job_dispatcher_f32;
|
||||||
|
op_type = "rmsnorm-f32";
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||||
|
return HTP_STATUS_NO_SUPPORT;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_threads = octx->n_threads;
|
||||||
|
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||||
|
|
||||||
|
const size_t src0_row_size = src0->nb[1];
|
||||||
|
const size_t dst_row_size = dst->nb[1];
|
||||||
|
|
||||||
|
// VTCM scratchpads for all tensors
|
||||||
|
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||||
|
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||||
|
|
||||||
|
size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
|
||||||
|
|
||||||
|
FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
||||||
|
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||||
|
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||||
|
|
||||||
|
// Make sure the reserved vtcm size is sufficient
|
||||||
|
if (octx->ctx->vtcm_size < spad_size) {
|
||||||
|
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||||
|
spad_size);
|
||||||
|
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||||
|
}
|
||||||
|
|
||||||
|
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||||
|
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||||
|
|
||||||
|
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||||
|
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||||
|
|
||||||
|
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||||
|
|
||||||
|
worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
|
||||||
|
}
|
||||||
|
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
||||||
|
int op_unary(struct htp_ops_context * octx) {
|
||||||
|
int err = HTP_STATUS_OK;
|
||||||
|
|
||||||
|
switch (octx->src0.type) {
|
||||||
|
case HTP_TYPE_F32:
|
||||||
|
err = execute_op_unary_f32(octx);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = HTP_STATUS_NO_SUPPORT;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return err;
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,297 @@
|
||||||
|
#include "worker-pool.h"
|
||||||
|
|
||||||
|
#include <qurt.h>
|
||||||
|
#include <stdatomic.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#ifdef HTP_DEBUG
|
||||||
|
# define FARF_HIGH 1
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "HAP_farf.h"
|
||||||
|
|
||||||
|
#define WORKER_THREAD_STACK_SZ (2 * 16384)
|
||||||
|
#define LOWEST_USABLE_QURT_PRIO (254)
|
||||||
|
|
||||||
|
struct worker_pool_s;
|
||||||
|
|
||||||
|
// internal structure kept in thread-local storage per instance of worker pool
|
||||||
|
typedef struct {
|
||||||
|
struct worker_pool_s * pool;
|
||||||
|
unsigned int id;
|
||||||
|
} worker_context_t;
|
||||||
|
|
||||||
|
// internal structure kept in thread-local storage per instance of worker pool
|
||||||
|
typedef struct worker_pool_s {
|
||||||
|
worker_pool_job_t job[MAX_NUM_WORKERS]; // list of job descriptors
|
||||||
|
qurt_thread_t thread[MAX_NUM_WORKERS]; // thread ID's of the workers
|
||||||
|
worker_context_t context[MAX_NUM_WORKERS]; // worker contexts
|
||||||
|
void * stack[MAX_NUM_WORKERS]; // thread stack pointers
|
||||||
|
unsigned int n_threads; // number of workers in this pool
|
||||||
|
|
||||||
|
atomic_uint seqn; // seqno used to detect new jobs
|
||||||
|
atomic_uint next_job; // next job index
|
||||||
|
atomic_uint n_pending; // number of pending jobs
|
||||||
|
atomic_uint n_jobs; // number of current jobs
|
||||||
|
atomic_bool killed; // threads need to exit
|
||||||
|
} worker_pool_t;
|
||||||
|
|
||||||
|
static void worker_pool_main(void * context) {
|
||||||
|
worker_context_t * me = (worker_context_t *) context;
|
||||||
|
worker_pool_t * pool = me->pool;
|
||||||
|
|
||||||
|
FARF(HIGH, "worker-pool: thread %u started", me->id);
|
||||||
|
|
||||||
|
unsigned int prev_seqn = 0;
|
||||||
|
while (!atomic_load(&pool->killed)) {
|
||||||
|
unsigned int seqn = atomic_load(&pool->seqn);
|
||||||
|
if (seqn == prev_seqn) {
|
||||||
|
// Nothing to do
|
||||||
|
qurt_futex_wait(&pool->seqn, prev_seqn);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// New job
|
||||||
|
prev_seqn = seqn;
|
||||||
|
|
||||||
|
unsigned int n = atomic_load(&pool->n_jobs);
|
||||||
|
unsigned int i = atomic_fetch_add(&pool->next_job, 1);
|
||||||
|
if (i >= n) {
|
||||||
|
// Spurios wakeup
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
pool->job[i].func(n, i, pool->job[i].data);
|
||||||
|
|
||||||
|
atomic_fetch_sub(&pool->n_pending, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
FARF(HIGH, "worker-pool: thread %u stopped", me->id);
|
||||||
|
}
|
||||||
|
|
||||||
|
AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context, uint32_t n_threads, uint32_t stack_size) {
|
||||||
|
int err = 0;
|
||||||
|
|
||||||
|
if (NULL == context) {
|
||||||
|
FARF(ERROR, "NULL context passed to worker_pool_init().");
|
||||||
|
return AEE_EBADPARM;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocations
|
||||||
|
int size = (stack_size * n_threads) + (sizeof(worker_pool_t));
|
||||||
|
|
||||||
|
unsigned char * mem_blob = (unsigned char *) malloc(size);
|
||||||
|
if (!mem_blob) {
|
||||||
|
FARF(ERROR, "Could not allocate memory for worker pool!!");
|
||||||
|
return AEE_ENOMEMORY;
|
||||||
|
}
|
||||||
|
|
||||||
|
worker_pool_t * me = (worker_pool_t *) (mem_blob + stack_size * n_threads);
|
||||||
|
|
||||||
|
// name for the first worker, useful in debugging threads
|
||||||
|
char name[19];
|
||||||
|
snprintf(name, 12, "0x%8x:", (int) me);
|
||||||
|
strcat(name, "worker0");
|
||||||
|
me->n_threads = n_threads;
|
||||||
|
|
||||||
|
// initializations
|
||||||
|
for (unsigned int i = 0; i < me->n_threads; i++) {
|
||||||
|
me->stack[i] = NULL;
|
||||||
|
me->thread[i] = 0;
|
||||||
|
|
||||||
|
me->context[i].id = i;
|
||||||
|
me->context[i].pool = me;
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize job queue
|
||||||
|
me->n_pending = 0;
|
||||||
|
me->n_jobs = 0;
|
||||||
|
me->next_job = 0;
|
||||||
|
me->seqn = 0;
|
||||||
|
me->killed = 0;
|
||||||
|
|
||||||
|
// launch the workers
|
||||||
|
qurt_thread_attr_t attr;
|
||||||
|
qurt_thread_attr_init(&attr);
|
||||||
|
|
||||||
|
for (unsigned int i = 0; i < me->n_threads; i++) {
|
||||||
|
// set up stack
|
||||||
|
me->stack[i] = mem_blob;
|
||||||
|
mem_blob += stack_size;
|
||||||
|
qurt_thread_attr_set_stack_addr(&attr, me->stack[i]);
|
||||||
|
qurt_thread_attr_set_stack_size(&attr, stack_size);
|
||||||
|
|
||||||
|
// set up name
|
||||||
|
qurt_thread_attr_set_name(&attr, name);
|
||||||
|
name[17] = (name[17] + 1);
|
||||||
|
// name threads context:worker0, context:worker1, .. (recycle at 9, but num threads should be less than that anyway)
|
||||||
|
if (name[17] > '9') {
|
||||||
|
name[17] = '0';
|
||||||
|
}
|
||||||
|
|
||||||
|
// set up priority - by default, match the creating thread's prio
|
||||||
|
int prio = qurt_thread_get_priority(qurt_thread_get_id());
|
||||||
|
|
||||||
|
if (prio < 1) {
|
||||||
|
prio = 1;
|
||||||
|
}
|
||||||
|
if (prio > LOWEST_USABLE_QURT_PRIO) {
|
||||||
|
prio = LOWEST_USABLE_QURT_PRIO;
|
||||||
|
}
|
||||||
|
|
||||||
|
qurt_thread_attr_set_priority(&attr, prio);
|
||||||
|
|
||||||
|
// launch
|
||||||
|
err = qurt_thread_create(&me->thread[i], &attr, worker_pool_main, (void *) &me->context[i]);
|
||||||
|
if (err) {
|
||||||
|
FARF(ERROR, "Could not launch worker threads!");
|
||||||
|
worker_pool_release((worker_pool_context_t *) &me);
|
||||||
|
return AEE_EQURTTHREADCREATE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*context = (worker_pool_context_t *) me;
|
||||||
|
return AEE_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads) {
|
||||||
|
return worker_pool_init_with_stack_size(context, n_threads, WORKER_THREAD_STACK_SZ);
|
||||||
|
}
|
||||||
|
|
||||||
|
// clean up worker pool
|
||||||
|
void worker_pool_release(worker_pool_context_t * context) {
|
||||||
|
worker_pool_t * me = (worker_pool_t *) *context;
|
||||||
|
|
||||||
|
// if no worker pool exists, return error.
|
||||||
|
if (NULL == me) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic_store(&me->killed, 1);
|
||||||
|
atomic_fetch_add(&me->seqn, 1);
|
||||||
|
qurt_futex_wake(&me->seqn, me->n_threads);
|
||||||
|
|
||||||
|
// de-initializations
|
||||||
|
for (unsigned int i = 0; i < me->n_threads; i++) {
|
||||||
|
if (me->thread[i]) {
|
||||||
|
int status;
|
||||||
|
(void) qurt_thread_join(me->thread[i], &status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// free allocated memory (were allocated as a single buffer starting at stack[0])
|
||||||
|
if (me->stack[0]) {
|
||||||
|
free(me->stack[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
*context = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
// run jobs
|
||||||
|
AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n) {
|
||||||
|
worker_pool_t * me = (worker_pool_t *) context;
|
||||||
|
if (NULL == me) {
|
||||||
|
FARF(ERROR, "worker-pool: invalid context");
|
||||||
|
return AEE_EBADPARM;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n > me->n_threads) {
|
||||||
|
FARF(ERROR, "worker-pool: invalid number of jobs %u for n-threads %u", n, me->n_threads);
|
||||||
|
return AEE_EBADPARM;
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(me->job, job, sizeof(worker_pool_job_t) * n);
|
||||||
|
|
||||||
|
if (n > 1) {
|
||||||
|
atomic_store(&me->next_job, 1);
|
||||||
|
atomic_store(&me->n_jobs, n);
|
||||||
|
atomic_store(&me->n_pending, n - 1);
|
||||||
|
|
||||||
|
// wake up workers
|
||||||
|
atomic_fetch_add(&me->seqn, 1);
|
||||||
|
qurt_futex_wake(&me->seqn, n - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// main thread runs job #0
|
||||||
|
me->job[0].func(n, 0, me->job[0].data);
|
||||||
|
|
||||||
|
if (n > 1) {
|
||||||
|
while (atomic_load(&me->n_pending))
|
||||||
|
;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// run func
|
||||||
|
AEEResult worker_pool_run_func(worker_pool_context_t context, worker_callback_t func, void * data, unsigned int n) {
|
||||||
|
worker_pool_job_t job[n];
|
||||||
|
|
||||||
|
for (unsigned int i = 0; i < n; i++) {
|
||||||
|
job[i].func = func;
|
||||||
|
job[i].data = data;
|
||||||
|
}
|
||||||
|
|
||||||
|
return worker_pool_run_jobs(context, job, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio) {
|
||||||
|
worker_pool_t * me = (worker_pool_t *) context;
|
||||||
|
|
||||||
|
// if no worker pool exists, return error.
|
||||||
|
if (!me) {
|
||||||
|
return AEE_ENOMORE;
|
||||||
|
}
|
||||||
|
|
||||||
|
int result = AEE_SUCCESS;
|
||||||
|
if (prio < 1) {
|
||||||
|
prio = 1;
|
||||||
|
}
|
||||||
|
if (prio > LOWEST_USABLE_QURT_PRIO) {
|
||||||
|
prio = LOWEST_USABLE_QURT_PRIO;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (unsigned int i = 0; i < me->n_threads; i++) {
|
||||||
|
int res = qurt_thread_set_priority(me->thread[i], (unsigned short) prio);
|
||||||
|
if (0 != res) {
|
||||||
|
result = AEE_EBADPARM;
|
||||||
|
FARF(ERROR, "QURT failed to set priority of thread %d, ERROR = %d", me->thread[i], res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids) {
|
||||||
|
worker_pool_t * me = (worker_pool_t *) context;
|
||||||
|
if (!me) {
|
||||||
|
FARF(ERROR, "worker-pool: invalid context");
|
||||||
|
return AEE_EBADPARM;
|
||||||
|
;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < me->n_threads; i++) {
|
||||||
|
tids[i] = me->thread[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return AEE_SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio) {
|
||||||
|
worker_pool_t * me = (worker_pool_t *) context;
|
||||||
|
if (!me) {
|
||||||
|
FARF(ERROR, "worker-pool: invalid context");
|
||||||
|
return AEE_EBADPARM;
|
||||||
|
}
|
||||||
|
|
||||||
|
int priority = qurt_thread_get_priority(me->thread[0]);
|
||||||
|
if (priority > 0) {
|
||||||
|
*prio = priority;
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
*prio = 0;
|
||||||
|
return AEE_EBADSTATE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
#ifndef HTP_WORKER_POOL_H
|
||||||
|
#define HTP_WORKER_POOL_H
|
||||||
|
|
||||||
|
// MACRO enables function to be visible in shared-library case.
|
||||||
|
#define WORKERPOOL_API __attribute__((visibility("default")))
|
||||||
|
|
||||||
|
#include <AEEStdDef.h>
|
||||||
|
#include <AEEStdErr.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/// signature of callbacks to be invoked by worker threads
|
||||||
|
typedef void (*worker_callback_t)(unsigned int n, unsigned int i, void *);
|
||||||
|
|
||||||
|
/// Typedef of worker_pool context
|
||||||
|
typedef void * worker_pool_context_t;
|
||||||
|
|
||||||
|
/// descriptor for requested callback
|
||||||
|
typedef struct {
|
||||||
|
worker_callback_t func;
|
||||||
|
void * data;
|
||||||
|
} worker_pool_job_t;
|
||||||
|
|
||||||
|
/// Maximum supported number of worker threads.
|
||||||
|
#define MAX_NUM_WORKERS 10
|
||||||
|
|
||||||
|
// Initialize worker pool.
|
||||||
|
WORKERPOOL_API AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads);
|
||||||
|
|
||||||
|
// Initialize worker pool with custom stack size
|
||||||
|
WORKERPOOL_API AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context,
|
||||||
|
uint32_t n_threads,
|
||||||
|
uint32_t stack_size);
|
||||||
|
|
||||||
|
// Kill worker threads and release worker pool resources
|
||||||
|
WORKERPOOL_API void worker_pool_release(worker_pool_context_t * context);
|
||||||
|
|
||||||
|
// Run jobs with the worker pool.
|
||||||
|
WORKERPOOL_API AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n);
|
||||||
|
|
||||||
|
WORKERPOOL_API AEEResult worker_pool_run_func(worker_pool_context_t context,
|
||||||
|
worker_callback_t func,
|
||||||
|
void * data,
|
||||||
|
unsigned int n);
|
||||||
|
|
||||||
|
WORKERPOOL_API AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio);
|
||||||
|
WORKERPOOL_API AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio);
|
||||||
|
WORKERPOOL_API AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // #ifndef HTP_WORKER_POOL_H
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
0xffff
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
#!/bin/sh
|
||||||
|
#
|
||||||
|
|
||||||
|
# Basedir on device
|
||||||
|
basedir=/data/local/tmp/llama.cpp
|
||||||
|
|
||||||
|
branch=.
|
||||||
|
[ "$B" != "" ] && branch=$B
|
||||||
|
|
||||||
|
adbserial=
|
||||||
|
[ "$S" != "" ] && adbserial="-s $S"
|
||||||
|
|
||||||
|
model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||||
|
[ "$M" != "" ] && model="$M"
|
||||||
|
|
||||||
|
device="HTP0"
|
||||||
|
[ "$D" != "" ] && device="$D"
|
||||||
|
|
||||||
|
verbose=""
|
||||||
|
[ "$V" != "" ] && verbose="$V"
|
||||||
|
|
||||||
|
opmask=
|
||||||
|
[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK"
|
||||||
|
|
||||||
|
nhvx=
|
||||||
|
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
|
||||||
|
|
||||||
|
ndev=
|
||||||
|
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
|
||||||
|
|
||||||
|
set -x
|
||||||
|
|
||||||
|
adb $adbserial shell " \
|
||||||
|
cd $basedir; \
|
||||||
|
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||||
|
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||||
|
$ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \
|
||||||
|
-t 4 --batch-size 128 -ngl 99 $@ \
|
||||||
|
"
|
||||||
|
|
@ -0,0 +1,52 @@
|
||||||
|
#!/bin/sh
|
||||||
|
#
|
||||||
|
|
||||||
|
# Basedir on device
|
||||||
|
basedir=/data/local/tmp/llama.cpp
|
||||||
|
|
||||||
|
cli_opts=
|
||||||
|
|
||||||
|
branch=.
|
||||||
|
[ "$B" != "" ] && branch=$B
|
||||||
|
|
||||||
|
adbserial=
|
||||||
|
[ "$S" != "" ] && adbserial="-s $S"
|
||||||
|
|
||||||
|
model="Llama-3.2-3B-Instruct-Q4_0.gguf"
|
||||||
|
[ "$M" != "" ] && model="$M"
|
||||||
|
|
||||||
|
device="HTP0"
|
||||||
|
[ "$D" != "" ] && device="$D"
|
||||||
|
|
||||||
|
verbose=
|
||||||
|
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V"
|
||||||
|
|
||||||
|
experimental=
|
||||||
|
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E"
|
||||||
|
|
||||||
|
sched=
|
||||||
|
[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v"
|
||||||
|
|
||||||
|
profile=
|
||||||
|
[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1"
|
||||||
|
|
||||||
|
opmask=
|
||||||
|
[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK"
|
||||||
|
|
||||||
|
nhvx=
|
||||||
|
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
|
||||||
|
|
||||||
|
ndev=
|
||||||
|
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
|
||||||
|
|
||||||
|
set -x
|
||||||
|
|
||||||
|
adb $adbserial shell " \
|
||||||
|
cd $basedir; ulimit -c unlimited; \
|
||||||
|
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||||
|
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||||
|
$verbose $experimental $sched $opmask $profile $nhvx $ndev \
|
||||||
|
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
|
||||||
|
-t 4 --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
|
||||||
|
-ngl 99 --device $device $cli_opts $@ \
|
||||||
|
"
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
#!/bin/sh
|
||||||
|
#
|
||||||
|
|
||||||
|
# Basedir on device
|
||||||
|
basedir=/data/local/tmp/llama.cpp
|
||||||
|
|
||||||
|
cli_opts=
|
||||||
|
|
||||||
|
branch=.
|
||||||
|
[ "$B" != "" ] && branch=$B
|
||||||
|
|
||||||
|
adbserial=
|
||||||
|
[ "$S" != "" ] && adbserial="-s $S"
|
||||||
|
|
||||||
|
device="HTP0"
|
||||||
|
[ "$D" != "" ] && device="$D"
|
||||||
|
|
||||||
|
verbose=
|
||||||
|
[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V"
|
||||||
|
|
||||||
|
experimental=
|
||||||
|
[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$V"
|
||||||
|
|
||||||
|
sched=
|
||||||
|
[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v"
|
||||||
|
|
||||||
|
profile=
|
||||||
|
[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1"
|
||||||
|
|
||||||
|
opmask=
|
||||||
|
[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK"
|
||||||
|
|
||||||
|
nhvx=
|
||||||
|
[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX"
|
||||||
|
|
||||||
|
ndev=
|
||||||
|
[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV"
|
||||||
|
|
||||||
|
hb=
|
||||||
|
[ "$HB" != "" ] && hb="GGML_HEXAGON_HOSTBUF=$HB"
|
||||||
|
|
||||||
|
set -x
|
||||||
|
|
||||||
|
tool=$1; shift
|
||||||
|
|
||||||
|
adb $adbserial shell " \
|
||||||
|
cd $basedir; ulimit -c unlimited; \
|
||||||
|
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||||
|
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||||
|
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb ./$branch/bin/$tool $@ \
|
||||||
|
"
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
This directory includes pytest based scripts for running CI jobs on Qualcomm Device Cloud (QDC).
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
Appium-Python-Client==5.2.4
|
||||||
|
attrs==25.4.0
|
||||||
|
certifi==2025.10.5
|
||||||
|
exceptiongroup==1.3.0
|
||||||
|
h11==0.16.0
|
||||||
|
idna==3.11
|
||||||
|
iniconfig==2.1.0
|
||||||
|
outcome==1.3.0.post0
|
||||||
|
packaging==25.0
|
||||||
|
pluggy==1.6.0
|
||||||
|
Pygments==2.19.2
|
||||||
|
PySocks==1.7.1
|
||||||
|
pytest==8.4.2
|
||||||
|
pytest-dependency==0.6.0
|
||||||
|
selenium==4.36.0
|
||||||
|
setuptools==80.9.0
|
||||||
|
sniffio==1.3.1
|
||||||
|
sortedcontainers==2.4.0
|
||||||
|
tomli==2.3.0
|
||||||
|
trio==0.31.0
|
||||||
|
trio-websocket==0.12.2
|
||||||
|
typing_extensions==4.15.0
|
||||||
|
urllib3==2.5.0
|
||||||
|
websocket-client==1.9.0
|
||||||
|
wsproto==1.2.0
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
import pytest
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
tmp_path='/data/local/tmp'
|
||||||
|
pkg_path=f'{tmp_path}/llama.cpp'
|
||||||
|
lib_path=f'{pkg_path}/lib'
|
||||||
|
bin_path=f'{pkg_path}/bin'
|
||||||
|
|
||||||
|
model='../gguf/Llama-3.2-1B-Instruct-Q4_0.gguf'
|
||||||
|
cli_pref=f'cd {pkg_path} && LD_LIBRARY_PATH={lib_path} ADSP_LIBRARY_PATH={lib_path} {bin_path}'
|
||||||
|
|
||||||
|
|
||||||
|
def run_cmd(cmd):
|
||||||
|
p = subprocess.run(cmd, text = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT)
|
||||||
|
sys.stdout.write(p.stdout)
|
||||||
|
assert(p.returncode == 0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dependency()
|
||||||
|
def test_install():
|
||||||
|
run_cmd(['adb', 'push', 'llama.cpp', f'{tmp_path}'])
|
||||||
|
run_cmd(['adb', 'shell', f'chmod 755 {bin_path}/*'])
|
||||||
|
|
||||||
|
|
||||||
|
## Basic cli tests
|
||||||
|
def run_llama_cli(dev, opts):
|
||||||
|
prompt='what is the most popular cookie in the world?\nPlease provide a very brief bullet point summary.\nBegin your answer with **BEGIN**.'
|
||||||
|
opts = '--batch-size 128 -n 128 -no-cnv --seed 42 ' + opts
|
||||||
|
run_cmd(['adb', 'shell', f'{cli_pref}/llama-cli -m {model} --device {dev} -ngl 99 -t 4 {opts} -p "{prompt}"'])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dependency(depends=['test_install'])
|
||||||
|
def test_llama_cli_cpu():
|
||||||
|
run_llama_cli('none', '-ctk q8_0 -ctv q8_0 -fa on')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dependency(depends=['test_install'])
|
||||||
|
def test_llama_cli_gpu():
|
||||||
|
run_llama_cli('GPUOpenCL', '-fa on')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dependency(depends=['test_install'])
|
||||||
|
def test_llama_cli_npu():
|
||||||
|
run_llama_cli('HTP0', '-ctk q8_0 -ctv q8_0 -fa on')
|
||||||
|
|
||||||
|
|
||||||
|
## Basic bench tests
|
||||||
|
def run_llama_bench(dev):
|
||||||
|
run_cmd(['adb', 'shell', f'{cli_pref}/llama-bench -m {model} --device {dev} -ngl 99 --batch-size 128 -t 4 -p 128 -n 32'])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dependency(depends=['test_install'])
|
||||||
|
def test_llama_bench_cpu():
|
||||||
|
run_llama_bench('none')
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama_bench_gpu():
|
||||||
|
run_llama_bench('GPUOpenCL')
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama_bench_npu():
|
||||||
|
run_llama_bench('HTP0')
|
||||||
|
|
@ -404,6 +404,19 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s
|
||||||
// add the device default buffer type
|
// add the device default buffer type
|
||||||
buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
|
buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
|
||||||
|
|
||||||
|
// add the device extra buffer type (if any)
|
||||||
|
ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
|
||||||
|
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
|
||||||
|
ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts");
|
||||||
|
|
||||||
|
if (ggml_backend_dev_get_extra_bufts_fn) {
|
||||||
|
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev);
|
||||||
|
while (extra_bufts && *extra_bufts) {
|
||||||
|
buft_list.emplace_back(dev, *extra_bufts);
|
||||||
|
++extra_bufts;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return buft_list;
|
return buft_list;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue