Merge remote-tracking branch 'upstream/master' into cuda_graph_plan
This commit is contained in:
commit
d6e97a00d7
|
|
@ -1305,6 +1305,81 @@ jobs:
|
|||
cd examples/llama.android
|
||||
./gradlew build --no-daemon
|
||||
|
||||
android-ndk-build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
OPENCL_VERSION: 2025.07.22
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- build: 'arm64-cpu'
|
||||
defines: '-D ANDROID_ABI=arm64-v8a -D ANDROID_PLATFORM=android-31 -D CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -D GGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm -G Ninja -D LLAMA_CURL=OFF -D GGML_OPENMP=OFF'
|
||||
- build: 'arm64-snapdragon'
|
||||
defines: '--preset arm64-android-snapdragon-release'
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install OpenCL Headers and Libs
|
||||
id: install_opencl
|
||||
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||
run: |
|
||||
mkdir opencl
|
||||
curl -L -o opencl/clhpp.tar.gz https://github.com/KhronosGroup/OpenCL-CLHPP/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
|
||||
curl -L -o opencl/headers.tar.gz https://github.com/KhronosGroup/OpenCL-Headers/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
|
||||
curl -L -o opencl/icd-loader.tar.gz https://github.com/KhronosGroup/OpenCL-ICD-Loader/archive/refs/tags/v${OPENCL_VERSION}.tar.gz
|
||||
tar -xaf opencl/headers.tar.gz -C opencl
|
||||
tar -xaf opencl/clhpp.tar.gz -C opencl
|
||||
tar -xaf opencl/icd-loader.tar.gz -C opencl
|
||||
sudo cp -r opencl/OpenCL-Headers-${OPENCL_VERSION}/CL ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
|
||||
sudo cp -r opencl/OpenCL-CLHPP-${OPENCL_VERSION}/include/CL/* ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include/CL
|
||||
cd opencl/OpenCL-ICD-Loader-${OPENCL_VERSION}
|
||||
cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -DOPENCL_ICD_LOADER_HEADERS_DIR=${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=31 -DANDROID_STL=c++_shared
|
||||
cmake --build build
|
||||
sudo cp build/libOpenCL.so ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android
|
||||
rm -rf opencl
|
||||
|
||||
- name: Install Hexagon SDK
|
||||
id: install_hexsdk
|
||||
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||
env:
|
||||
HEXSDK_VER: 6.4.0.2
|
||||
HEXTLS_VER: 19.0.04
|
||||
run: |
|
||||
curl -L -o hex-sdk.tar.gz https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v$HEXSDK_VER/hexagon-sdk-v$HEXSDK_VER-amd64-lnx.tar.xz
|
||||
mkdir hex-sdk
|
||||
tar -xaf hex-sdk.tar.gz -C hex-sdk
|
||||
ls -l hex-sdk
|
||||
sudo mv hex-sdk /opt/hexagon
|
||||
echo "HEXAGON_SDK_ROOT=/opt/hexagon/$HEXSDK_VER" >> "$GITHUB_ENV"
|
||||
echo "HEXAGON_TOOLS_ROOT=/opt/hexagon/$HEXSDK_VER/tools/HEXAGON_Tools/$HEXTLS_VER" >> "$GITHUB_ENV"
|
||||
echo "DEFAULT_HLOS_ARCH=64" >> "$GITHUB_ENV"
|
||||
echo "DEFAULT_TOOLS_VARIANT=toolv19" >> "$GITHUB_ENV"
|
||||
echo "DEFAULT_NO_QURT_INC=0" >> "$GITHUB_ENV"
|
||||
echo "DEFAULT_DSP_ARCH=v73" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Update CMake presets
|
||||
id: update_presets
|
||||
if: ${{ matrix.build == 'arm64-snapdragon' }}
|
||||
run: |
|
||||
cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||
|
||||
- name: Build
|
||||
id: ndk_build
|
||||
run: |
|
||||
cmake ${{ matrix.defines }} -B build
|
||||
cmake --build build
|
||||
cmake --install build --prefix pkg-adb/llama.cpp
|
||||
|
||||
- name: Test
|
||||
id: cmake_test
|
||||
run: |
|
||||
echo "FIXME: test on devices"
|
||||
|
||||
openEuler-latest-cmake-cann:
|
||||
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }}
|
||||
defaults:
|
||||
|
|
|
|||
|
|
@ -134,6 +134,8 @@ jobs:
|
|||
include:
|
||||
- build: 'x64'
|
||||
os: ubuntu-22.04
|
||||
- build: 's390x-z15' # z15 because our CI runners are on z15
|
||||
os: ubuntu-22.04-s390x
|
||||
# GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm
|
||||
# - build: 'arm64'
|
||||
# os: ubuntu-22.04-arm
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ name: Update Operations Documentation
|
|||
on:
|
||||
push:
|
||||
paths:
|
||||
- 'docs/ops.md'
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/ops.md'
|
||||
- 'docs/ops/**'
|
||||
- 'scripts/create_ops_docs.py'
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@
|
|||
/ggml/src/ggml-cuda/common.cuh @slaren
|
||||
/ggml/src/ggml-cuda/fattn* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/ggml-cuda.cu @slaren
|
||||
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/mmf.* @JohannesGaessler @am17an
|
||||
/ggml/src/ggml-cuda/mmq.* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/mmvf.* @JohannesGaessler
|
||||
/ggml/src/ggml-cuda/mmvq.* @JohannesGaessler
|
||||
|
|
@ -65,6 +65,7 @@
|
|||
/ggml/src/ggml-impl.h @ggerganov @slaren
|
||||
/ggml/src/ggml-metal/ @ggerganov
|
||||
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
|
||||
/ggml/src/ggml-hexagon/ @max-krasnyansky
|
||||
/ggml/src/ggml-opt.cpp @JohannesGaessler
|
||||
/ggml/src/ggml-quants.* @ggerganov
|
||||
/ggml/src/ggml-rpc/ @rgerganov
|
||||
|
|
|
|||
|
|
@ -138,6 +138,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
|
||||
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
|
||||
|
||||
#### Multimodal
|
||||
|
||||
|
|
@ -187,6 +188,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
- Swift [srgtuszy/llama-cpp-swift](https://github.com/srgtuszy/llama-cpp-swift)
|
||||
- Swift [ShenghaiWang/SwiftLlama](https://github.com/ShenghaiWang/SwiftLlama)
|
||||
- Delphi [Embarcadero/llama-cpp-delphi](https://github.com/Embarcadero/llama-cpp-delphi)
|
||||
- Go (no CGo needed): [hybridgroup/yzma](https://github.com/hybridgroup/yzma)
|
||||
|
||||
</details>
|
||||
|
||||
|
|
@ -278,6 +280,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
| [IBM zDNN](docs/backend/zDNN.md) | IBM Z & LinuxONE |
|
||||
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
|
||||
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
|
||||
| [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon |
|
||||
|
||||
## Obtaining and quantizing models
|
||||
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ if [ ! -z ${GG_BUILD_ROCM} ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DAMDGPU_TARGETS=${GG_BUILD_AMDGPU_TARGETS}"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGPU_TARGETS=${GG_BUILD_AMDGPU_TARGETS}"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_SYCL} ]; then
|
||||
|
|
|
|||
|
|
@ -1760,7 +1760,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP}));
|
||||
add_opt(common_arg(
|
||||
{"-t", "--threads"}, "N",
|
||||
string_format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads),
|
||||
string_format("number of CPU threads to use during generation (default: %d)", params.cpuparams.n_threads),
|
||||
[](common_params & params, int value) {
|
||||
params.cpuparams.n_threads = value;
|
||||
if (params.cpuparams.n_threads <= 0) {
|
||||
|
|
@ -3435,7 +3435,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
[](common_params & params) {
|
||||
params.use_jinja = true;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-format"}, "FORMAT",
|
||||
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
|
||||
|
|
|
|||
|
|
@ -41,9 +41,9 @@ static std::string build_repetition(const std::string & item_rule, int min_items
|
|||
return result;
|
||||
}
|
||||
|
||||
static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||
auto has_min = min_value != std::numeric_limits<int>::min();
|
||||
auto has_max = max_value != std::numeric_limits<int>::max();
|
||||
static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
|
||||
auto has_min = min_value != std::numeric_limits<int64_t>::min();
|
||||
auto has_max = max_value != std::numeric_limits<int64_t>::max();
|
||||
|
||||
auto digit_range = [&](char from, char to) {
|
||||
out << "[";
|
||||
|
|
@ -159,7 +159,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||
if (has_min) {
|
||||
if (min_value < 0) {
|
||||
out << "\"-\" (";
|
||||
_build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
||||
_build_min_max_int(std::numeric_limits<int64_t>::min(), -min_value, out, decimals_left, /* top_level= */ false);
|
||||
out << ") | [0] | [1-9] ";
|
||||
more_digits(0, decimals_left - 1);
|
||||
} else if (min_value == 0) {
|
||||
|
|
@ -194,7 +194,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||
}
|
||||
digit_range(c, c);
|
||||
out << " (";
|
||||
_build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
|
||||
_build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits<int64_t>::max(), out, less_decimals, /* top_level= */ false);
|
||||
out << ")";
|
||||
if (c < '9') {
|
||||
out << " | ";
|
||||
|
|
@ -216,7 +216,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
|
|||
_build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
|
||||
} else {
|
||||
out << "\"-\" (";
|
||||
_build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
|
||||
_build_min_max_int(-max_value, std::numeric_limits<int64_t>::max(), out, decimals_left, /* top_level= */ false);
|
||||
out << ")";
|
||||
}
|
||||
return;
|
||||
|
|
@ -925,17 +925,17 @@ public:
|
|||
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
|
||||
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
|
||||
} else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
|
||||
int min_value = std::numeric_limits<int>::min();
|
||||
int max_value = std::numeric_limits<int>::max();
|
||||
int64_t min_value = std::numeric_limits<int64_t>::min();
|
||||
int64_t max_value = std::numeric_limits<int64_t>::max();
|
||||
if (schema.contains("minimum")) {
|
||||
min_value = schema["minimum"].get<int>();
|
||||
min_value = schema["minimum"].get<int64_t>();
|
||||
} else if (schema.contains("exclusiveMinimum")) {
|
||||
min_value = schema["exclusiveMinimum"].get<int>() + 1;
|
||||
min_value = schema["exclusiveMinimum"].get<int64_t>() + 1;
|
||||
}
|
||||
if (schema.contains("maximum")) {
|
||||
max_value = schema["maximum"].get<int>();
|
||||
max_value = schema["maximum"].get<int64_t>();
|
||||
} else if (schema.contains("exclusiveMaximum")) {
|
||||
max_value = schema["exclusiveMaximum"].get<int>() - 1;
|
||||
max_value = schema["exclusiveMaximum"].get<int64_t>() - 1;
|
||||
}
|
||||
std::stringstream out;
|
||||
out << "(";
|
||||
|
|
|
|||
|
|
@ -29,12 +29,29 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
|||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||
import gguf
|
||||
from gguf.vocab import MistralTokenizerType, MistralVocab
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
)
|
||||
|
||||
_mistral_common_installed = True
|
||||
_mistral_import_error_msg = ""
|
||||
except ImportError:
|
||||
_MISTRAL_COMMON_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
||||
_MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
_mistral_common_installed = False
|
||||
TokenizerVersion = None
|
||||
Tekkenizer = None
|
||||
SentencePieceTokenizer = None
|
||||
_mistral_import_error_msg = (
|
||||
"Mistral format requires `mistral-common` to be installed. Please run "
|
||||
"`pip install mistral-common[image,audio]` to install it."
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("hf-to-gguf")
|
||||
|
|
@ -107,6 +124,9 @@ class ModelBase:
|
|||
type(self) is MmprojModel:
|
||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||
|
||||
if self.is_mistral_format and not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
self.dir_model = dir_model
|
||||
self.ftype = ftype
|
||||
self.fname_out = fname_out
|
||||
|
|
@ -892,8 +912,8 @@ class TextModel(ModelBase):
|
|||
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
|
||||
res = "mellum"
|
||||
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
|
||||
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
|
||||
res = "llada-moe"
|
||||
# ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0
|
||||
res = "bailingmoe2"
|
||||
if chkhsh == "53e325976a6e142379c19b09afcae354f2f496f147afa8f9e189a33fe4e3024e":
|
||||
# ref: https://huggingface.co/ibm-granite/granite-docling-258M
|
||||
res = "granite-docling"
|
||||
|
|
@ -1363,8 +1383,8 @@ class MmprojModel(ModelBase):
|
|||
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"]))
|
||||
|
||||
# preprocessor config
|
||||
image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||
image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
|
||||
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
|
||||
image_std = _MISTRAL_COMMON_DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"]
|
||||
|
||||
self.gguf_writer.add_vision_image_mean(image_mean)
|
||||
self.gguf_writer.add_vision_image_std(image_std)
|
||||
|
|
@ -2033,6 +2053,9 @@ class LlamaModel(TextModel):
|
|||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
||||
|
||||
def _set_vocab_mistral(self):
|
||||
if not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
vocab = MistralVocab(self.dir_model)
|
||||
logger.info(
|
||||
f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}."
|
||||
|
|
@ -8055,6 +8078,103 @@ class BailingMoeModel(TextModel):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("BailingMoeV2ForCausalLM")
|
||||
class BailingMoeV2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.BAILINGMOE2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if nextn_layers := self.hparams.get("num_nextn_predict_layers", 0):
|
||||
self.block_count = self.hparams["num_hidden_layers"] + nextn_layers
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
if (rope_dim := hparams.get("head_dim")) is None:
|
||||
rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
else:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"]))
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_group_count(hparams["n_group"])
|
||||
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["score_function"] == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif hparams["score_function"] == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
raise ValueError(f"Unsupported score_function value: {hparams['score_function']}")
|
||||
|
||||
if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
|
||||
self.gguf_writer.add_nextn_predict_layers(nextn_layers)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if "mlp.experts" in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
|
||||
return tensors
|
||||
|
||||
if name.endswith(".expert_bias"):
|
||||
name = name.replace(".expert_bias", ".expert_bias.bias")
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
|
||||
class GroveMoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GROVEMOE
|
||||
|
|
@ -9115,7 +9235,7 @@ class MistralModel(LlamaModel):
|
|||
|
||||
@staticmethod
|
||||
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
|
||||
assert TokenizerVersion is not None, "mistral_common is not installed"
|
||||
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
|
||||
assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), (
|
||||
f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}"
|
||||
)
|
||||
|
|
@ -9497,6 +9617,8 @@ def main() -> None:
|
|||
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
|
||||
|
||||
is_mistral_format = args.mistral_format
|
||||
if is_mistral_format and not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
disable_mistral_community_chat_template = args.disable_mistral_community_chat_template
|
||||
|
||||
with torch.inference_mode():
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ models = [
|
|||
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
|
||||
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
|
||||
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
|
||||
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
|
||||
{"name": "bailingmoe2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
|
||||
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
{
|
||||
"version": 4,
|
||||
"configurePresets": [
|
||||
{
|
||||
"name": "arm64-android-snapdragon",
|
||||
"hidden": true,
|
||||
"architecture": { "value": "arm64", "strategy": "external" },
|
||||
"toolset": { "value": "host=x86_64", "strategy": "external" },
|
||||
"cacheVariables": {
|
||||
"ANDROID_ABI": "arm64-v8a",
|
||||
"ANDROID_PLATFORM": "android-31",
|
||||
"CMAKE_TOOLCHAIN_FILE": "$env{ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake",
|
||||
"CMAKE_C_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE",
|
||||
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
|
||||
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "android_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "ON",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"LLAMA_CURL": "OFF"
|
||||
}
|
||||
},
|
||||
|
||||
{
|
||||
"name": "arm64-windows-snapdragon",
|
||||
"inherits": [ "base", "arm64-windows-llvm" ],
|
||||
"cacheVariables": {
|
||||
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
|
||||
"PREBUILT_LIB_DIR": "windows_aarch64",
|
||||
"GGML_OPENMP": "OFF",
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "ON",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"LLAMA_CURL": "OFF"
|
||||
}
|
||||
},
|
||||
|
||||
{ "name": "arm64-android-snapdragon-debug" , "inherits": [ "base", "arm64-android-snapdragon", "debug" ] },
|
||||
{ "name": "arm64-android-snapdragon-release", "inherits": [ "base", "arm64-android-snapdragon", "release" ] },
|
||||
|
||||
{ "name": "arm64-windows-snapdragon-debug" , "inherits": [ "base", "arm64-windows-snapdragon", "debug" ] },
|
||||
{ "name": "arm64-windows-snapdragon-release", "inherits": [ "base", "arm64-windows-snapdragon", "release" ] }
|
||||
]
|
||||
}
|
||||
|
|
@ -0,0 +1,239 @@
|
|||
# Snapdragon-based Android devices
|
||||
|
||||
## How to Build
|
||||
|
||||
The easiest way to build llama.cpp for a Snapdragon-based Android device is using the toolchain Docker image (see github.com/snapdragon-toolchain).
|
||||
This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc.
|
||||
|
||||
This method works on Linux, macOS, and Windows. macOS and Windows users should install Docker Desktop.
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ docker run -it -u $(id -u):$(id -g) --volume $(pwd):/workspace --platform linux/amd64 ghcr.io/snapdragon-toolchain/arm64-android:v0.3
|
||||
[d]/> cd /workspace
|
||||
```
|
||||
|
||||
The rest of the Android build process assumes that you're running inside the toolchain container.
|
||||
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
|
||||
|
||||
```
|
||||
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
|
||||
|
||||
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
|
||||
Preset CMake variables:
|
||||
ANDROID_ABI="arm64-v8a"
|
||||
...
|
||||
CMAKE_TOOLCHAIN_FILE="/opt/android-ndk-r28b/build/cmake/android.toolchain.cmake"
|
||||
GGML_HEXAGON="ON"
|
||||
GGML_OPENCL="ON"
|
||||
GGML_OPENMP="OFF"
|
||||
HEXAGON_SDK_ROOT="/opt/hexagon/6.4.0.2"
|
||||
...
|
||||
-- Including OpenCL backend
|
||||
-- Including Hexagon backend
|
||||
...
|
||||
-- Build files have been written to: /workspace/build-snapdragon
|
||||
|
||||
[d]/workspace> cmake --build build-snapdragon
|
||||
...
|
||||
[144/356] Performing build step for 'htp-v73'
|
||||
[1/16] Generating htp_iface_skel.c, htp_iface_stub.c, htp_iface.h
|
||||
[2/16] Building C object CMakeFiles/ggml-htp-v73.dir/hvx-sigmoid.c.obj
|
||||
[3/16] Building C object CMakeFiles/ggml-htp-v73.dir/htp-dma.c.obj
|
||||
[4/16] Building C object CMakeFiles/ggml-htp-v73.dir/worker-pool.c.obj
|
||||
...
|
||||
-- Installing: /workspace/build-snapdragon/ggml/src/ggml-hexagon/libggml-htp-v73.so
|
||||
-- Installing: /workspace/build-snapdragon/ggml/src/ggml-hexagon/libggml-htp-v75.so
|
||||
...
|
||||
```
|
||||
|
||||
To generate an installable "package" simply use cmake --install:
|
||||
|
||||
```
|
||||
[d]/workspace> cmake --install build-snapdragon --prefix pkg-adb/llama.cpp
|
||||
-- Install configuration: "Release"
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-cpu.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-opencl.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-hexagon.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v73.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v75.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v79.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v81.so
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml.so
|
||||
...
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-bench
|
||||
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-cli
|
||||
...
|
||||
```
|
||||
|
||||
## How to Install
|
||||
|
||||
For this step, your device needs to be configured for on-device development.
|
||||
Please see https://developer.android.com/studio/debug/dev-options for details.
|
||||
|
||||
Once ADB is enabled, use `adb push` to install `pkg-snapdragon` on the device.
|
||||
**Note that the toolchain Docker image doesn't have ADB and doesn't set up the ADB bridge. Please use native ADB on the host.**
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ adb push pkg-adb/llama.cpp /data/local/tmp/
|
||||
pkg-adb/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
|
||||
pkg-adb/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
|
||||
pkg-adb/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
|
||||
102 files pushed, 0 skipped. 186.9 MB/s (963151597 bytes in 4.914s)
|
||||
```
|
||||
|
||||
At this point, you should also install some models:
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ wget https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||
...
|
||||
2025-10-11 12:04:52 (10.7 MB/s) - ‘Llama-3.2-1B-Instruct-Q4_0.gguf’ saved [773025920/773025920]
|
||||
|
||||
~/src/llama.cpp$ adb push Llama-3.2-1B-Instruct-Q4_0.gguf /data/local/tmp/gguf
|
||||
Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 bytes in 19.250s)
|
||||
```
|
||||
|
||||
## How to Run
|
||||
|
||||
The easiest way to run llama.cpp cli tools is using provided wrapper scripts that properly set up all required environment variables.
|
||||
|
||||
llama.cpp supports three backends on Snapdragon-based devices: CPU, Adreno GPU (GPUOpenCL), and Hexagon NPU (HTP0-4).
|
||||
You can select which backend to run the model on using the `D=` variable, which maps to the `--device` option.
|
||||
|
||||
Hexagon NPU behaves as a "GPU" device when it comes to `-ngl` and other offload-related options.
|
||||
|
||||
Here are some examples of running various llama.cpp tools via ADB.
|
||||
|
||||
Simple question for Llama-3.2-1B
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ M=Llama-3.2-1B-Instruct-Q4_0.gguf D=HTP0 ./scripts/snapdragon/adb/run-cli.sh -no-cnv -p "what is the most popular cookie in the world?"
|
||||
...
|
||||
ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1
|
||||
ggml-hex: Hexagon Arch version v79
|
||||
ggml-hex: allocating new session: HTP0
|
||||
ggml-hex: new session: HTP0 : session-id 0 domain-id 3 uri file:///libggml-htp-v79.so?htp_iface_skel_handle_invoke&_modver=1.0&_dom=cdsp&_session=0 handle 0xb4000072c7955e50
|
||||
...
|
||||
load_tensors: offloading output layer to GPU
|
||||
load_tensors: offloaded 17/17 layers to GPU
|
||||
load_tensors: CPU model buffer size = 225.49 MiB
|
||||
load_tensors: HTP0 model buffer size = 0.26 MiB
|
||||
load_tensors: HTP0-REPACK model buffer size = 504.00 MiB
|
||||
...
|
||||
I hope this helps you understand the world's most popular cookies! [end of text]
|
||||
...
|
||||
llama_perf_sampler_print: sampling time = 30.08 ms / 487 runs ( 0.06 ms per token, 16191.77 tokens per second)
|
||||
llama_perf_context_print: load time = 617.94 ms
|
||||
llama_perf_context_print: prompt eval time = 80.76 ms / 11 tokens ( 7.34 ms per token, 136.21 tokens per second)
|
||||
llama_perf_context_print: eval time = 9210.59 ms / 475 runs ( 19.39 ms per token, 51.57 tokens per second)
|
||||
llama_perf_context_print: total time = 9454.92 ms / 486 tokens
|
||||
llama_perf_context_print: graphs reused = 473
|
||||
llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted |
|
||||
llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - Host | 439 = 225 + 136 + 77 |
|
||||
llama_memory_breakdown_print: | - HTP0-REPACK | 504 = 504 + 0 + 0 |
|
||||
```
|
||||
|
||||
Summary request for OLMoE-1B-7B. This is a large model that requires two HTP sessions/devices
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ M=OLMoE-1B-7B-0125-Instruct-Q4_0.gguf NDEV=2 D=HTP0,HTP1 ./scripts/snapdragon/adb/run-cli.sh -f surfing.txt -no-cnv
|
||||
...
|
||||
ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1
|
||||
ggml-hex: Hexagon Arch version v81
|
||||
ggml-hex: allocating new session: HTP0
|
||||
ggml-hex: allocating new session: HTP1
|
||||
...
|
||||
load_tensors: offloading output layer to GPU
|
||||
load_tensors: offloaded 17/17 layers to GPU
|
||||
load_tensors: CPU model buffer size = 143.86 MiB
|
||||
load_tensors: HTP1 model buffer size = 0.23 MiB
|
||||
load_tensors: HTP1-REPACK model buffer size = 1575.00 MiB
|
||||
load_tensors: HTP0 model buffer size = 0.28 MiB
|
||||
load_tensors: HTP0-REPACK model buffer size = 2025.00 MiB
|
||||
...
|
||||
llama_context: CPU output buffer size = 0.19 MiB
|
||||
llama_kv_cache: HTP1 KV buffer size = 238.00 MiB
|
||||
llama_kv_cache: HTP0 KV buffer size = 306.00 MiB
|
||||
llama_kv_cache: size = 544.00 MiB ( 8192 cells, 16 layers, 1/1 seqs), K (q8_0): 272.00 MiB, V (q8_0): 272.00 MiB
|
||||
llama_context: HTP0 compute buffer size = 15.00 MiB
|
||||
llama_context: HTP1 compute buffer size = 15.00 MiB
|
||||
llama_context: CPU compute buffer size = 24.56 MiB
|
||||
...
|
||||
llama_perf_context_print: prompt eval time = 1730.57 ms / 212 tokens ( 8.16 ms per token, 122.50 tokens per second)
|
||||
llama_perf_context_print: eval time = 5624.75 ms / 257 runs ( 21.89 ms per token, 45.69 tokens per second)
|
||||
llama_perf_context_print: total time = 7377.33 ms / 469 tokens
|
||||
llama_perf_context_print: graphs reused = 255
|
||||
llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted |
|
||||
llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - HTP1 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - Host | 742 = 144 + 544 + 54 |
|
||||
llama_memory_breakdown_print: | - HTP1-REPACK | 1575 = 1575 + 0 + 0 |
|
||||
llama_memory_breakdown_print: | - HTP0-REPACK | 2025 = 2025 + 0 + 0 |
|
||||
```
|
||||
|
||||
Op test for MUL_MAT
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ HB=0 ./scripts/snapdragon/adb/run-tool.sh test-backend-ops -b HTP0 -o MUL_MAT
|
||||
...
|
||||
Backend 2/3: HTP0
|
||||
Device description: Hexagon
|
||||
Device memory: 2048 MB (2048 MB free)
|
||||
MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK
|
||||
MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK
|
||||
MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK
|
||||
|
||||
~/src/llama.cpp-hexagon$ M=Llama-3.2-1B-Instruct-Q4_0.gguf ./scripts/snapdragon/adb/run-bench.sh -p 128 -n 64
|
||||
...
|
||||
ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1
|
||||
ggml-hex: Hexagon Arch version v79
|
||||
ggml-hex: allocating new session: HTP0
|
||||
ggml-hex: new session: HTP0 : session-id 0 domain-id 3 uri file:///libggml-htp-v79.so?htp_iface_skel_handle_invoke&_modver=1.0&_dom=cdsp&_session=0 handle 0xb400007d4b231090
|
||||
| model | size | params | backend | ngl | threads | n_batch | mmap | test | t/s |
|
||||
| ---------------| ---------: | -----: | ---------- | --: | ------: | ------: | ---: | ----: | ------------: |
|
||||
| llama 1B Q4_0 | 729.75 MiB | 1.24 B | HTP | 99 | 4 | 128 | 0 | pp128 | 169.42 ± 1.75 |
|
||||
| llama 1B Q4_0 | 729.75 MiB | 1.24 B | HTP | 99 | 4 | 128 | 0 | tg64 | 51.54 ± 1.13 |
|
||||
|
||||
build: 6a8cf8914 (6733)
|
||||
```
|
||||
|
||||
## Environment variables
|
||||
|
||||
- `GGML_HEXAGON_NDEV=1`
|
||||
Controls the number of devices/sessions to allocate. The default is 1.
|
||||
Most quantized models under 4B fit into a single session; an 8B model needs two, and a 20B model needs four.
|
||||
|
||||
- `GGML_HEXAGON_NHVX=0`
|
||||
Controls the number of HVX hardware threads to use. The default is all (actual number varies depending on the hardware version).
|
||||
|
||||
- `GGML_HEXAGON_HOSTBUF=1`
|
||||
Controls whether the Hexagon backend allocates host buffers. By default, all buffers except for REPACK are host buffers.
|
||||
This option is required for testing Ops that require REPACK buffers (MUL_MAT and MUL_MAT_ID).
|
||||
|
||||
- `GGML_HEXAGON_VERBOSE=1`
|
||||
Enables verbose logging of Ops from the backend. Example output:
|
||||
|
||||
```
|
||||
ggml-hex: HTP0 graph-compute n_nodes 2
|
||||
ggml-hex: HTP0 matmul : blk.27.ffn_up.weight x ffn_norm-27 -> ffn_up-27 : 3072:8192 x 3072:1 -> 8192:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x1
|
||||
ggml-hex: HTP0 matmul : blk.27.ffn_gate.weight x ffn_norm-27 -> ffn_gate-27 : 3072:8192 x 3072:1 -> 8192:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x3
|
||||
ggml-hex: HTP0 graph-compute n_nodes 1
|
||||
ggml-hex: HTP0 matmul : blk.27.ffn_down.weight x ffn_gate_par-27 -> ffn_out-27 : 8192:3072 x 8192:1 -> 3072:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x0
|
||||
ggml-hex: HTP0 get-tensor result_output : data 0x7592487000 offset 0 size 513024
|
||||
```
|
||||
|
||||
- `GGML_HEXAGON_PROFILE=1`
|
||||
Generates a host-side profile for the ggml-hexagon Ops.
|
||||
|
||||
- `GGML_HEXAGON_OPMASK=0x0`
|
||||
Allows enabling specific stages of the processing pipeline:
|
||||
|
||||
- `0x1` Enable Op Queue (i.e., queuing Ops into NPU)
|
||||
- `0x2` Enable Dynamic Quantizer (if needed for the Op)
|
||||
- `0x4` Enable Op Compute (MUL_MAT, etc.)
|
||||
|
||||
Examples:
|
||||
|
||||
`GGML_HEXAGON_OPMASK=0x1 llama-cli ...` - Ops are enqueued but NPU-side processing is stubbed out
|
||||
`GGML_HEXAGON_OPMASK=0x3 llama-cli ...` - NPU performs dynamic quantization and skips the rest
|
||||
`GGML_HEXAGON_OPMASK=0x7 llama-cli ...` - Full queuing and processing of Ops (default)
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
# Hexagon backend developer details
|
||||
|
||||
## Backend libraries
|
||||
|
||||
The Hexagon backend consist of two parts:
|
||||
|
||||
- `libggml-hexagon`
|
||||
This is the regular CPU-side GGML backend library, either shared or statically linked
|
||||
|
||||
- `libggml-htp-vNN`
|
||||
This is the NPU-side (HTP stands for Hexagon Tensor Processor) shared library that contains the Op dispatcher and kernels.
|
||||
The correct library is selected automatically at runtime based on the HW version.
|
||||
|
||||
Here is an example of the build artifacts
|
||||
|
||||
```
|
||||
~/src/llama.cpp$ ls -l pkg-adb/llama.cpp/lib/libggml*
|
||||
pkg-adb/llama.cpp/lib/libggml-base.so
|
||||
pkg-adb/llama.cpp/lib/libggml-cpu.so
|
||||
pkg-adb/llama.cpp/lib/libggml-hexagon.so <<< CPU library
|
||||
pkg-adb/llama.cpp/lib/libggml-htp-v73.so <<< HTP op/kernels for Hexagon v73
|
||||
pkg-adb/llama.cpp/lib/libggml-htp-v75.so
|
||||
pkg-adb/llama.cpp/lib/libggml-htp-v79.so
|
||||
pkg-adb/llama.cpp/lib/libggml-htp-v81.so
|
||||
```
|
||||
|
||||
## Memory buffers
|
||||
|
||||
Hexagon NPU backend takes advantage of the Snapdragon's unified memory model where all buffers are fully accessible by the CPU and GPU.
|
||||
The NPU does have a dedicated tightly-coupled memory called VTCM but that memory is used only for intermediate data (e.g. dynamically
|
||||
quantized tensors) or temporary data (chunks of the weight tensors fetched via DMA).
|
||||
|
||||
Please note that currently the Hexagon backend does not implement SET/GET_ROWS Ops because there is no advantage in offloading those
|
||||
to the NPU at this point.
|
||||
|
||||
The backend does allocates non-host buffers for the tensors with datatypes that require repacking: Q4_0, Q8_0, MXFP4.
|
||||
From the MMU perspective these buffers are still regular buffers (normal access by the CPU) they are marked as non-host simply to force
|
||||
the repacking.
|
||||
|
||||
## Large model handling
|
||||
|
||||
Hexagon NPU session (aka Process Domain (PD) in the Hexagon docs) is limited to a memory mapping of around 3.5GB.
|
||||
In llama.cpp/GGML the Hexagon session is mapped to a single GGML backend device (HTP0, HTP1, etc).
|
||||
|
||||
In order to map models larger than 3.5GB we need to allocate multiple devices and split the model.
|
||||
For this we're taking advantage of the llama.cpp/GGML multi-GPU layer-splitting support.
|
||||
Each Hexagon device behaves like a GPU from the offload and model splitting perspective.
|
||||
|
||||
Here is an example of running GPT-OSS-20B model on a newer Snapdragon device with 16GB of DDR.
|
||||
|
||||
```
|
||||
M=gpt-oss-20b-Q4_0.gguf NDEV=4 D=HTP0,HTP1,HTP2,HTP3 P=surfing.txt scripts/snapdragon/adb/run-cli.sh -no-cnv -f surfing.txt -n 32
|
||||
...
|
||||
LD_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib
|
||||
ADSP_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib
|
||||
GGML_HEXAGON_NDEV=4 ./bin/llama-cli --no-mmap -m /data/local/tmp/llama.cpp/../gguf/gpt-oss-20b-Q4_0.gguf
|
||||
-t 4 --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on -ngl 99 --device HTP0,HTP1,HTP2,HTP3 -no-cnv -f surfing.txt
|
||||
...
|
||||
llama_model_loader: - type f32: 289 tensors
|
||||
llama_model_loader: - type q4_0: 96 tensors
|
||||
llama_model_loader: - type q8_0: 2 tensors
|
||||
llama_model_loader: - type mxfp4: 72 tensors
|
||||
...
|
||||
load_tensors: offloaded 25/25 layers to GPU
|
||||
load_tensors: CPU model buffer size = 1182.09 MiB
|
||||
load_tensors: HTP1 model buffer size = 6.64 MiB
|
||||
load_tensors: HTP1-REPACK model buffer size = 2505.94 MiB
|
||||
load_tensors: HTP3 model buffer size = 5.55 MiB
|
||||
load_tensors: HTP3-REPACK model buffer size = 2088.28 MiB
|
||||
load_tensors: HTP0 model buffer size = 7.75 MiB
|
||||
load_tensors: HTP0-REPACK model buffer size = 2923.59 MiB
|
||||
load_tensors: HTP2 model buffer size = 6.64 MiB
|
||||
load_tensors: HTP2-REPACK model buffer size = 2505.94 MiB
|
||||
...
|
||||
llama_context: n_ctx_per_seq (8192) < n_ctx_train (131072) -- the full capacity of the model will not be utilized
|
||||
llama_context: CPU output buffer size = 0.77 MiB
|
||||
llama_kv_cache_iswa: creating non-SWA KV cache, size = 8192 cells
|
||||
llama_kv_cache: HTP1 KV buffer size = 25.50 MiB
|
||||
llama_kv_cache: HTP3 KV buffer size = 25.50 MiB
|
||||
llama_kv_cache: HTP0 KV buffer size = 25.50 MiB
|
||||
llama_kv_cache: HTP2 KV buffer size = 25.50 MiB
|
||||
llama_kv_cache: size = 102.00 MiB ( 8192 cells, 12 layers, 1/1 seqs), K (q8_0): 51.00 MiB, V (q8_0): 51.00 MiB
|
||||
llama_kv_cache_iswa: creating SWA KV cache, size = 256 cells
|
||||
llama_kv_cache: HTP1 KV buffer size = 0.80 MiB
|
||||
llama_kv_cache: HTP3 KV buffer size = 0.53 MiB
|
||||
llama_kv_cache: HTP0 KV buffer size = 1.06 MiB
|
||||
llama_kv_cache: HTP2 KV buffer size = 0.80 MiB
|
||||
llama_kv_cache: size = 3.19 MiB ( 256 cells, 12 layers, 1/1 seqs), K (q8_0): 1.59 MiB, V (q8_0): 1.59 MiB
|
||||
llama_context: HTP0 compute buffer size = 16.06 MiB
|
||||
llama_context: HTP1 compute buffer size = 16.06 MiB
|
||||
llama_context: HTP2 compute buffer size = 16.06 MiB
|
||||
llama_context: HTP3 compute buffer size = 16.06 MiB
|
||||
llama_context: CPU compute buffer size = 98.19 MiB
|
||||
...
|
||||
llama_perf_context_print: prompt eval time = 3843.67 ms / 197 tokens ( 19.51 ms per token, 51.25 tokens per second)
|
||||
llama_perf_context_print: eval time = 1686.13 ms / 31 runs ( 54.39 ms per token, 18.39 tokens per second)
|
||||
llama_perf_context_print: total time = 6266.30 ms / 228 tokens
|
||||
llama_perf_context_print: graphs reused = 30
|
||||
llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted |
|
||||
llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - HTP1 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - HTP2 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - HTP3 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 |
|
||||
llama_memory_breakdown_print: | - Host | 1476 = 1208 + 105 + 162 |
|
||||
llama_memory_breakdown_print: | - HTP1-REPACK | 2505 = 2505 + 0 + 0 |
|
||||
llama_memory_breakdown_print: | - HTP3-REPACK | 2088 = 2088 + 0 + 0 |
|
||||
llama_memory_breakdown_print: | - HTP0-REPACK | 2923 = 2923 + 0 + 0 |
|
||||
llama_memory_breakdown_print: | - HTP2-REPACK | 2505 = 2505 + 0 + 0 |
|
||||
```
|
||||
14
docs/ops.md
14
docs/ops.md
|
|
@ -22,7 +22,7 @@ Legend:
|
|||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
|
|
@ -42,7 +42,7 @@ Legend:
|
|||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
|
|
@ -72,7 +72,7 @@ Legend:
|
|||
| OPT_STEP_SGD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||
| PAD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||
|
|
@ -84,7 +84,7 @@ Legend:
|
|||
| ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
|
@ -100,8 +100,8 @@ Legend:
|
|||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
||||
| SUM | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
||||
|
|
@ -111,6 +111,6 @@ Legend:
|
|||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
| TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
|
|
|||
|
|
@ -31,6 +31,14 @@
|
|||
"SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","XIELU","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
|
||||
"SYCL0","XIELU","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
|
||||
"SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","TRUNC","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ABS","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
"SYCL0","ABS","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
|
||||
"SYCL0","SGN","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
|
|
@ -95,6 +103,14 @@
|
|||
"SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","XIELU","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","SYCL"
|
||||
"SYCL0","XIELU","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","SYCL"
|
||||
"SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","TRUNC","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
|
||||
"SYCL0","ABS","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
"SYCL0","ABS","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
|
||||
"SYCL0","SGN","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
|
||||
|
|
@ -9363,8 +9379,8 @@
|
|||
"SYCL0","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","SYCL"
|
||||
"SYCL0","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","1","yes","SYCL"
|
||||
"SYCL0","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0","support","1","yes","SYCL"
|
||||
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","SYCL"
|
||||
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","SYCL"
|
||||
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","yes","SYCL"
|
||||
"SYCL0","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","yes","SYCL"
|
||||
"SYCL0","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","0","no","SYCL"
|
||||
"SYCL0","ARANGE","type=f32,start=0.000000,stop=10.000000,step=1.000000","support","0","no","SYCL"
|
||||
"SYCL0","TIMESTEP_EMBEDDING","type=f32,ne_a=[2,1,1,1],dim=320,max_period=10000","support","1","yes","SYCL"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
|
|
@ -3263,27 +3263,27 @@
|
|||
"Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=0","support","1","yes","Vulkan"
|
||||
"Vulkan0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=1.000000,broadcast=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[4,1536,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
"Vulkan0","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=1,n_seqs=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=1","support","1","yes","Vulkan"
|
||||
"Vulkan0","RWKV_WKV6","type=f32,head_count=32,head_size=64,n_seq_tokens=32,n_seqs=4","support","1","yes","Vulkan"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
|
|
@ -251,6 +251,8 @@ option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adr
|
|||
set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
|
||||
"gmml: OpenCL API version to target")
|
||||
|
||||
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
|
||||
|
||||
# toolchain for vulkan-shaders-gen
|
||||
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// backend API
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_hexagon_init(void);
|
||||
|
||||
GGML_BACKEND_API bool ggml_backend_is_hexagon(ggml_backend_t backend);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_hexagon_reg(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
|
|||
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
|
||||
|
||||
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
||||
size_t n_threads, size_t n_devices,
|
||||
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
|
||||
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
|
||||
|
|
|
|||
|
|
@ -307,6 +307,10 @@ function(ggml_add_cpu_backend_variant tag_name)
|
|||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
foreach (feat ${ARGN})
|
||||
set(GGML_INTERNAL_${feat} ON)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
ggml_add_cpu_backend_variant_impl(${tag_name})
|
||||
|
|
@ -371,6 +375,14 @@ if (GGML_CPU_ALL_VARIANTS)
|
|||
else()
|
||||
message(FATAL_ERROR "Unsupported PowerPC target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
ggml_add_cpu_backend_variant(s390x_z15 Z15 VXE)
|
||||
# ggml_add_cpu_backend_variant(s390x_z16 Z16 VXE)
|
||||
# ggml_add_cpu_backend_variant(s390x_z17 Z17 VXE)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported s390x target OS: ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}")
|
||||
endif()
|
||||
|
|
@ -390,6 +402,7 @@ ggml_add_backend(Vulkan)
|
|||
ggml_add_backend(WebGPU)
|
||||
ggml_add_backend(zDNN)
|
||||
ggml_add_backend(OpenCL)
|
||||
ggml_add_backend(Hexagon)
|
||||
|
||||
foreach (target ggml-base ggml)
|
||||
target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>)
|
||||
|
|
|
|||
|
|
@ -598,6 +598,26 @@ static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor
|
|||
return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated;
|
||||
}
|
||||
|
||||
// free the extra space at the end if the new tensor is smaller
|
||||
static void ggml_gallocr_free_extra_space(ggml_gallocr_t galloc, struct ggml_tensor * node, struct ggml_tensor * parent) {
|
||||
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
||||
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
|
||||
|
||||
size_t parent_size = ggml_backend_buft_get_alloc_size(galloc->bufts[p_hn->buffer_id], parent);
|
||||
size_t node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[hn->buffer_id], node);
|
||||
|
||||
GGML_ASSERT(parent_size >= node_size);
|
||||
|
||||
if (parent_size > node_size) {
|
||||
struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id];
|
||||
struct buffer_address p_addr = p_hn->addr;
|
||||
p_addr.offset += node_size;
|
||||
size_t extra_size = parent_size - node_size;
|
||||
AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name);
|
||||
ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) {
|
||||
GGML_ASSERT(buffer_id >= 0);
|
||||
struct hash_node * hn = ggml_gallocr_hash_get(galloc, node);
|
||||
|
|
@ -643,6 +663,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
|||
hn->addr = p_hn->addr;
|
||||
p_hn->allocated = false; // avoid freeing the parent
|
||||
view_src_hn->allocated = false;
|
||||
ggml_gallocr_free_extra_space(galloc, node, view_src);
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
|
|
@ -650,6 +671,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
|||
hn->buffer_id = p_hn->buffer_id;
|
||||
hn->addr = p_hn->addr;
|
||||
p_hn->allocated = false; // avoid freeing the parent
|
||||
ggml_gallocr_free_extra_space(galloc, node, parent);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -57,6 +57,10 @@
|
|||
#include "ggml-opencl.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_HEXAGON
|
||||
#include "ggml-hexagon.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_BLAS
|
||||
#include "ggml-blas.h"
|
||||
#endif
|
||||
|
|
@ -199,6 +203,9 @@ struct ggml_backend_registry {
|
|||
#ifdef GGML_USE_OPENCL
|
||||
register_backend(ggml_backend_opencl_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_HEXAGON
|
||||
register_backend(ggml_backend_hexagon_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_CANN
|
||||
register_backend(ggml_backend_cann_reg());
|
||||
#endif
|
||||
|
|
@ -598,6 +605,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
|||
ggml_backend_load_best("sycl", silent, dir_path);
|
||||
ggml_backend_load_best("vulkan", silent, dir_path);
|
||||
ggml_backend_load_best("opencl", silent, dir_path);
|
||||
ggml_backend_load_best("hexagon", silent, dir_path);
|
||||
ggml_backend_load_best("musa", silent, dir_path);
|
||||
ggml_backend_load_best("cpu", silent, dir_path);
|
||||
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
||||
|
|
|
|||
|
|
@ -51,8 +51,11 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
|
|||
return ACL_DT_UNDEFINED;
|
||||
}
|
||||
|
||||
aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
|
||||
size_t* nb, int64_t dims, aclFormat format,
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne,
|
||||
size_t * nb,
|
||||
int64_t dims,
|
||||
aclFormat format,
|
||||
size_t offset) {
|
||||
// If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
|
||||
// added.
|
||||
|
|
@ -84,15 +87,13 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
|
|||
std::reverse(acl_ne, acl_ne + final_dims);
|
||||
std::reverse(acl_stride, acl_stride + final_dims);
|
||||
|
||||
aclTensor* acl_tensor = aclCreateTensor(
|
||||
acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
|
||||
elem_offset, format, &acl_storage_len, 1,
|
||||
tensor->data);
|
||||
aclTensor * acl_tensor = aclCreateTensor(acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
|
||||
elem_offset, format, &acl_storage_len, 1, tensor->data);
|
||||
|
||||
return acl_tensor;
|
||||
}
|
||||
|
||||
bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) {
|
||||
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (t1->ne[i] != t0->ne[i] && t1->ne[i] != 1) {
|
||||
return true;
|
||||
|
|
@ -101,11 +102,12 @@ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) {
|
|||
return false;
|
||||
}
|
||||
|
||||
int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
|
||||
const ggml_tensor* src1,
|
||||
int64_t* bcast_src0_ne,
|
||||
int64_t* bcast_src1_ne, size_t* bcast_src0_nb,
|
||||
size_t* bcast_src1_nb) {
|
||||
int64_t ggml_cann_get_bcast_shape(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
int64_t * bcast_src0_ne,
|
||||
int64_t * bcast_src1_ne,
|
||||
size_t * bcast_src0_nb,
|
||||
size_t * bcast_src1_nb) {
|
||||
GGML_ASSERT(ggml_can_repeat(src1, src0));
|
||||
int bcast_dim_cnt = 0;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
|
|
@ -119,21 +121,26 @@ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
|
|||
// Need to add an extra dim.
|
||||
bcast_src0_ne[bcast_dim_cnt] = nr;
|
||||
bcast_src1_ne[bcast_dim_cnt] = 1;
|
||||
bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] *
|
||||
bcast_src0_ne[bcast_dim_cnt - 1];
|
||||
bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] *
|
||||
bcast_src1_ne[bcast_dim_cnt - 1];
|
||||
bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] * bcast_src0_ne[bcast_dim_cnt - 1];
|
||||
bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] * bcast_src1_ne[bcast_dim_cnt - 1];
|
||||
bcast_dim_cnt++;
|
||||
}
|
||||
}
|
||||
return bcast_dim_cnt;
|
||||
}
|
||||
|
||||
int64_t ggml_cann_get_mulmat_bcast_shape(
|
||||
const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
|
||||
const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
|
||||
int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
|
||||
size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb) {
|
||||
int64_t ggml_cann_get_mulmat_bcast_shape(const int64_t * input_ne,
|
||||
const int64_t * weight_ne,
|
||||
const int64_t * dst_ne,
|
||||
const size_t * input_nb,
|
||||
const size_t * weight_nb,
|
||||
const size_t * dst_nb,
|
||||
int64_t * bcast_input_ne,
|
||||
int64_t * bcast_weight_ne,
|
||||
int64_t * bcast_dst_ne,
|
||||
size_t * bcast_input_nb,
|
||||
size_t * bcast_weight_nb,
|
||||
size_t * bcast_dst_nb) {
|
||||
// input and dst shoule in same shape, except first two dims.
|
||||
GGML_ASSERT(input_ne[2] == dst_ne[2]);
|
||||
GGML_ASSERT(input_ne[3] == dst_ne[3]);
|
||||
|
|
@ -169,13 +176,9 @@ int64_t ggml_cann_get_mulmat_bcast_shape(
|
|||
bcast_input_ne[bcast_dim_cnt] = input_ne[i] / nr;
|
||||
bcast_dst_ne[bcast_dim_cnt] = dst_ne[i] / nr;
|
||||
bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];
|
||||
bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] *
|
||||
bcast_input_ne[bcast_dim_cnt - 1];
|
||||
bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] *
|
||||
bcast_dst_ne[bcast_dim_cnt - 1];
|
||||
bcast_weight_nb[bcast_dim_cnt] =
|
||||
bcast_weight_nb[bcast_dim_cnt - 1] *
|
||||
bcast_weight_ne[bcast_dim_cnt - 1];
|
||||
bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] * bcast_input_ne[bcast_dim_cnt - 1];
|
||||
bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] * bcast_dst_ne[bcast_dim_cnt - 1];
|
||||
bcast_weight_nb[bcast_dim_cnt] = bcast_weight_nb[bcast_dim_cnt - 1] * bcast_weight_ne[bcast_dim_cnt - 1];
|
||||
bcast_dim_cnt++;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,8 +62,10 @@ aclDataType ggml_cann_type_mapping(ggml_type type);
|
|||
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = nullptr,
|
||||
size_t* nb = nullptr, int64_t dims = 0,
|
||||
aclTensor * ggml_cann_create_tensor(const ggml_tensor * tensor,
|
||||
int64_t * ne = nullptr,
|
||||
size_t * nb = nullptr,
|
||||
int64_t dims = 0,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0);
|
||||
|
||||
|
|
@ -87,9 +89,12 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = null
|
|||
* @param offset Offset in bytes for the ACL tensor data. Defaults to 0.
|
||||
* @return Pointer to the created ACL tensor.
|
||||
*/
|
||||
template<typename TYPE>
|
||||
aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
|
||||
TYPE type_size, int64_t* ne, TYPE* nb,
|
||||
template <typename TYPE>
|
||||
aclTensor * ggml_cann_create_tensor(void * data_ptr,
|
||||
aclDataType dtype,
|
||||
TYPE type_size,
|
||||
int64_t * ne,
|
||||
TYPE * nb,
|
||||
int64_t dims,
|
||||
aclFormat format = ACL_FORMAT_ND,
|
||||
size_t offset = 0) {
|
||||
|
|
@ -109,9 +114,8 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
|
|||
std::reverse(tmp_ne, tmp_ne + dims);
|
||||
std::reverse(tmp_stride, tmp_stride + dims);
|
||||
|
||||
aclTensor* acl_tensor =
|
||||
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size,
|
||||
format, &acl_storage_len, 1, data_ptr);
|
||||
aclTensor * acl_tensor =
|
||||
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size, format, &acl_storage_len, 1, data_ptr);
|
||||
|
||||
return acl_tensor;
|
||||
}
|
||||
|
|
@ -132,7 +136,7 @@ aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
|
|||
* to 1. If such a dimension is found, broadcasting is required to align t1
|
||||
* with t0 for element-wise operations.
|
||||
*/
|
||||
bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1);
|
||||
bool ggml_cann_need_bcast(const ggml_tensor * t0, const ggml_tensor * t1);
|
||||
|
||||
/**
|
||||
* @brief Computes broadcast shapes and strides for two ggml_tensors.
|
||||
|
|
@ -187,9 +191,12 @@ bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1);
|
|||
* dim1 in a inserted dim, should add nb for dim1,
|
||||
* and all other nb moves to next in order.
|
||||
*/
|
||||
int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* src1,
|
||||
int64_t* bcast_ne_src0, int64_t* bcast_ne_src1,
|
||||
size_t* bcast_nb_src0, size_t* bcast_nb_src1);
|
||||
int64_t ggml_cann_get_bcast_shape(const ggml_tensor * src0,
|
||||
const ggml_tensor * src1,
|
||||
int64_t * bcast_ne_src0,
|
||||
int64_t * bcast_ne_src1,
|
||||
size_t * bcast_nb_src0,
|
||||
size_t * bcast_nb_src1);
|
||||
|
||||
// Bcast macro to avoid duplicate code.
|
||||
#define BCAST_SHAPE(src0, src1) \
|
||||
|
|
@ -197,9 +204,8 @@ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* sr
|
|||
int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##src0##_nb[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##src1##_nb[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_dims = ggml_cann_get_bcast_shape( \
|
||||
src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, bcast_##src0##_nb, \
|
||||
bcast_##src1##_nb);
|
||||
int64_t bcast_dims = ggml_cann_get_bcast_shape(src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, \
|
||||
bcast_##src0##_nb, bcast_##src1##_nb);
|
||||
|
||||
#define BCAST_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
|
||||
|
||||
|
|
@ -233,11 +239,18 @@ int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* sr
|
|||
* before cast dim.
|
||||
* @sa ggml_cann_get_bcast_shape
|
||||
*/
|
||||
int64_t ggml_cann_get_mulmat_bcast_shape(
|
||||
const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
|
||||
const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
|
||||
int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
|
||||
size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb);
|
||||
int64_t ggml_cann_get_mulmat_bcast_shape(const int64_t * input_ne,
|
||||
const int64_t * weight_ne,
|
||||
const int64_t * dst_ne,
|
||||
const size_t * input_nb,
|
||||
const size_t * weight_nb,
|
||||
const size_t * dst_nb,
|
||||
int64_t * bcast_input_ne,
|
||||
int64_t * bcast_weight_ne,
|
||||
int64_t * bcast_dst_ne,
|
||||
size_t * bcast_input_nb,
|
||||
size_t * bcast_weight_nb,
|
||||
size_t * bcast_dst_nb);
|
||||
|
||||
// Bcast macro to avoid duplicate code.
|
||||
#define BCAST_MUL_MAT_SHAPE(input, weight, dst) \
|
||||
|
|
@ -248,11 +261,9 @@ int64_t ggml_cann_get_mulmat_bcast_shape(
|
|||
size_t bcast_##weight##_nb[GGML_MAX_DIMS * 2]; \
|
||||
size_t bcast_##dst##_nb[GGML_MAX_DIMS * 2]; \
|
||||
int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape( \
|
||||
input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, \
|
||||
bcast_##input##_ne, bcast_##weight##_ne, bcast_##dst##_ne, \
|
||||
bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);
|
||||
input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, bcast_##input##_ne, bcast_##weight##_ne, \
|
||||
bcast_##dst##_ne, bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);
|
||||
|
||||
#define BCAST_MUL_MAT_PARAM(tensor) \
|
||||
bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
|
||||
#define BCAST_MUL_MAT_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
|
||||
|
||||
#endif // CANN_ACL_TENSOR_H
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -62,7 +62,7 @@
|
|||
* @param dst The ggml tensor representing the destination, which op is
|
||||
* GGML_OP_REPEAT and specifies the desired dimensions.
|
||||
*/
|
||||
void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies the Leaky ReLU activation function to a tensor using the CANN
|
||||
|
|
@ -82,7 +82,7 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result of the Leaky ReLU
|
||||
* activation is stored, which op is `GGML_OP_LEAKY_RELU`
|
||||
*/
|
||||
void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_leaky_relu(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Concatenates multiple tensors along a specified dimension using the
|
||||
|
|
@ -97,7 +97,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @attention tensorList length should be 2 and the dimension using for concat
|
||||
* default to 1.
|
||||
*/
|
||||
void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_concat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Generates a sequence of evenly spaced values within a specified
|
||||
|
|
@ -113,7 +113,7 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* `start`, 'stop' and 'step' are in dst->op_params and dst->op is
|
||||
* `GGML_OP_ARANGE`.
|
||||
*/
|
||||
void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_arange(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies a clamp operation to the elements of a ggml tensor using the
|
||||
|
|
@ -131,7 +131,7 @@ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the clamped values will be stored.
|
||||
* dst->op is `GGML_OP_CLAMP`, `min` and `max` value is in dst->params.
|
||||
*/
|
||||
void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_clamp(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Scales the elements of a ggml tensor by a constant factor using the
|
||||
|
|
@ -148,7 +148,7 @@ void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the scaled values will be stored.
|
||||
* dst->op is `GGML_OP_SCALE` and `scale` value is in dst->params.
|
||||
*/
|
||||
void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_scale(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Sorts the elements of a ggml tensor and returns the indices that
|
||||
|
|
@ -163,7 +163,7 @@ void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the sorted indices will be stored.
|
||||
* dst->op is `GGML_OP_ARGSORT`.
|
||||
*/
|
||||
void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_argsort(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Layer Normalization for a ggml tensor using the CANN
|
||||
|
|
@ -185,7 +185,7 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the normalized values will be stored.
|
||||
* @attention `Var` defaults to dst->ne[0].
|
||||
*/
|
||||
void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Group Normalization for a ggml tensor using the CANN
|
||||
|
|
@ -209,7 +209,7 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
*
|
||||
* @attention eps defaults to 1e-6f.
|
||||
*/
|
||||
void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the accumulation of tensors using the CANN backend.
|
||||
|
|
@ -228,7 +228,7 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the accumulated values will be stored.
|
||||
* `inplace` is in dst->params, and dst->op is `GGML_OP_ACC`.
|
||||
*/
|
||||
void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the sum of elements along the last dimension of a ggml tensor
|
||||
|
|
@ -244,7 +244,7 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
*
|
||||
* @attention `reduce_dims` defaults to 3, which means the last dimension.
|
||||
*/
|
||||
void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the sum of elements in a ggml tensor.
|
||||
|
|
@ -258,7 +258,7 @@ void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
*
|
||||
*/
|
||||
|
||||
void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Upsamples a ggml tensor using nearest neighbor interpolation using
|
||||
|
|
@ -274,8 +274,7 @@ void ggml_cann_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the upsampled values will be stored.
|
||||
* dst->op is `GGML_OP_UPSCALE`.
|
||||
*/
|
||||
void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
|
||||
ggml_tensor* dst);
|
||||
void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Pads a ggml tensor to match the dimensions of the destination tensor
|
||||
|
|
@ -290,7 +289,7 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
|
|||
* @param dst The destination tensor, which specifies the target dimensions for
|
||||
* padding. dst->op is `GGML_OP_PAD`.
|
||||
*/
|
||||
void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_pad(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Executes a 2D pooling operation on a ggml tensor using the CANN
|
||||
|
|
@ -307,7 +306,7 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor on which the pooling operation is to be
|
||||
* performed. dst->op is `GGML_OP_POOL_2D`.
|
||||
*/
|
||||
void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_pool2d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Duplicates a ggml tensor using the CANN backend.
|
||||
|
|
@ -326,7 +325,7 @@ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* different shape and dst is no-contiguous.
|
||||
* @note: This func need to simplify.
|
||||
*/
|
||||
void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_dup(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the Root Mean Square (RMS) normalization of a ggml tensor
|
||||
|
|
@ -348,7 +347,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the normalized values will be stored.
|
||||
* dst->op is `GGML_OP_RMS_NORM`.
|
||||
*/
|
||||
void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_rms_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies a diagonal mask to the tensor with a specified value.
|
||||
|
|
@ -363,7 +362,7 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* `GGML_OP_DIAG_MASK`
|
||||
* @param value The value to use for masking.
|
||||
*/
|
||||
void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float value);
|
||||
void ggml_cann_diag_mask(ggml_backend_cann_context & ctx, ggml_tensor * dst, float value);
|
||||
|
||||
/**
|
||||
* @brief Performs an image-to-column transformation on the input tensor.
|
||||
|
|
@ -378,7 +377,7 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float
|
|||
* @param dst The destination tensor that stores the result of the operation.
|
||||
* dst->op is `GGML_OP_IM2COL`.
|
||||
*/
|
||||
void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_im2col(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes time step embeddings using sine and cosine functions.
|
||||
|
|
@ -392,10 +391,10 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result of the embedding operation
|
||||
* will be stored. dst->op is `GGML_OP_TIMESTEP_EMBEDDING`.
|
||||
*/
|
||||
void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
// @see ggml_cann_dup.
|
||||
void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the softmax activation with optional masking.
|
||||
|
|
@ -417,7 +416,7 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result will be stored. dst->op is
|
||||
* `GGML_OP_SOFTMAX`.
|
||||
*/
|
||||
void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Extracts specific rows from a tensor based on indices.
|
||||
|
|
@ -429,7 +428,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param ctx The backend CANN context for executing operations.
|
||||
* @param dst The destination tensor where the extracted rows will be stored.
|
||||
*/
|
||||
void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Writes specific rows into a tensor at positions specified by indices.
|
||||
|
|
@ -441,7 +440,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param ctx The backend CANN context for executing operations.
|
||||
* @param dst The destination tensor where the specified rows will be updated.
|
||||
*/
|
||||
void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Executes matrix multiplication for the given tensor.
|
||||
|
|
@ -454,7 +453,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor for storing the result of the matrix
|
||||
* multiplication. dst->op is `GGML_OP_MUL_MAT`.
|
||||
*/
|
||||
void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies Rotary Positional Embedding (RoPE) to the input tensor.
|
||||
|
|
@ -477,7 +476,7 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @note The function currently does not support cases where the freq_scale is
|
||||
* not equal 1.
|
||||
*/
|
||||
void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the index of the maximum value along the specified dimension
|
||||
|
|
@ -492,7 +491,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the indices of the maximum values will
|
||||
* be stored. dst->op is `GGML_OP_ARGMAX`.
|
||||
*/
|
||||
void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Adds two tensors element-wise and stores the result in a destination
|
||||
|
|
@ -509,8 +508,10 @@ void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param acl_src1 The second source tensor.
|
||||
* @param acl_dst The destination tensor where the result will be stored.
|
||||
*/
|
||||
void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
|
||||
aclTensor* acl_src1, aclTensor* acl_dst = nullptr);
|
||||
void aclnn_add(ggml_backend_cann_context & ctx,
|
||||
aclTensor * acl_src0,
|
||||
aclTensor * acl_src1,
|
||||
aclTensor * acl_dst = nullptr);
|
||||
|
||||
/**
|
||||
* @brief Sub two tensors element-wise and stores the result in a destination
|
||||
|
|
@ -527,8 +528,10 @@ void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
|
|||
* @param acl_src1 The second source tensor.
|
||||
* @param acl_dst The destination tensor where the result will be stored.
|
||||
*/
|
||||
void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
|
||||
aclTensor* acl_src1, aclTensor* acl_dst = nullptr);
|
||||
void aclnn_sub(ggml_backend_cann_context & ctx,
|
||||
aclTensor * acl_src0,
|
||||
aclTensor * acl_src1,
|
||||
aclTensor * acl_dst = nullptr);
|
||||
|
||||
/**
|
||||
* @brief Performs element-wise multiplication of two tensors and stores the
|
||||
|
|
@ -546,8 +549,10 @@ void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
|
|||
* @param acl_other The second tensor for element-wise multiplication.
|
||||
* @param acl_dst The destination tensor where the result will be stored.
|
||||
*/
|
||||
void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_other, aclTensor* acl_dst = nullptr);
|
||||
void aclnn_mul(ggml_backend_cann_context & ctx,
|
||||
aclTensor * acl_src,
|
||||
aclTensor * acl_other,
|
||||
aclTensor * acl_dst = nullptr);
|
||||
|
||||
/**
|
||||
* @brief Matrix division, optionally in-place.
|
||||
|
|
@ -567,8 +572,10 @@ void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
|||
* @param inplace Flag indicating whether to perform the operation in-place on
|
||||
* `acl_src`.
|
||||
*/
|
||||
void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_other, aclTensor* acl_dst = nullptr);
|
||||
void aclnn_div(ggml_backend_cann_context & ctx,
|
||||
aclTensor * acl_src,
|
||||
aclTensor * acl_other,
|
||||
aclTensor * acl_dst = nullptr);
|
||||
|
||||
/**
|
||||
* @brief Applies element-wise cosine function to the elements of a tensor.
|
||||
|
|
@ -584,8 +591,7 @@ void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
|||
* @param acl_dst The destination tensor where the cosine results will be
|
||||
* stored.
|
||||
*/
|
||||
void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_dst);
|
||||
void aclnn_cos(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Applies element-wise sine function to the elements of a tensor.
|
||||
|
|
@ -602,8 +608,7 @@ void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
|||
* @param acl_src The source tensor on which the sine function will be applied.
|
||||
* @param acl_dst The destination tensor where the sine results will be stored.
|
||||
*/
|
||||
void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
||||
aclTensor* acl_dst);
|
||||
void aclnn_sin(ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Prepares broadcast-compatible ACL tensors for two input tensors and one
|
||||
|
|
@ -621,8 +626,12 @@ void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
|
|||
* @param acl_src1 Output pointer to the created ACL tensor corresponding to src1.
|
||||
* @param acl_dst Output pointer to the created ACL tensor corresponding to dst.
|
||||
*/
|
||||
void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst,
|
||||
aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst);
|
||||
void bcast_shape(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
aclTensor ** acl_src0,
|
||||
aclTensor ** acl_src1,
|
||||
aclTensor ** acl_dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the 1D transposed convolution (deconvolution) of a ggml
|
||||
|
|
@ -637,7 +646,7 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst,
|
|||
* @param dst The destination tensor where the transposed convolution result
|
||||
* will be stored. dst->op is `GGML_OP_CONV_TRANSPOSE_1D`.
|
||||
*/
|
||||
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies the ELU (Exponential Linear Unit) activation to a ggml tensor
|
||||
|
|
@ -662,7 +671,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
|
|||
* @param dst The destination tensor where the ELU-activated result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_ELU`.
|
||||
*/
|
||||
void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_elu(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Computes the mean of a ggml tensor element-wise using the CANN backend.
|
||||
|
|
@ -677,7 +686,7 @@ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the mean result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_MEAN`.
|
||||
*/
|
||||
void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_mean(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies 1D reflect padding to a ggml tensor using the CANN backend.
|
||||
|
|
@ -692,7 +701,7 @@ void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the padded result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_PAD_REFLECT_1D`.
|
||||
*/
|
||||
void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Counts the number of equal elements in two ggml tensors using the CANN backend.
|
||||
|
|
@ -708,7 +717,7 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_COUNT_EQUAL`.
|
||||
*/
|
||||
void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies the Step activation function to a ggml tensor using the CANN backend.
|
||||
|
|
@ -723,7 +732,7 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_STEP`.
|
||||
*/
|
||||
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Performs the Flash Attention extended operator using the CANN backend.
|
||||
|
|
@ -738,59 +747,46 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
* @param dst The destination tensor where the result will be stored.
|
||||
* dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
|
||||
*/
|
||||
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/*
|
||||
* @brief A generic wrapper for ACL resources with custom deleter support.
|
||||
*/
|
||||
using any_acl_resource = std::unique_ptr<void, std::function<void(void*)>>;
|
||||
using any_acl_resource = std::unique_ptr<void, std::function<void(void *)>>;
|
||||
|
||||
/**
|
||||
* @brief Trait structure used to define how to destroy a given ACL resource type.
|
||||
*
|
||||
* @tparam T ACL resource type.
|
||||
*/
|
||||
template<typename T>
|
||||
struct acl_resource_traits;
|
||||
template <typename T> struct acl_resource_traits;
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclTensor, defines how to destroy an aclTensor resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclTensor> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyTensor(static_cast<aclTensor*>(p)));
|
||||
}
|
||||
template <> struct acl_resource_traits<aclTensor> {
|
||||
static void destroy(void * p) { ACL_CHECK(aclDestroyTensor(static_cast<aclTensor *>(p))); }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclIntArray> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray*>(p)));
|
||||
}
|
||||
template <> struct acl_resource_traits<aclIntArray> {
|
||||
static void destroy(void * p) { ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray *>(p))); }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclScalar, defines how to destroy an aclScalar resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclScalar> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyScalar(static_cast<aclScalar*>(p)));
|
||||
}
|
||||
template <> struct acl_resource_traits<aclScalar> {
|
||||
static void destroy(void * p) { ACL_CHECK(aclDestroyScalar(static_cast<aclScalar *>(p))); }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource.
|
||||
*/
|
||||
template<>
|
||||
struct acl_resource_traits<aclTensorList> {
|
||||
static void destroy(void* p) {
|
||||
ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList*>(p)));
|
||||
}
|
||||
template <> struct acl_resource_traits<aclTensorList> {
|
||||
static void destroy(void * p) { ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList *>(p))); }
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -800,14 +796,8 @@ struct acl_resource_traits<aclTensorList> {
|
|||
* @param ptr Raw pointer to ACL resource.
|
||||
* @return any_acl_resource Smart pointer that handles destruction.
|
||||
*/
|
||||
template<typename T>
|
||||
any_acl_resource make_acl_resource(T* ptr) {
|
||||
return any_acl_resource(
|
||||
static_cast<void*>(ptr),
|
||||
[](void* p) {
|
||||
acl_resource_traits<T>::destroy(p);
|
||||
}
|
||||
);
|
||||
template <typename T> any_acl_resource make_acl_resource(T * ptr) {
|
||||
return any_acl_resource(static_cast<void *>(ptr), [](void * p) { acl_resource_traits<T>::destroy(p); });
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -817,8 +807,7 @@ any_acl_resource make_acl_resource(T* ptr) {
|
|||
* @param vec Target vector to hold ACL resources.
|
||||
* @param args Raw pointers to ACL resources.
|
||||
*/
|
||||
template<typename... Args>
|
||||
void register_acl_resources(std::vector<any_acl_resource>& vec, Args*... args) {
|
||||
template <typename... Args> void register_acl_resources(std::vector<any_acl_resource> & vec, Args *... args) {
|
||||
(vec.emplace_back(make_acl_resource(args)), ...);
|
||||
}
|
||||
|
||||
|
|
@ -827,17 +816,18 @@ void register_acl_resources(std::vector<any_acl_resource>& vec, Args*... args) {
|
|||
*/
|
||||
class aclnn_task : public cann_task {
|
||||
public:
|
||||
aclnn_task(aclnn_func_t aclnn_func, void * workspace_addr,
|
||||
uint64_t workspace_size, aclOpExecutor * executor,
|
||||
aclnn_task(aclnn_func_t aclnn_func,
|
||||
void * workspace_addr,
|
||||
uint64_t workspace_size,
|
||||
aclOpExecutor * executor,
|
||||
aclrtStream stream) :
|
||||
aclnn_func_(aclnn_func),
|
||||
workspace_addr_(workspace_addr),
|
||||
workspace_size_(workspace_size),
|
||||
executor_(executor),
|
||||
stream_(stream) {}
|
||||
virtual void run_task() override {
|
||||
ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_));
|
||||
}
|
||||
|
||||
virtual void run_task() override { ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); }
|
||||
private:
|
||||
aclnn_func_t aclnn_func_;
|
||||
void * workspace_addr_;
|
||||
|
|
@ -850,15 +840,11 @@ class aclnn_task : public cann_task {
|
|||
* @brief Task class that releases ACL resources after usage.
|
||||
*/
|
||||
class release_resource_task : public cann_task {
|
||||
public:
|
||||
release_resource_task(std::vector<any_acl_resource>&& resources){
|
||||
resource_ = std::move(resources);
|
||||
}
|
||||
public:
|
||||
release_resource_task(std::vector<any_acl_resource> && resources) { resource_ = std::move(resources); }
|
||||
|
||||
virtual void run_task() override {
|
||||
resource_.clear();
|
||||
}
|
||||
private:
|
||||
virtual void run_task() override { resource_.clear(); }
|
||||
private:
|
||||
std::vector<any_acl_resource> resource_;
|
||||
};
|
||||
|
||||
|
|
@ -866,17 +852,18 @@ private:
|
|||
* @brief Task class for performing asynchronous memory copy operations.
|
||||
*/
|
||||
class async_memcpy_task : public cann_task {
|
||||
public:
|
||||
async_memcpy_task(void* dst, const void* src, size_t size,
|
||||
aclrtMemcpyKind kind, aclrtStream stream)
|
||||
: dst_(dst), src_(src), size_(size), kind_(kind), stream_(stream) {}
|
||||
public:
|
||||
async_memcpy_task(void * dst, const void * src, size_t size, aclrtMemcpyKind kind, aclrtStream stream) :
|
||||
dst_(dst),
|
||||
src_(src),
|
||||
size_(size),
|
||||
kind_(kind),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override {
|
||||
ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_));
|
||||
}
|
||||
private:
|
||||
void* dst_;
|
||||
const void* src_;
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); }
|
||||
private:
|
||||
void * dst_;
|
||||
const void * src_;
|
||||
size_t size_;
|
||||
aclrtMemcpyKind kind_;
|
||||
aclrtStream stream_;
|
||||
|
|
@ -887,14 +874,15 @@ private:
|
|||
*/
|
||||
class async_memset_task : public cann_task {
|
||||
public:
|
||||
async_memset_task(void* buffer, size_t size, int32_t value, aclrtStream stream)
|
||||
: buffer_(buffer), size_(size), value_(value), stream_(stream) {}
|
||||
async_memset_task(void * buffer, size_t size, int32_t value, aclrtStream stream) :
|
||||
buffer_(buffer),
|
||||
size_(size),
|
||||
value_(value),
|
||||
stream_(stream) {}
|
||||
|
||||
virtual void run_task() override {
|
||||
ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_));
|
||||
}
|
||||
virtual void run_task() override { ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); }
|
||||
private:
|
||||
void* buffer_;
|
||||
void * buffer_;
|
||||
size_t size_;
|
||||
int32_t value_;
|
||||
aclrtStream stream_;
|
||||
|
|
@ -923,7 +911,7 @@ class async_memset_task : public cann_task {
|
|||
uint64_t workspaceSize = 0; \
|
||||
aclOpExecutor * executor; \
|
||||
void * workspaceAddr = nullptr; \
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor));\
|
||||
ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
|
||||
/* workspace should alloced in main thread to keep malloc order when using vmm. */ \
|
||||
if (workspaceSize > 0) { \
|
||||
ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \
|
||||
|
|
@ -931,11 +919,10 @@ class async_memset_task : public cann_task {
|
|||
} \
|
||||
if (CTX.async_mode) { \
|
||||
auto task = \
|
||||
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, \
|
||||
executor, CTX.stream()); \
|
||||
std::make_unique<aclnn_task>(aclnn##OP_NAME, workspaceAddr, workspaceSize, executor, CTX.stream()); \
|
||||
CTX.task_queue.submit_task(std::move(task)); \
|
||||
} else { \
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));\
|
||||
ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
|
@ -947,11 +934,10 @@ class async_memset_task : public cann_task {
|
|||
* @param ctx Backend context which manages task submission and async mode.
|
||||
* @param args Pointers to ACL resources to be released.
|
||||
*/
|
||||
template <typename... Args>
|
||||
void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
|
||||
template <typename... Args> void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
|
||||
std::vector<any_acl_resource> resources;
|
||||
register_acl_resources(resources, std::forward<Args>(args)...);
|
||||
if(ctx.async_mode) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<release_resource_task>(std::move(resources));
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
}
|
||||
|
|
@ -966,8 +952,11 @@ void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... arg
|
|||
* @param len Size of memory to copy (in bytes).
|
||||
* @param kind Type of memory copy (host-to-device, device-to-host, etc).
|
||||
*/
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
|
||||
const void * src, size_t len, aclrtMemcpyKind kind) {
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
|
|
@ -976,8 +965,11 @@ inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
|
|||
}
|
||||
}
|
||||
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
|
||||
const void * src, size_t len, aclrtMemcpyKind kind) {
|
||||
inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t len,
|
||||
aclrtMemcpyKind kind) {
|
||||
if (ctx->async_mode) {
|
||||
auto task = std::make_unique<async_memcpy_task>(dst, const_cast<void *>(src), len, kind, ctx->stream());
|
||||
ctx->task_queue.submit_task(std::move(task));
|
||||
|
|
@ -994,8 +986,7 @@ inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
|
|||
* @param size Size of the memory buffer (in bytes).
|
||||
* @param value Value to set in the buffer.
|
||||
*/
|
||||
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer,
|
||||
size_t size, int value) {
|
||||
inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, size_t size, int value) {
|
||||
if (ctx.async_mode) {
|
||||
auto task = std::make_unique<async_memset_task>(buffer, size, value, ctx.stream());
|
||||
ctx.task_queue.submit_task(std::move(task));
|
||||
|
|
@ -1029,7 +1020,7 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe
|
|||
* @param dst The destination tensor where the expert-weighted token outputs are stored.
|
||||
* Expected to be of shape [M, K, N, 1].
|
||||
*/
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Check whether a tensor is a weight tensor for matrix multiplication.
|
||||
|
|
@ -1041,20 +1032,14 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
|||
*
|
||||
* @param tensor Pointer to the target ggml_tensor object (const-qualified).
|
||||
*/
|
||||
static bool is_matmul_weight(const ggml_tensor* tensor) {
|
||||
static bool is_matmul_weight(const ggml_tensor * tensor) {
|
||||
std::string name = ggml_get_name(tensor);
|
||||
static const std::unordered_set<std::string> weight_suffixes{
|
||||
"output.weight",
|
||||
"attn_q.weight",
|
||||
"attn_k.weight",
|
||||
"attn_v.weight",
|
||||
"attn_output.weight",
|
||||
"ffn_gate.weight",
|
||||
"ffn_up.weight",
|
||||
"ffn_down.weight"
|
||||
};
|
||||
static const std::unordered_set<std::string> weight_suffixes{ "output.weight", "attn_q.weight",
|
||||
"attn_k.weight", "attn_v.weight",
|
||||
"attn_output.weight", "ffn_gate.weight",
|
||||
"ffn_up.weight", "ffn_down.weight" };
|
||||
|
||||
for (const auto& suffix : weight_suffixes) {
|
||||
for (const auto & suffix : weight_suffixes) {
|
||||
if (name.find(suffix) != std::string::npos) {
|
||||
return true;
|
||||
}
|
||||
|
|
@ -1078,14 +1063,13 @@ static bool is_matmul_weight(const ggml_tensor* tensor) {
|
|||
* @param ctx The CANN backend context used to manage execution and resources.
|
||||
* @param dst The destination tensor.
|
||||
*/
|
||||
template <auto binary_op>
|
||||
void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src0 = dst->src[0];
|
||||
ggml_tensor* src1 = dst->src[1];
|
||||
template <auto binary_op> void ggml_cann_binary_op(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src0 = dst->src[0];
|
||||
ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
aclTensor* acl_src0;
|
||||
aclTensor* acl_src1;
|
||||
aclTensor* acl_dst;
|
||||
aclTensor * acl_src0;
|
||||
aclTensor * acl_src1;
|
||||
aclTensor * acl_dst;
|
||||
|
||||
// Need bcast
|
||||
bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
|
||||
|
|
@ -1094,7 +1078,6 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Applies a unary operation to an input tensor using the CANN backend.
|
||||
*
|
||||
|
|
@ -1107,12 +1090,12 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
|||
* @param ctx The CANN backend context for managing resources and execution.
|
||||
* @param dst The destination tensor. Its src[0] is treated as the input tensor.
|
||||
*/
|
||||
template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
|
||||
void ggml_cann_op_unary(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
||||
ggml_tensor* src = dst->src[0];
|
||||
template <void unary_op(ggml_backend_cann_context &, aclTensor *, aclTensor *)>
|
||||
void ggml_cann_op_unary(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
|
||||
aclTensor* acl_src = ggml_cann_create_tensor(src);
|
||||
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
|
||||
aclTensor * acl_src = ggml_cann_create_tensor(src);
|
||||
aclTensor * acl_dst = ggml_cann_create_tensor(dst);
|
||||
|
||||
unary_op(ctx, acl_src, acl_dst);
|
||||
ggml_cann_release_resources(ctx, acl_src, acl_dst);
|
||||
|
|
@ -1138,9 +1121,9 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
|
|||
*
|
||||
* @see GGML_CANN_CALL_OP_UNARY
|
||||
*/
|
||||
void ggml_cann_op_unary(
|
||||
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
|
||||
ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_op_unary(std::function<void(ggml_backend_cann_context &, aclTensor *, aclTensor *)> unary_op,
|
||||
ggml_backend_cann_context & ctx,
|
||||
ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Applies a gated (GLU-style) unary operation using the CANN backend.
|
||||
|
|
@ -1172,9 +1155,9 @@ void ggml_cann_op_unary(
|
|||
*
|
||||
* @see GGML_CANN_CALL_OP_UNARY_GATED
|
||||
*/
|
||||
void ggml_cann_op_unary_gated(
|
||||
std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
|
||||
ggml_backend_cann_context& ctx, ggml_tensor* dst);
|
||||
void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, aclTensor *, aclTensor *)> unary_op,
|
||||
ggml_backend_cann_context & ctx,
|
||||
ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
|
||||
|
|
@ -1199,14 +1182,11 @@ void ggml_cann_op_unary_gated(
|
|||
*/
|
||||
#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \
|
||||
do { \
|
||||
auto lambda = [](ggml_backend_cann_context& ctx, \
|
||||
aclTensor* acl_src, \
|
||||
aclTensor* acl_dst) { \
|
||||
auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
|
||||
}; \
|
||||
ggml_cann_op_unary(lambda, ctx, dst); \
|
||||
} \
|
||||
while (0)
|
||||
} while (0)
|
||||
|
||||
/**
|
||||
* @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
|
||||
|
|
@ -1231,13 +1211,10 @@ void ggml_cann_op_unary_gated(
|
|||
*/
|
||||
#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \
|
||||
do { \
|
||||
auto lambda = [](ggml_backend_cann_context& ctx, \
|
||||
aclTensor* acl_src, \
|
||||
aclTensor* acl_dst) { \
|
||||
auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \
|
||||
}; \
|
||||
ggml_cann_op_unary_gated(lambda, ctx, dst); \
|
||||
} \
|
||||
while (0)
|
||||
} while (0)
|
||||
|
||||
#endif // CANN_ACLNN_OPS
|
||||
|
|
|
|||
|
|
@ -56,8 +56,7 @@
|
|||
* @param line The line number at which the error occurred.
|
||||
* @param msg The error message.
|
||||
*/
|
||||
[[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
|
||||
const char* file, int line, const char* msg);
|
||||
[[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
|
||||
|
||||
/**
|
||||
* @brief Checks the result of a CANN function call and invokes the error
|
||||
|
|
@ -96,18 +95,17 @@ struct ggml_cann_device_info {
|
|||
size_t total_vram; /**< Total video RAM available on the device. */
|
||||
};
|
||||
|
||||
cann_device_info devices[GGML_CANN_MAX_DEVICES] =
|
||||
{}; /**< Array of CANN device information. */
|
||||
cann_device_info devices[GGML_CANN_MAX_DEVICES] = {}; /**< Array of CANN device information. */
|
||||
};
|
||||
|
||||
const ggml_cann_device_info& ggml_cann_info();
|
||||
const ggml_cann_device_info & ggml_cann_info();
|
||||
|
||||
void ggml_cann_set_device(int32_t device);
|
||||
int32_t ggml_cann_get_device();
|
||||
|
||||
std::optional<std::string> get_env(const std::string& name);
|
||||
bool parse_bool(const std::string& value);
|
||||
int parse_integer(const std::string& value);
|
||||
std::optional<std::string> get_env(const std::string & name);
|
||||
bool parse_bool(const std::string & value);
|
||||
int parse_integer(const std::string & value);
|
||||
|
||||
/**
|
||||
* @brief Abstract base class for memory pools used by CANN.
|
||||
|
|
@ -126,7 +124,7 @@ struct ggml_cann_pool {
|
|||
* will be stored.
|
||||
* @return Pointer to the allocated memory block.
|
||||
*/
|
||||
virtual void* alloc(size_t size, size_t* actual_size) = 0;
|
||||
virtual void * alloc(size_t size, size_t * actual_size) = 0;
|
||||
|
||||
/**
|
||||
* @brief Frees a previously allocated memory block.
|
||||
|
|
@ -136,15 +134,15 @@ struct ggml_cann_pool {
|
|||
* @note Note that all CANN opertors are running async. Make sure memory is
|
||||
* still avaiable before this operator finished.
|
||||
*/
|
||||
virtual void free(void* ptr, size_t size) = 0;
|
||||
virtual void free(void * ptr, size_t size) = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief RAII wrapper for managing memory allocations from a CANN memory pool.
|
||||
*/
|
||||
struct ggml_cann_pool_alloc {
|
||||
ggml_cann_pool* pool = nullptr; /**< Pointer to the memory pool. */
|
||||
void* ptr = nullptr; /**< Pointer to the allocated memory block. */
|
||||
ggml_cann_pool * pool = nullptr; /**< Pointer to the memory pool. */
|
||||
void * ptr = nullptr; /**< Pointer to the allocated memory block. */
|
||||
size_t actual_size = 0; /**< Actual size of the allocated memory block. */
|
||||
|
||||
/**
|
||||
|
|
@ -156,16 +154,14 @@ struct ggml_cann_pool_alloc {
|
|||
* @brief Constructor that initializes the memory pool.
|
||||
* @param pool Reference to the memory pool.
|
||||
*/
|
||||
explicit ggml_cann_pool_alloc(ggml_cann_pool& pool) : pool(&pool) {}
|
||||
explicit ggml_cann_pool_alloc(ggml_cann_pool & pool) : pool(&pool) {}
|
||||
|
||||
/**
|
||||
* @brief Constructor that initializes the memory pool and allocates memory.
|
||||
* @param pool Reference to the memory pool.
|
||||
* @param size Size of the memory block to allocate.
|
||||
*/
|
||||
ggml_cann_pool_alloc(ggml_cann_pool& pool, size_t size) : pool(&pool) {
|
||||
alloc(size);
|
||||
}
|
||||
ggml_cann_pool_alloc(ggml_cann_pool & pool, size_t size) : pool(&pool) { alloc(size); }
|
||||
|
||||
/**
|
||||
* @brief Destructor that frees the allocated memory block.
|
||||
|
|
@ -181,7 +177,7 @@ struct ggml_cann_pool_alloc {
|
|||
* @param size Size of the memory block to allocate.
|
||||
* @return Pointer to the allocated memory block.
|
||||
*/
|
||||
void* alloc(size_t size) {
|
||||
void * alloc(size_t size) {
|
||||
GGML_ASSERT(pool != nullptr);
|
||||
GGML_ASSERT(ptr == nullptr);
|
||||
ptr = pool->alloc(size, &this->actual_size);
|
||||
|
|
@ -194,7 +190,7 @@ struct ggml_cann_pool_alloc {
|
|||
* @param size Size of the memory block to allocate.
|
||||
* @return Pointer to the allocated memory block.
|
||||
*/
|
||||
void* alloc(ggml_cann_pool& pool, size_t size) {
|
||||
void * alloc(ggml_cann_pool & pool, size_t size) {
|
||||
this->pool = &pool;
|
||||
return alloc(size);
|
||||
}
|
||||
|
|
@ -203,25 +199,25 @@ struct ggml_cann_pool_alloc {
|
|||
* @brief Gets the pointer to the allocated memory block.
|
||||
* @return Pointer to the allocated memory block.
|
||||
*/
|
||||
void* get() { return ptr; }
|
||||
void * get() { return ptr; }
|
||||
|
||||
// Deleted copy constructor
|
||||
ggml_cann_pool_alloc(const ggml_cann_pool_alloc&) = delete;
|
||||
ggml_cann_pool_alloc(const ggml_cann_pool_alloc &) = delete;
|
||||
|
||||
// Deleted move constructor
|
||||
ggml_cann_pool_alloc(ggml_cann_pool_alloc&&) = delete;
|
||||
ggml_cann_pool_alloc(ggml_cann_pool_alloc &&) = delete;
|
||||
|
||||
// Deleted copy assignment operator
|
||||
ggml_cann_pool_alloc& operator=(const ggml_cann_pool_alloc&) = delete;
|
||||
ggml_cann_pool_alloc & operator=(const ggml_cann_pool_alloc &) = delete;
|
||||
|
||||
// Deleted move assignment operator
|
||||
ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
|
||||
ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Function pointer type for ACLNN operator calls.
|
||||
*/
|
||||
using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream);
|
||||
using aclnn_func_t = aclnnStatus (*)(void *, uint64_t, aclOpExecutor *, aclrtStream);
|
||||
|
||||
/**
|
||||
* @brief Base class for all CANN tasks to be submitted to the task queue.
|
||||
|
|
@ -229,7 +225,7 @@ using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStrea
|
|||
* Users should override the run_task() method with actual task logic.
|
||||
*/
|
||||
class cann_task {
|
||||
public:
|
||||
public:
|
||||
virtual void run_task() {}
|
||||
};
|
||||
|
||||
|
|
@ -237,16 +233,20 @@ public:
|
|||
* @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
|
||||
*/
|
||||
class cann_task_queue {
|
||||
public:
|
||||
public:
|
||||
/**
|
||||
* @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
|
||||
*
|
||||
* @param capacity Queue capacity. Must be a power of 2.
|
||||
* @param device Target device ID (used for context setting).
|
||||
*/
|
||||
explicit cann_task_queue(size_t capacity, int32_t device)
|
||||
: buffer_(capacity), capacity_(capacity), head_(0), tail_(0),
|
||||
running_(false), device_(device) {
|
||||
explicit cann_task_queue(size_t capacity, int32_t device) :
|
||||
buffer_(capacity),
|
||||
capacity_(capacity),
|
||||
head_(0),
|
||||
tail_(0),
|
||||
running_(false),
|
||||
device_(device) {
|
||||
GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
|
||||
mask_ = capacity_ - 1;
|
||||
}
|
||||
|
|
@ -257,7 +257,7 @@ public:
|
|||
* @param item Unique pointer to the task.
|
||||
* @return true if the task was successfully enqueued, false if the queue was full.
|
||||
*/
|
||||
bool enqueue(std::unique_ptr<cann_task>&& item) {
|
||||
bool enqueue(std::unique_ptr<cann_task> && item) {
|
||||
size_t next_tail = (tail_ + 1) & mask_;
|
||||
|
||||
if (next_tail == head_) {
|
||||
|
|
@ -276,8 +276,8 @@ public:
|
|||
*
|
||||
* @param task Task to be submitted.
|
||||
*/
|
||||
void submit_task(std::unique_ptr<cann_task>&& task) {
|
||||
while(!enqueue(std::move(task))) {
|
||||
void submit_task(std::unique_ptr<cann_task> && task) {
|
||||
while (!enqueue(std::move(task))) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
|
@ -286,7 +286,6 @@ public:
|
|||
running_ = true;
|
||||
thread_ = std::thread(&cann_task_queue::execute, this);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -309,7 +308,7 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
private:
|
||||
private:
|
||||
/**
|
||||
* @brief Worker thread function that continuously dequeues and executes tasks.
|
||||
*/
|
||||
|
|
@ -317,7 +316,7 @@ private:
|
|||
ggml_cann_set_device(device_);
|
||||
|
||||
while (running_) {
|
||||
if(head_ == tail_) {
|
||||
if (head_ == tail_) {
|
||||
std::this_thread::yield();
|
||||
continue;
|
||||
}
|
||||
|
|
@ -378,11 +377,9 @@ struct ggml_cann_graph {
|
|||
struct ggml_cann_graph_lru_cache {
|
||||
size_t capacity; /**< Maximum number of graphs in the cache. */
|
||||
|
||||
std::list<ggml_cann_graph*> cache_list; /**< List storing cached graphs as raw pointers. */
|
||||
std::list<ggml_cann_graph *> cache_list; /**< List storing cached graphs as raw pointers. */
|
||||
|
||||
ggml_cann_graph_lru_cache() {
|
||||
capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12"));
|
||||
}
|
||||
ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); }
|
||||
|
||||
/**
|
||||
* @brief Push a new graph to the front of the cache.
|
||||
|
|
@ -390,9 +387,9 @@ struct ggml_cann_graph_lru_cache {
|
|||
* @param new_node Pointer to the new ggml_cann_graph to cache.
|
||||
* Ownership is transferred to the cache (cache will delete it).
|
||||
*/
|
||||
void push(ggml_cann_graph* new_node) {
|
||||
void push(ggml_cann_graph * new_node) {
|
||||
if (cache_list.size() >= capacity) {
|
||||
ggml_cann_graph* old = cache_list.back();
|
||||
ggml_cann_graph * old = cache_list.back();
|
||||
cache_list.pop_back();
|
||||
delete old; // free the old graph
|
||||
}
|
||||
|
|
@ -403,7 +400,7 @@ struct ggml_cann_graph_lru_cache {
|
|||
* @brief Move an existing graph to the front of the cache.
|
||||
* @param node Pointer to the ggml_cann_graph to move.
|
||||
*/
|
||||
void move_to_front(ggml_cann_graph* node) {
|
||||
void move_to_front(ggml_cann_graph * node) {
|
||||
cache_list.remove(node);
|
||||
cache_list.push_front(node);
|
||||
}
|
||||
|
|
@ -421,30 +418,28 @@ struct ggml_cann_graph_lru_cache {
|
|||
/**
|
||||
* @brief Destructor that clears the cache and frees all cached graphs.
|
||||
*/
|
||||
~ggml_cann_graph_lru_cache() {
|
||||
clear();
|
||||
}
|
||||
~ggml_cann_graph_lru_cache() { clear(); }
|
||||
};
|
||||
#endif // USE_ACL_GRAPH
|
||||
|
||||
struct ggml_cann_rope_cache {
|
||||
~ggml_cann_rope_cache() {
|
||||
if(theta_scale_cache != nullptr) {
|
||||
if (theta_scale_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(theta_scale_cache));
|
||||
}
|
||||
if(sin_cache != nullptr) {
|
||||
if (sin_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(sin_cache));
|
||||
}
|
||||
if(cos_cache != nullptr) {
|
||||
if (cos_cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(cos_cache));
|
||||
}
|
||||
}
|
||||
|
||||
void* theta_scale_cache = nullptr;
|
||||
void * theta_scale_cache = nullptr;
|
||||
int64_t theta_scale_length = 0;
|
||||
// sin/cos cache, used only to accelerate first layer on each device
|
||||
void* sin_cache = nullptr;
|
||||
void* cos_cache = nullptr;
|
||||
void * sin_cache = nullptr;
|
||||
void * cos_cache = nullptr;
|
||||
int64_t position_length = 0;
|
||||
// Properties to check before reusing the sincos cache
|
||||
bool cached = false;
|
||||
|
|
@ -457,12 +452,12 @@ struct ggml_cann_rope_cache {
|
|||
|
||||
struct ggml_cann_tensor_cache {
|
||||
~ggml_cann_tensor_cache() {
|
||||
if(cache != nullptr) {
|
||||
if (cache != nullptr) {
|
||||
ACL_CHECK(aclrtFree(cache));
|
||||
}
|
||||
}
|
||||
|
||||
void* cache = nullptr;
|
||||
void * cache = nullptr;
|
||||
int64_t size = 0;
|
||||
};
|
||||
|
||||
|
|
@ -487,25 +482,24 @@ struct ggml_backend_cann_context {
|
|||
ggml_cann_tensor_cache rms_norm_one_tensor_cache;
|
||||
ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
|
||||
|
||||
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
|
||||
aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */
|
||||
|
||||
/**
|
||||
* @brief Constructor for initializing the context with a given device.
|
||||
* @param device Device ID.
|
||||
*/
|
||||
explicit ggml_backend_cann_context(int device)
|
||||
: device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
|
||||
explicit ggml_backend_cann_context(int device) :
|
||||
device(device),
|
||||
name("CANN" + std::to_string(device)),
|
||||
task_queue(1024, device) {
|
||||
ggml_cann_set_device(device);
|
||||
description = aclrtGetSocName();
|
||||
|
||||
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
|
||||
device, async_mode ? "ON" : "OFF");
|
||||
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, device, async_mode ? "ON" : "OFF");
|
||||
#ifdef USE_ACL_GRAPH
|
||||
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
|
||||
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n",
|
||||
__func__, device,
|
||||
acl_graph_mode ? "GRAPH" : "EAGER",
|
||||
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
|
||||
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
|
||||
#endif
|
||||
}
|
||||
|
|
@ -549,8 +543,7 @@ struct ggml_backend_cann_context {
|
|||
aclrtStream stream() { return stream(0); }
|
||||
|
||||
// TODO: each stream should have a memory pool.
|
||||
std::unique_ptr<ggml_cann_pool>
|
||||
mem_pool; /**< Memory pool for the device. */
|
||||
std::unique_ptr<ggml_cann_pool> mem_pool; /**< Memory pool for the device. */
|
||||
|
||||
/**
|
||||
* @brief Create a new memory pool for a given device.
|
||||
|
|
@ -563,7 +556,7 @@ struct ggml_backend_cann_context {
|
|||
* @brief Get or create the memory pool for the context.
|
||||
* @return Reference to the memory pool.
|
||||
*/
|
||||
ggml_cann_pool& pool() {
|
||||
ggml_cann_pool & pool() {
|
||||
if (mem_pool == nullptr) {
|
||||
mem_pool = new_pool_for_device(device);
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -466,7 +466,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
|
||||
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
|
||||
message(STATUS "s390x detected")
|
||||
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/arch/s390/quants.c)
|
||||
|
||||
# for native compilation
|
||||
if (GGML_NATIVE)
|
||||
# check machine level to determine target
|
||||
file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
|
||||
string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
|
||||
|
||||
|
|
@ -487,8 +492,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
|
||||
list(APPEND ARCH_FLAGS -march=native -mtune=native)
|
||||
endif()
|
||||
# for cross-compilation
|
||||
elseif(GGML_CPU_ALL_VARIANTS)
|
||||
# range through IBM z15 to z17
|
||||
# NOTE: update when a new hardware level is released
|
||||
foreach (ZHW RANGE 15 17)
|
||||
if(DEFINED GGML_INTERNAL_Z${ZHW})
|
||||
message(STATUS "z${ZHW} cross-compile target")
|
||||
list(APPEND ARCH_FLAGS -march=z${ZHW})
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if (GGML_VXE)
|
||||
if (GGML_VXE OR GGML_INTERNAL_VXE)
|
||||
message(STATUS "VX/VXE/VXE2 enabled")
|
||||
list(APPEND ARCH_FLAGS -mvx -mzvector)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_VXE)
|
||||
|
|
|
|||
|
|
@ -3567,13 +3567,17 @@ void ggml_cpu_init(void) {
|
|||
#ifdef GGML_USE_OPENMP
|
||||
//if (!getenv("OMP_WAIT_POLICY")) {
|
||||
// // set the wait policy to active, so that OpenMP threads don't sleep
|
||||
// putenv("OMP_WAIT_POLICY=active");
|
||||
// setenv("OMP_WAIT_POLICY", "active", 0)
|
||||
//}
|
||||
|
||||
if (!getenv("KMP_BLOCKTIME")) {
|
||||
// set the time to wait before sleeping a thread
|
||||
// this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases
|
||||
putenv("KMP_BLOCKTIME=200"); // 200ms
|
||||
#ifdef _WIN32
|
||||
_putenv_s("KMP_BLOCKTIME", "200"); // 200ms
|
||||
#else
|
||||
setenv("KMP_BLOCKTIME", "200", 0); // 200ms
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
|
|
@ -485,8 +485,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_
|
|||
int32_t start = ith * task_per_thread;
|
||||
int32_t end = std::min((ith + 1) * task_per_thread, task_count);
|
||||
for (int32_t compute_idx = start; compute_idx < end; compute_idx++) {
|
||||
int32_t gemm_idx = compute_idx / block_size_m;
|
||||
int32_t m_idx = compute_idx % block_size_m * block_size_m;
|
||||
int32_t gemm_idx = compute_idx / per_gemm_block_count_m;
|
||||
int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m;
|
||||
int32_t m_idx = block_idx_in_gemm * block_size_m;
|
||||
const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx];
|
||||
int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx);
|
||||
|
||||
|
|
|
|||
|
|
@ -895,6 +895,7 @@ void launch_fattn(
|
|||
const dim3 block_dim(warp_size, nwarps, 1);
|
||||
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
|
||||
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
|
||||
GGML_ASSERT(max_blocks_per_sm > 0);
|
||||
int parallel_blocks = max_blocks_per_sm;
|
||||
|
||||
dim3 blocks_num;
|
||||
|
|
|
|||
|
|
@ -2822,18 +2822,15 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
#endif
|
||||
|
||||
//TODO: remove special case once ggml_can_fuse can handle empty nodes
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
|
||||
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
|
||||
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
|
||||
|
||||
if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
|
||||
|
||||
if (node_idx + topk_moe_ops_with_norm.size() > (size_t)cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
|
||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
|
||||
}
|
||||
if (ops.size() == topk_moe_ops_with_norm.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx+8];
|
||||
|
||||
|
|
@ -2842,16 +2839,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
}
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
|
||||
|
||||
if (node_idx + topk_moe_ops.size() > (size_t)cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < topk_moe_ops.size(); i++) {
|
||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx+4];
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
|
|
@ -2859,6 +2848,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||
}
|
||||
}
|
||||
|
||||
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2, node_idx + 5 })) {
|
||||
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
|
||||
|
||||
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -2946,7 +2945,8 @@ static void evaluate_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_
|
|||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i+8];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
|
||||
/*delayed softmax*/ false);
|
||||
i += 8;
|
||||
continue;
|
||||
}
|
||||
|
|
@ -2954,11 +2954,23 @@ static void evaluate_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cuda_
|
|||
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i+4];
|
||||
ggml_tensor * selected_experts = cgraph->nodes[i+3];
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
|
||||
/*delayed softmax*/ false);
|
||||
i += 4;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i,
|
||||
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
|
||||
ggml_tensor * weights = cgraph->nodes[i + 5];
|
||||
ggml_tensor * ids = cgraph->nodes[i + 1];
|
||||
|
||||
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
|
||||
/*delayed_softmax*/ true);
|
||||
i += 5;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
int n_fuse = 0;
|
||||
ggml_op ops[8];
|
||||
|
|
|
|||
|
|
@ -4,16 +4,61 @@
|
|||
|
||||
#include <initializer_list>
|
||||
|
||||
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
|
||||
template <int experts_per_thread, bool use_limit>
|
||||
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
|
||||
float max_val = -INFINITY;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = lane + i * WARP_SIZE;
|
||||
const bool active = !use_limit || (idx < limit);
|
||||
if (active) {
|
||||
max_val = max(max_val, vals[i]);
|
||||
}
|
||||
}
|
||||
|
||||
max_val = warp_reduce_max(max_val);
|
||||
|
||||
float sum = 0.f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = lane + i * WARP_SIZE;
|
||||
const bool active = !use_limit || (idx < limit);
|
||||
if (active) {
|
||||
const float val = expf(vals[i] - max_val);
|
||||
vals[i] = val;
|
||||
sum += val;
|
||||
} else {
|
||||
vals[i] = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
const float inv_sum = 1.0f / sum;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = lane + i * WARP_SIZE;
|
||||
const bool active = !use_limit || (idx < limit);
|
||||
if (active) {
|
||||
vals[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
This kernel does the following:
|
||||
1. softmax over the logits per token [n_experts, n_tokens]
|
||||
1. optionally softmax over the logits per token [n_experts, n_tokens]
|
||||
2. argmax reduce over the top-k (n_experts_used) logits
|
||||
3. write weights + ids to global memory
|
||||
4. optionally normalize the weights
|
||||
4. optionally normalize the weights or apply softmax over the selected logits
|
||||
|
||||
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
|
||||
*/
|
||||
template <int n_experts, bool with_norm>
|
||||
template <int n_experts, bool with_norm, bool delayed_softmax = false>
|
||||
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
|
||||
float * weights,
|
||||
int32_t * ids,
|
||||
|
|
@ -30,51 +75,30 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||
|
||||
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||
|
||||
float logits_r[experts_per_thread];
|
||||
float wt[experts_per_thread];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const int expert = i + threadIdx.x;
|
||||
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
|
||||
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
|
||||
}
|
||||
|
||||
float max_val = logits_r[0];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
max_val = max(val, max_val);
|
||||
if constexpr (!delayed_softmax) {
|
||||
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
|
||||
}
|
||||
|
||||
max_val = warp_reduce_max(max_val);
|
||||
|
||||
float wt[experts_per_thread];
|
||||
float tmp = 0.f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
wt[i] = expf(val - max_val);
|
||||
tmp += wt[i];
|
||||
}
|
||||
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
wt[i] = wt[i] * inv_sum;
|
||||
}
|
||||
|
||||
//at this point, each thread holds a portion of softmax,
|
||||
//we do the argmax reduce over n_expert_used, each time marking
|
||||
//at this point, each thread holds either a portion of the softmax distribution
|
||||
//or the raw logits. We do the argmax reduce over n_expert_used, each time marking
|
||||
//the expert weight as -inf to exclude from the next iteration
|
||||
|
||||
float wt_sum = 0.f;
|
||||
|
||||
extern __shared__ float data_topk_shared[];
|
||||
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
|
||||
float output_weights[experts_per_thread];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
output_weights[i] = 0.f;
|
||||
}
|
||||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
|
|
@ -99,10 +123,13 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||
}
|
||||
}
|
||||
|
||||
if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
output_weights[k / WARP_SIZE] = max_val;
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
|
||||
wt_shared_ptr[k] = max_val;
|
||||
ids[k] = max_expert;
|
||||
if constexpr (with_norm) {
|
||||
wt_sum += max_val;
|
||||
|
|
@ -114,17 +141,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
|
|||
wt_sum = warp_reduce_sum(wt_sum);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
||||
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
||||
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
output_weights[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
|
||||
weights[i] = wt_shared_ptr[i];
|
||||
if constexpr (delayed_softmax) {
|
||||
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const int idx = i * WARP_SIZE + threadIdx.x;
|
||||
if (idx < n_expert_used) {
|
||||
weights[idx] = output_weights[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool with_norm>
|
||||
template <bool with_norm, bool delayed_softmax = false>
|
||||
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
const float * logits,
|
||||
float * weights,
|
||||
|
|
@ -132,53 +167,53 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
|||
const int n_rows,
|
||||
const int n_expert,
|
||||
const int n_expert_used) {
|
||||
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
|
||||
|
||||
const int rows_per_block = 4;
|
||||
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
|
||||
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
|
||||
|
||||
switch (n_expert) {
|
||||
case 1:
|
||||
topk_moe_cuda<1, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<1, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 2:
|
||||
topk_moe_cuda<2, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<2, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 4:
|
||||
topk_moe_cuda<4, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<4, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 8:
|
||||
topk_moe_cuda<8, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<8, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 16:
|
||||
topk_moe_cuda<16, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<16, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 32:
|
||||
topk_moe_cuda<32, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<32, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 64:
|
||||
topk_moe_cuda<64, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<64, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 128:
|
||||
topk_moe_cuda<128, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<128, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 256:
|
||||
topk_moe_cuda<256, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<256, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
case 512:
|
||||
topk_moe_cuda<512, with_norm>
|
||||
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
topk_moe_cuda<512, with_norm, delayed_softmax>
|
||||
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "fatal error");
|
||||
|
|
@ -190,7 +225,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm) {
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax) {
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
|
@ -198,7 +234,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|||
const int n_experts = logits->ne[0];
|
||||
const int n_rows = logits->ne[1];
|
||||
|
||||
const float * logits_d = (const float *) logits->src[0]->data;
|
||||
const float * logits_d = (const float *) logits->data;
|
||||
float * weights_d = (float *) weights->data;
|
||||
int32_t * ids_d = (int32_t *) ids->data;
|
||||
|
||||
|
|
@ -209,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
|||
if (with_norm) {
|
||||
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
} else {
|
||||
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
if (delayed_softmax) {
|
||||
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
} else {
|
||||
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -242,7 +282,7 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
|
|||
return true;
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
|
||||
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||
|
|
@ -250,8 +290,19 @@ std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
|
|||
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
|
||||
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
||||
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
||||
|
||||
GGML_ASSERT(!norm || !delayed_softmax);
|
||||
|
||||
if (delayed_softmax) {
|
||||
return delayed_softmax_ops;
|
||||
}
|
||||
|
||||
if (norm) {
|
||||
return norm_ops;
|
||||
}
|
||||
|
||||
return no_norm_ops;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,9 +6,10 @@
|
|||
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * logits,
|
||||
ggml_tensor * weights,
|
||||
ggml_tensor * top_k,
|
||||
const bool with_norm);
|
||||
ggml_tensor * ids,
|
||||
const bool with_norm,
|
||||
const bool delayed_softmax = false);
|
||||
|
||||
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
|
||||
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);
|
||||
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include(ExternalProject)
|
||||
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
|
||||
add_library(htp_iface OBJECT
|
||||
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
|
||||
|
||||
set_target_properties(htp_iface PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_include_directories(htp_iface PUBLIC
|
||||
${HEXAGON_SDK_ROOT}/incs
|
||||
${HEXAGON_SDK_ROOT}/incs/stddef
|
||||
${HEXAGON_SDK_ROOT}/utils/examples
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/htp
|
||||
${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
build_idl(htp/htp_iface.idl htp_iface)
|
||||
|
||||
if (CMAKE_SYSTEM_NAME MATCHES Android)
|
||||
target_link_options(htp_iface PUBLIC -llog -ldl)
|
||||
elseif (CMAKE_SYSTEM_NAME MATCHES Windows)
|
||||
target_precompile_headers(htp_iface PUBLIC <sal.h>)
|
||||
else()
|
||||
target_link_options(htp_iface PUBLIC -ldl)
|
||||
endif()
|
||||
|
||||
link_custom_library(htp_iface cdsprpc)
|
||||
link_custom_library(htp_iface rpcmem)
|
||||
|
||||
set(TARGET_NAME ggml-hexagon)
|
||||
ggml_add_backend_library(${TARGET_NAME}
|
||||
ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h)
|
||||
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE htp_iface)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
# Build HTP bits
|
||||
set(HTP_CMAKE_ARGS
|
||||
-DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
|
||||
-DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT}
|
||||
-DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT}
|
||||
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG})
|
||||
|
||||
ExternalProject_Add(htp-v73
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73")
|
||||
|
||||
ExternalProject_Add(htp-v75
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75")
|
||||
|
||||
ExternalProject_Add(htp-v79
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79")
|
||||
|
||||
ExternalProject_Add(htp-v81
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
|
||||
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81")
|
||||
|
||||
# Install Hexagon skels required at runtime
|
||||
install(FILES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so
|
||||
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so
|
||||
TYPE LIB)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,448 @@
|
|||
|
||||
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
|
||||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
#pragma clang diagnostic ignored "-Wsign-compare"
|
||||
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-common.h"
|
||||
#include "ggml-hexagon.h"
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include "htp-utils.h"
|
||||
|
||||
#include <domain.h>
|
||||
#include <remote.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
domain * get_domain(int domain_id) {
|
||||
int i = 0;
|
||||
int size = sizeof(supported_domains) / sizeof(domain);
|
||||
|
||||
for (i = 0; i < size; i++) {
|
||||
if (supported_domains[i].id == domain_id) {
|
||||
return &supported_domains[i];
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
bool is_valid_domain_id(int domain_id, int compute_only) {
|
||||
int i = 0;
|
||||
int size = sizeof(supported_domains) / sizeof(domain);
|
||||
|
||||
if (compute_only) {
|
||||
return is_CDSP(domain_id);
|
||||
}
|
||||
|
||||
for (i = 0; i < size; i++) {
|
||||
if (supported_domains[i].id == domain_id) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
int ss_info = 0;
|
||||
if (domain_type != NULL) {
|
||||
if (strcmp(domain_type, "LPASS") == 0) {
|
||||
ss_info = FASTRPC_LPASS;
|
||||
} else if (strcmp(domain_type, "HPASS") == 0) {
|
||||
ss_info = FASTRPC_HPASS;
|
||||
} else {
|
||||
ss_info = FASTRPC_NSP;
|
||||
}
|
||||
}
|
||||
system_req_payload req = { 0 };
|
||||
req.id = FASTRPC_GET_DOMAINS;
|
||||
req.sys.domains = NULL;
|
||||
fastrpc_domain * domain = NULL;
|
||||
if (ss_info != 0) {
|
||||
req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info);
|
||||
} else {
|
||||
req.sys.flags = 0;
|
||||
}
|
||||
#ifdef _WIN32
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
goto bail;
|
||||
#endif
|
||||
if (remote_system_request) {
|
||||
nErr = remote_system_request(&req);
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
// Allocate memory for domain-info array
|
||||
req.sys.max_domains = req.sys.num_domains;
|
||||
if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) {
|
||||
nErr = AEE_ENOMEMORY;
|
||||
GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains");
|
||||
goto bail;
|
||||
}
|
||||
|
||||
nErr = remote_system_request(&req);
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
|
||||
for (int i = 0; i < req.sys.num_domains; i++) {
|
||||
// Verify that only requested type domains were returned
|
||||
domain = &req.sys.domains[i];
|
||||
if (domain->type != ss_info && domain_type != NULL) {
|
||||
nErr = -1;
|
||||
GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n");
|
||||
goto bail;
|
||||
}
|
||||
}
|
||||
*domains_info = req.sys.domains;
|
||||
*num_domains = req.sys.num_domains;
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
goto bail;
|
||||
}
|
||||
bail:
|
||||
if (nErr && !req.sys.domains) {
|
||||
free(req.sys.domains);
|
||||
}
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) {
|
||||
int err = 0;
|
||||
remote_rpc_effective_domain_id_t sess = { 0 };
|
||||
|
||||
sess.domain_name = domain_name;
|
||||
sess.domain_name_len = strlen(domain_name);
|
||||
sess.session_id = session_id;
|
||||
|
||||
err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess));
|
||||
if (err) {
|
||||
GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name,
|
||||
session_id);
|
||||
return err;
|
||||
}
|
||||
|
||||
*effec_domain_id = sess.effective_domain_id;
|
||||
return err;
|
||||
}
|
||||
|
||||
int get_dsp_support(int * domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*domain = CDSP_DOMAIN_ID; // DSP domain default value is CDSP_DOMAIN_ID
|
||||
|
||||
if (remote_handle_control) {
|
||||
struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 };
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
goto bail;
|
||||
}
|
||||
|
||||
if (dsp_capability_domain.capability == 0) {
|
||||
dsp_capability_domain.domain = ADSP_DOMAIN_ID; // Check for ADSP support.
|
||||
dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT;
|
||||
dsp_capability_domain.capability = 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if (dsp_capability_domain.capability) {
|
||||
*domain = ADSP_DOMAIN_ID; // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID
|
||||
}
|
||||
}
|
||||
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (attr == VTCM_PAGE || attr == VTCM_COUNT) {
|
||||
} else {
|
||||
nErr = AEE_EBADPARM;
|
||||
GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n");
|
||||
goto bail;
|
||||
}
|
||||
if (remote_handle_control) {
|
||||
if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for VTCM information
|
||||
* Since the ADSP does not have a dedicated VTCM, we expect the output to be 0
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_vtcm_dsp;
|
||||
dsp_capability_vtcm_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_vtcm_dsp.attribute_ID = attr;
|
||||
dsp_capability_vtcm_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_vtcm_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("Unsupported domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
bool is_unsignedpd_supported(int domain_id) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
if (remote_handle_control) {
|
||||
struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 };
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n");
|
||||
return false;
|
||||
}
|
||||
if (nErr) {
|
||||
GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr);
|
||||
return false;
|
||||
}
|
||||
if (dsp_capability_domain.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n");
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool get_unsignedpd_support(void) {
|
||||
return is_unsignedpd_supported(CDSP_DOMAIN_ID);
|
||||
}
|
||||
|
||||
bool is_async_fastrpc_supported(int domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for ASYNC_FASTRPC_SUPPORT information
|
||||
* Async fastrpc is supported only on CDSP
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_async_support;
|
||||
dsp_capability_async_support.domain = (uint32_t) domain;
|
||||
dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT;
|
||||
dsp_capability_async_support.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (dsp_capability_async_support.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_status_notification_supported(int domain) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
|
||||
if (remote_handle_control) {
|
||||
/*
|
||||
* Query the DSP for STATUS_NOTIFICATION_SUPPORT information
|
||||
* DSP User PD status notification Support
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_status_notification_support;
|
||||
dsp_capability_status_notification_support.domain = (uint32_t) domain;
|
||||
dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT;
|
||||
dsp_capability_status_notification_support.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (dsp_capability_status_notification_support.capability == 1) {
|
||||
return true;
|
||||
}
|
||||
if (nErr != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return false;
|
||||
}
|
||||
|
||||
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) {
|
||||
nErr = AEE_EBADPARM;
|
||||
GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n");
|
||||
goto bail;
|
||||
}
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for HMX SUPPORT information
|
||||
* HMX is supported on CDSP only
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_hmx_dsp;
|
||||
dsp_capability_hmx_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_hmx_dsp.attribute_ID = attr;
|
||||
dsp_capability_hmx_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_hmx_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
||||
int get_hex_arch_ver(int domain, int * arch) {
|
||||
if (!remote_handle_control) {
|
||||
GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
struct remote_dsp_capability arch_ver;
|
||||
arch_ver.domain = (uint32_t) domain;
|
||||
arch_ver.attribute_ID = ARCH_VER;
|
||||
arch_ver.capability = (uint32_t) 0;
|
||||
|
||||
int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
|
||||
if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
|
||||
return AEE_EUNSUPPORTEDAPI;
|
||||
}
|
||||
|
||||
if (err != AEE_SUCCESS) {
|
||||
GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
|
||||
return err;
|
||||
}
|
||||
|
||||
switch (arch_ver.capability & 0xff) {
|
||||
case 0x73:
|
||||
*arch = 73;
|
||||
return 0;
|
||||
case 0x75:
|
||||
*arch = 75;
|
||||
return 0;
|
||||
case 0x79:
|
||||
*arch = 79;
|
||||
return 0;
|
||||
case 0x81:
|
||||
*arch = 81;
|
||||
return 0;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) {
|
||||
int nErr = AEE_SUCCESS;
|
||||
*capability = 0;
|
||||
|
||||
if (remote_handle_control) {
|
||||
if (domain == CDSP_DOMAIN_ID) {
|
||||
/*
|
||||
* Query the DSP for HVX SUPPORT information
|
||||
* HVX is supported on CDSP only
|
||||
*/
|
||||
struct remote_dsp_capability dsp_capability_hvx_dsp;
|
||||
dsp_capability_hvx_dsp.domain = (uint32_t) domain;
|
||||
dsp_capability_hvx_dsp.attribute_ID = attr;
|
||||
dsp_capability_hvx_dsp.capability = (uint32_t) 0;
|
||||
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp,
|
||||
sizeof(struct remote_dsp_capability));
|
||||
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
|
||||
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
|
||||
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
|
||||
nErr = AEE_SUCCESS;
|
||||
goto bail;
|
||||
} else if (nErr == AEE_SUCCESS) {
|
||||
*capability = dsp_capability_hvx_dsp.capability;
|
||||
} else {
|
||||
GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTED;
|
||||
GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain);
|
||||
goto bail;
|
||||
}
|
||||
} else {
|
||||
nErr = AEE_EUNSUPPORTEDAPI;
|
||||
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
|
||||
}
|
||||
|
||||
bail:
|
||||
return nErr;
|
||||
}
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
#ifndef HTP_UTILS_H
|
||||
#define HTP_UTILS_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include <AEEStdErr.h>
|
||||
#include <inttypes.h>
|
||||
#include <remote.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
/* Offset to differentiate HLOS and Hexagon error codes.
|
||||
Stores the value of AEE_EOFFSET for Hexagon. */
|
||||
#ifndef DSP_OFFSET
|
||||
# define DSP_OFFSET 0x80000400
|
||||
#endif
|
||||
|
||||
/* Errno for connection reset by peer. */
|
||||
#ifndef ECONNRESET
|
||||
# ifdef __hexagon__
|
||||
# define ECONNRESET 104
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Abstraction of different OS specific sleep APIs.
|
||||
SLEEP accepts input in seconds. */
|
||||
#ifndef SLEEP
|
||||
# ifdef __hexagon__
|
||||
# define SLEEP(x) \
|
||||
{ /* Do nothing for simulator. */ \
|
||||
}
|
||||
# else
|
||||
# ifdef _WINDOWS
|
||||
# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
|
||||
# else
|
||||
# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
|
||||
# endif
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* Include windows specific header files. */
|
||||
#ifdef _WINDOWS
|
||||
# include <sysinfoapi.h>
|
||||
# include <windows.h>
|
||||
# define _CRT_SECURE_NO_WARNINGS 1
|
||||
# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
|
||||
/* Including this file for custom implementation of getopt function. */
|
||||
# include "getopt_custom.h"
|
||||
#endif
|
||||
|
||||
/* Includes and defines for all HLOS except windows */
|
||||
#if !defined(__hexagon__) && !defined(_WINDOWS)
|
||||
# include "unistd.h"
|
||||
|
||||
# include <sys/time.h>
|
||||
#endif
|
||||
|
||||
/* Includes and defines for Hexagon and all HLOS except Windows. */
|
||||
#if !defined(_WINDOWS)
|
||||
/* Weak reference to remote symbol for compilation. */
|
||||
# pragma weak remote_session_control
|
||||
# pragma weak remote_handle_control
|
||||
# pragma weak remote_handle64_control
|
||||
# pragma weak fastrpc_mmap
|
||||
# pragma weak fastrpc_munmap
|
||||
#endif
|
||||
|
||||
#if !defined(_WINDOWS)
|
||||
# pragma weak remote_system_request
|
||||
#endif
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query DSP support.
|
||||
*
|
||||
* @param[out] domain pointer to supported domain.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*/
|
||||
int get_dsp_support(int * domain);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query VTCM information.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*/
|
||||
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain.
|
||||
*
|
||||
* @return true if unsigned pd is supported.
|
||||
* false if unsigned pd is not supported, capability query failed.
|
||||
*/
|
||||
|
||||
bool get_unsignedpd_support(void);
|
||||
|
||||
/**
|
||||
* Wrapper for FastRPC Capability API: query unsigned pd support.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @return true if unsigned pd is supported.
|
||||
* false if unsigned pd is not supported, capability query failed.
|
||||
*/
|
||||
|
||||
bool is_unsignedpd_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* is_valid_domain_id API: query a domain id is valid.
|
||||
*
|
||||
* @param[in] domain value of domain in the queried.
|
||||
* @param[in] compute_only value of domain is only compared with CDSP domains supported by the target when enabled.
|
||||
* @return true if value of domain is valid.
|
||||
* false if value of domain is not valid.
|
||||
*/
|
||||
|
||||
bool is_valid_domain_id(int domain_id, int compute_only);
|
||||
|
||||
/**
|
||||
* get_domain API: get domain struct from domain value.
|
||||
*
|
||||
* @param[in] domain value of a domain
|
||||
* @return Returns domain struct of the domain if it is supported or else
|
||||
* returns NULL.
|
||||
*
|
||||
*/
|
||||
|
||||
domain * get_domain(int domain_id);
|
||||
|
||||
/**
|
||||
* get_domains_info API: get information for all the domains available on the device
|
||||
*
|
||||
* @param[in] domain_type pointer to domain type
|
||||
* @param[in] num_domains pointer to number of domains
|
||||
* @param[in] domains_info pointer to save discovered domains information.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
* It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application.
|
||||
*
|
||||
*/
|
||||
|
||||
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info);
|
||||
|
||||
/**
|
||||
* get_effective_domain_id API: get effective domain id for given session id
|
||||
*
|
||||
* @param[in] domain_name pointer to domain name
|
||||
* @param[in] session_id
|
||||
* @param[in] effec_domain_id pointer to save obtained effective domain id.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
|
||||
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id);
|
||||
|
||||
/**
|
||||
* is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @return Returns true or false stating support of Async FastRPC
|
||||
*
|
||||
*/
|
||||
|
||||
bool is_async_fastrpc_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @return Returns true or false stating status notification support information
|
||||
*
|
||||
*/
|
||||
bool is_status_notification_supported(int domain_id);
|
||||
|
||||
/**
|
||||
* get_hmx_support_info API: query the DSP for HMX SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
/**
|
||||
* get_hex_arch_ver API: query the Hexagon processor architecture version information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] Arch version (73, 75, ...)
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hex_arch_ver(int domain, int * arch);
|
||||
|
||||
/**
|
||||
* get_hvx_support_info API: query the DSP for HVX SUPPORT information
|
||||
*
|
||||
* @param[in] domain_id value of a domain
|
||||
* @param[out] capability capability value of the attribute queried.
|
||||
* @param[in] attr value of the attribute to the queried.
|
||||
* @return 0 if query is successful.
|
||||
* non-zero if error, return value points to the error.
|
||||
*
|
||||
*/
|
||||
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif //DSP_CAPABILITIES_UTILS_H
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
cmake_minimum_required(VERSION 3.22.2)
|
||||
project(ggml-htp C CXX ASM)
|
||||
|
||||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
|
||||
include_directories(
|
||||
${HEXAGON_SDK_ROOT}/incs
|
||||
${HEXAGON_SDK_ROOT}/incs/stddef
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../..
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/..
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
set(HTP_LIB ggml-htp-${DSP_VERSION})
|
||||
|
||||
add_library(${HTP_LIB} SHARED
|
||||
main.c
|
||||
htp_iface_skel.c
|
||||
worker-pool.c
|
||||
htp-dma.c
|
||||
hvx-sigmoid.c
|
||||
hvx-inverse.c
|
||||
hvx-exp.c
|
||||
hvx-utils.c
|
||||
matmul-ops.c
|
||||
binary-ops.c
|
||||
unary-ops.c
|
||||
softmax-ops.c
|
||||
act-ops.c
|
||||
rope-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>)
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
install(TARGETS ${HTP_LIB})
|
||||
|
|
@ -0,0 +1,448 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define htp_act_preamble3 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1->ne[0]; \
|
||||
const uint32_t ne11 = src1->ne[1]; \
|
||||
const uint32_t ne12 = src1->ne[2]; \
|
||||
const uint32_t ne13 = src1->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = src1->nb[0]; \
|
||||
const uint32_t nb11 = src1->nb[1]; \
|
||||
const uint32_t nb12 = src1->nb[2]; \
|
||||
const uint32_t nb13 = src1->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
#define htp_act_preamble2 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
bool src1_valid = src1->ne[0];
|
||||
if (!src1_valid) {
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
}
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||
|
||||
const int32_t swapped = op_params[1];
|
||||
|
||||
const int nc = (src1_valid) ? ne0 : ne0 / 2;
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
|
||||
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||
}
|
||||
|
||||
if (!src1_valid) {
|
||||
src0 += swapped ? nc : 0;
|
||||
src1 += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
|
||||
hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
|
||||
(uint8_t *) dst, nc);
|
||||
} else {
|
||||
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, nc, true);
|
||||
hvx_add_scalar_f32(src0_spad_data, 1.0, src1_spad_data, nc);
|
||||
hvx_inverse_f32(src1_spad_data, src0_spad_data, nc);
|
||||
|
||||
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, dst_spad_data, nc);
|
||||
hvx_mul_f32(dst_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "swiglu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
|
||||
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
htp_act_preamble3;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t src1_row_size = nb11;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||
FARF(HIGH, "act-f32: unaligned addresses in activations op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
bool src1_valid = src1->ne[0];
|
||||
if (!src1_valid) {
|
||||
data_src1 = data_src0;
|
||||
}
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||
|
||||
const int32_t swapped = op_params[1];
|
||||
const float alpha = ((const float *) (op_params))[2];
|
||||
const float limit = ((const float *) (op_params))[3];
|
||||
|
||||
const int nc = (src1_valid) ? ne0 : ne0 / 2;
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
|
||||
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||
}
|
||||
|
||||
if (!src1) {
|
||||
src0 += swapped ? nc : 0;
|
||||
src1 += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
// x (src0_spad_data) = std::min(src0_p[k], limit);
|
||||
hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
|
||||
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
|
||||
hvx_clamp_scalar_f32((const uint8_t *) src1, limit, limit, src1_spad_data, nc);
|
||||
// y (src1_spad_data) = y1 + 1.f
|
||||
hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc);
|
||||
// x1 (dst_spad_data) = alpha * (x)
|
||||
hvx_mul_scalar_f32(src0_spad_data, alpha, dst_spad_data, nc);
|
||||
// x2 (dst_spad_data) = expf(-x1)
|
||||
hvx_exp_f32(dst_spad_data, dst_spad_data, nc, true);
|
||||
// x3 (dst_spad_data) = x2 + 1.f
|
||||
hvx_add_scalar_f32(dst_spad_data, 1.0, dst_spad_data, nc);
|
||||
// x4 (dst_spad_data) = 1 / x3
|
||||
hvx_inverse_f32(dst_spad_data, dst_spad_data, nc);
|
||||
// out_glu(dst_spad_data) = x * x4
|
||||
hvx_mul_f32(src0_spad_data, dst_spad_data, dst_spad_data, nc);
|
||||
// out = out_glu * (y + 1.f);
|
||||
hvx_mul_f32(dst_spad_data, src1_spad_data, (uint8_t *) dst, nc);
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0],
|
||||
src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
htp_act_preamble2;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
|
||||
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, ne0);
|
||||
hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
|
||||
} else {
|
||||
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, ne0, true);
|
||||
hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0);
|
||||
hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0);
|
||||
|
||||
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "silu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02,
|
||||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
octx->src0_nrows_per_thread);
|
||||
}
|
||||
|
||||
static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
|
||||
}
|
||||
|
||||
static int execute_op_activations_fp32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) {
|
||||
FARF(ERROR, "Non-contiguous tensors are not supported at this time \n");
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
worker_callback_t act_op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_UNARY_SILU:
|
||||
act_op_func = unary_silu_fp32;
|
||||
op_type = "silu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU:
|
||||
act_op_func = glu_swiglu_fp32;
|
||||
op_type = "swiglu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU_OAI:
|
||||
act_op_func = glu_swiglu_oai_fp32;
|
||||
op_type = "swiglu-oai-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src1->ne[0] ? src1->nb[1] : src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * octx->n_threads;
|
||||
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * octx->n_threads;
|
||||
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * octx->n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (src1->ne[0]) {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
octx->dst_spad.size);
|
||||
} else {
|
||||
FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
}
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "act-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
int op_activations(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_activations_fp32(octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
|
@ -0,0 +1,344 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0,
|
||||
const uint8_t * src1,
|
||||
uint8_t * data_dst,
|
||||
const int num_elems);
|
||||
|
||||
static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 };
|
||||
static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt };
|
||||
|
||||
#define htp_binary_preamble \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne10 = src1->ne[0]; \
|
||||
const uint32_t ne11 = src1->ne[1]; \
|
||||
const uint32_t ne12 = src1->ne[2]; \
|
||||
const uint32_t ne13 = src1->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb10 = src1->nb[0]; \
|
||||
const uint32_t nb11 = src1->nb[1]; \
|
||||
const uint32_t nb12 = src1->nb[2]; \
|
||||
const uint32_t nb13 = src1->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
static void binary_job_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
uint8_t * spad_data,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
enum htp_op op) {
|
||||
htp_binary_preamble;
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t src1_row_size = nb11;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||
FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
is_aligned = 0;
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op];
|
||||
|
||||
uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size);
|
||||
|
||||
const uint32_t nr0 = ne00 / ne10;
|
||||
|
||||
const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size);
|
||||
uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size);
|
||||
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
const uint8_t * restrict src1_ptr = NULL;
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
src1_ptr = data_src1 + (ir % src1_nrows) * src1_row_size;
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
|
||||
if (src1_row_size == src0_row_size) {
|
||||
htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (nr0 > 1) {
|
||||
if ((1 == is_aligned) && (nr0 == ne00)) {
|
||||
hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0);
|
||||
} else {
|
||||
for (uint32_t r = 0; r < nr0; r++) {
|
||||
memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11);
|
||||
}
|
||||
}
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00);
|
||||
} else {
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
|
||||
}
|
||||
|
||||
src0_ptr += src0_row_size;
|
||||
dst_ptr += dst_row_size;
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
|
||||
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
const struct htp_tensor * src2,
|
||||
struct htp_tensor * dst,
|
||||
uint8_t * spad_data,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
hvx_elemwise_f32_func func_HVX) {
|
||||
htp_binary_preamble;
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t src1_row_size = nb11;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t ne02_ne01 = ne02 * ne01;
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||
FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n");
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
|
||||
// src0 indices
|
||||
const uint32_t i03 = ir / ne02_ne01;
|
||||
const uint32_t i02 = (ir - i03 * ne02_ne01) / ne01;
|
||||
const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01);
|
||||
|
||||
// src1 indices
|
||||
const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]);
|
||||
assert(i11 >= 0 && i11 < ne11);
|
||||
|
||||
float * restrict dst_ptr = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1);
|
||||
const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01);
|
||||
const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11);
|
||||
|
||||
if (ir + 1 < src0_end_row) {
|
||||
htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size);
|
||||
if (src1_row_size == src0_row_size) {
|
||||
htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size);
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t nr0 = ne00 / ne10;
|
||||
if (nr0 > 1) {
|
||||
for (uint32_t r = 0; r < nr0; r++) {
|
||||
memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10);
|
||||
}
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00);
|
||||
} else {
|
||||
func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1],
|
||||
src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1],
|
||||
dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_MUL:
|
||||
case HTP_OP_ADD:
|
||||
case HTP_OP_SUB:
|
||||
binary_job_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->src1_spad.data, n, i,
|
||||
octx->src0_nrows_per_thread, octx->op);
|
||||
break;
|
||||
|
||||
case HTP_OP_ADD_ID:
|
||||
binary_add_id_job_f32_per_thread(&octx->src0, &octx->src1, &octx->src2, &octx->dst, octx->src0_spad.data, n,
|
||||
i, octx->src0_nrows_per_thread, hvx_add_f32);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Binary Op %u", octx->op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
static int execute_op_binary_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t binary_op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_MUL:
|
||||
binary_op_func = binary_job_dispatcher_f32;
|
||||
op_type = "mul-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_ADD:
|
||||
binary_op_func = binary_job_dispatcher_f32;
|
||||
op_type = "add-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_SUB:
|
||||
binary_op_func = binary_job_dispatcher_f32;
|
||||
op_type = "sub-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_ADD_ID:
|
||||
binary_op_func = binary_job_dispatcher_f32;
|
||||
op_type = "add-id-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported binary-Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const int n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src1->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
FARF(HIGH,
|
||||
"%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
octx->dst_spad.size);
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type,
|
||||
octx->ctx->vtcm_size, spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
int op_binary(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_binary_f32(octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
|
@ -0,0 +1,157 @@
|
|||
if (HEXAGON_TOOLCHAIN_INCLUDED)
|
||||
return()
|
||||
endif()
|
||||
set(HEXAGON_TOOLCHAIN_INCLUDED true)
|
||||
|
||||
#Cross Compiling for Hexagon
|
||||
set(HEXAGON TRUE)
|
||||
set(CMAKE_SYSTEM_NAME QURT)
|
||||
set(CMAKE_SYSTEM_PROCESSOR Hexagon)
|
||||
set(CMAKE_SYSTEM_VERSION "1") #${HEXAGON_PLATFORM_LEVEL})
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
|
||||
set(CUSTOM_RUNELF_PATH "")
|
||||
|
||||
#To fix backward compatibility with EAI addon.
|
||||
if (NOT HEXAGON_SDK_ROOT)
|
||||
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
|
||||
endif()
|
||||
|
||||
if (NOT HEXAGON_TOOLS_ROOT)
|
||||
if (DEFINED ENV{HEXAGON_TOOLS_ROOT})
|
||||
set(HEXAGON_TOOLS_ROOT $ENV{HEXAGON_TOOLS_ROOT})
|
||||
endif()
|
||||
if(NOT HEXAGON_TOOLS_ROOT)
|
||||
set(HEXAGON_TOOLS_ROOT $ENV{DEFAULT_HEXAGON_TOOLS_ROOT})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
|
||||
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
|
||||
|
||||
#Get the Binary extension of the Hexagon Toolchain
|
||||
if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
|
||||
set(HEXAGON_TOOLCHAIN_SUFFIX .exe)
|
||||
endif()
|
||||
message(DEBUG "CMAKE_HOST_SYSTEM_NAME:${CMAKE_HOST_SYSTEM_NAME}")
|
||||
|
||||
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_arch.cmake)
|
||||
|
||||
set(HEXAGON_TOOLCHAIN ${HEXAGON_TOOLS_ROOT})
|
||||
set(HEXAGON_LIB_DIR "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib")
|
||||
set(HEXAGON_ISS_DIR ${HEXAGON_TOOLCHAIN}/Tools/lib/iss)
|
||||
|
||||
set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES
|
||||
HEXAGON_SDK_ROOT
|
||||
HEXAGON_TOOLS_ROOT
|
||||
)
|
||||
|
||||
#QURT Related includes and linker flags
|
||||
set(V_ARCH ${HEXAGON_ARCH})
|
||||
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}")
|
||||
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||
|
||||
if( ${TREE} MATCHES PAKMAN )
|
||||
set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||
endif()
|
||||
message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}")
|
||||
set(RTOS_DIR ${_QURT_INSTALL_DIR})
|
||||
set(QCC_DIR "${HEXAGON_QCC_DIR}/${V_ARCH}/G0")
|
||||
set(TARGET_DIR "${HEXAGON_LIB_DIR}/${V_ARCH}/G0")
|
||||
|
||||
include_directories(
|
||||
${_QURT_INSTALL_DIR}/include
|
||||
${_QURT_INSTALL_DIR}/include/qurt
|
||||
${_QURT_INSTALL_DIR}/include/posix
|
||||
)
|
||||
|
||||
set(QURT_START_LINK_LIBS)
|
||||
set(QURT_START_LINK_LIBS
|
||||
"${TARGET_DIR}/init.o"
|
||||
"${RTOS_DIR}/lib/crt1.o"
|
||||
"${RTOS_DIR}/lib/debugmon.o"
|
||||
"${RTOS_DIR}/lib/libqurt.a"
|
||||
"${TARGET_DIR}/libc.a"
|
||||
"${TARGET_DIR}/libqcc.a"
|
||||
"${TARGET_DIR}/libhexagon.a"
|
||||
"${RTOS_DIR}/lib/libqurtcfs.a"
|
||||
"${RTOS_DIR}/lib/libtimer_island.a"
|
||||
"${RTOS_DIR}/lib/libtimer_main.a"
|
||||
"${RTOS_DIR}/lib/libposix.a"
|
||||
)
|
||||
STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}")
|
||||
|
||||
set(QURT_END_LINK_LIBS
|
||||
${TARGET_DIR}/fini.o
|
||||
)
|
||||
|
||||
#Non QURT related includes and linker flags
|
||||
|
||||
set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}")
|
||||
|
||||
if (NOT NO_WRAP_MEM_API)
|
||||
set(WRAP_MALLOC -Wl,--wrap=malloc)
|
||||
set(WRAP_CALLOC -Wl,--wrap=calloc)
|
||||
set(WRAP_FREE -Wl,--wrap=free)
|
||||
set(WRAP_REALLOC -Wl,--wrap=realloc)
|
||||
set(WRAP_MEMALIGN -Wl,--wrap=memalign)
|
||||
endif()
|
||||
|
||||
set(PIC_SHARED_LD_FLAGS
|
||||
-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}
|
||||
-G0
|
||||
-fpic
|
||||
-Wl,-Bsymbolic
|
||||
-Wl,-L${TARGET_DIR_NOOS}/G0/pic
|
||||
-Wl,-L${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/
|
||||
-Wl,--no-threads ${WRAP_MALLOC} ${WRAP_CALLOC} ${WRAP_FREE} ${WRAP_REALLOC} ${WRAP_MEMALIGN}
|
||||
-shared
|
||||
"-o <TARGET> <SONAME_FLAG><TARGET_SONAME>"
|
||||
"<LINK_FLAGS>"
|
||||
-Wl,--start-group
|
||||
"<OBJECTS>"
|
||||
"<LINK_LIBRARIES>"
|
||||
-Wl,--end-group
|
||||
-lc
|
||||
)
|
||||
STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}")
|
||||
|
||||
set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}")
|
||||
|
||||
#System include paths
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)
|
||||
|
||||
#LLVM toolchain setup
|
||||
#Compiler paths, options and architecture
|
||||
set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_ASM_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(HEXAGON_LINKER ${CMAKE_C_COMPILER})
|
||||
set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)
|
||||
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
|
||||
|
||||
#Compiler Options
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
|
||||
set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3")
|
||||
|
||||
set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}")
|
||||
set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}")
|
||||
set(CMAKE_ASM_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}" )
|
||||
|
||||
#Linker Options
|
||||
set(CMAKE_C_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}")
|
||||
set(CMAKE_CXX_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}")
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
#ifndef HTP_CTX_H
|
||||
#define HTP_CTX_H
|
||||
|
||||
#include "htp-dma.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <dspqueue.h>
|
||||
#include <stdatomic.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#define HTP_MAX_NTHREADS 10
|
||||
|
||||
// FIXME: move these into matmul-ops
|
||||
#define HTP_SPAD_SRC0_NROWS 16
|
||||
#define HTP_SPAD_SRC1_NROWS 16
|
||||
#define HTP_SPAD_DST_NROWS 2
|
||||
|
||||
// Main context for htp DSP backend
|
||||
struct htp_context {
|
||||
dspqueue_t queue;
|
||||
dma_queue * dma[HTP_MAX_NTHREADS];
|
||||
worker_pool_context_t worker_pool;
|
||||
uint32_t n_threads;
|
||||
|
||||
int thread_id;
|
||||
int thread_prio;
|
||||
|
||||
uint8_t * vtcm_base;
|
||||
size_t vtcm_size;
|
||||
uint32_t vtcm_rctx;
|
||||
|
||||
atomic_bool vtcm_valid;
|
||||
atomic_bool vtcm_inuse;
|
||||
atomic_bool vtcm_needs_release;
|
||||
|
||||
uint32_t opmask;
|
||||
};
|
||||
|
||||
#endif /* HTP_CTX_H */
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
#include "htp-dma.h"
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
|
||||
static inline uint32_t pow2_ceil(uint32_t x) {
|
||||
if (x <= 1) {
|
||||
return 1;
|
||||
}
|
||||
int p = 2;
|
||||
x--;
|
||||
while (x >>= 1) {
|
||||
p <<= 1;
|
||||
}
|
||||
return p;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue_create(size_t capacity) {
|
||||
dma_queue * q = (dma_queue *) memalign(32, sizeof(dma_queue));
|
||||
if (q == NULL) {
|
||||
FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
capacity = pow2_ceil(capacity);
|
||||
|
||||
memset(q, 0, sizeof(dma_queue));
|
||||
q->capacity = capacity;
|
||||
q->idx_mask = capacity - 1;
|
||||
|
||||
q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t));
|
||||
memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t));
|
||||
|
||||
q->dst = (void **) memalign(4, capacity * sizeof(void *));
|
||||
memset(q->dst, 0, capacity * sizeof(void *));
|
||||
|
||||
q->tail = &q->desc[capacity - 1];
|
||||
|
||||
if (!q->desc && !q->dst) {
|
||||
FARF(ERROR, "%s: failed to allocate DMA queue items\n", __FUNCTION__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
FARF(HIGH, "dma-queue: capacity %u\n", capacity);
|
||||
|
||||
return q;
|
||||
}
|
||||
|
||||
void dma_queue_delete(dma_queue * q) {
|
||||
if (!q) {
|
||||
return;
|
||||
}
|
||||
free(q->desc);
|
||||
free(q->dst);
|
||||
free(q);
|
||||
}
|
||||
|
||||
void dma_queue_flush(dma_queue * q) {
|
||||
while (1) {
|
||||
uint32_t s = dmwait() & 0x3;
|
||||
if (s == HEXAGON_UDMA_DM0_STATUS_IDLE) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
q->tail = NULL;
|
||||
}
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
#ifndef HTP_DMA_H
|
||||
#define HTP_DMA_H
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
hexagon_udma_descriptor_type1_t * desc; // descriptor pointers
|
||||
hexagon_udma_descriptor_type1_t * tail; // tail pointer
|
||||
void ** dst; // dst pointers
|
||||
uint32_t push_idx;
|
||||
uint32_t pop_idx;
|
||||
uint32_t capacity;
|
||||
uint32_t idx_mask;
|
||||
} dma_queue;
|
||||
|
||||
dma_queue * dma_queue_create(size_t capacity);
|
||||
void dma_queue_delete(dma_queue * q);
|
||||
void dma_queue_flush(dma_queue * q);
|
||||
|
||||
// TODO: technically we don't need these and could use Q6_dmstart/wait/etc instead
|
||||
// but those do not seem to always compiler properly.
|
||||
static inline void dmstart(void * next) {
|
||||
asm volatile(" release(%0):at" : : "r"(next));
|
||||
asm volatile(" dmstart(%0)" : : "r"(next));
|
||||
}
|
||||
|
||||
static inline void dmlink(void * cur, void * next) {
|
||||
asm volatile(" release(%0):at" : : "r"(next));
|
||||
asm volatile(" dmlink(%0, %1)" : : "r"(cur), "r"(next));
|
||||
}
|
||||
|
||||
static inline unsigned int dmpoll(void) {
|
||||
unsigned int ret = 0;
|
||||
asm volatile(" %0 = dmpoll" : "=r"(ret) : : "memory");
|
||||
return ret;
|
||||
}
|
||||
|
||||
static inline unsigned int dmwait(void) {
|
||||
unsigned int ret = 0;
|
||||
asm volatile(" %0 = dmwait" : "=r"(ret) : : "memory");
|
||||
return ret;
|
||||
}
|
||||
|
||||
static inline bool dma_queue_push(dma_queue * q,
|
||||
void * dst,
|
||||
const void * src,
|
||||
size_t dst_row_size,
|
||||
size_t src_row_size,
|
||||
size_t nrows) {
|
||||
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
|
||||
return false;
|
||||
}
|
||||
|
||||
hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx];
|
||||
|
||||
desc->next = NULL;
|
||||
desc->length = 0;
|
||||
desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1;
|
||||
desc->dstbypass = 1;
|
||||
desc->srcbypass = 1;
|
||||
desc->order = 0;
|
||||
desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE;
|
||||
desc->src = (void *) src;
|
||||
desc->dst = (void *) dst;
|
||||
desc->allocation = 0;
|
||||
desc->padding = 0;
|
||||
desc->roiwidth = src_row_size;
|
||||
desc->roiheight = nrows;
|
||||
desc->srcstride = src_row_size;
|
||||
desc->dststride = dst_row_size;
|
||||
desc->srcwidthoffset = 0;
|
||||
desc->dstwidthoffset = 0;
|
||||
|
||||
q->dst[q->push_idx] = dst;
|
||||
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = desc;
|
||||
|
||||
// FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
|
||||
q->push_idx = (q->push_idx + 1) & q->idx_mask;
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline uint8_t * dma_queue_pop(dma_queue * q) {
|
||||
if (q->push_idx == q->pop_idx) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx];
|
||||
|
||||
// Wait for desc to complete
|
||||
while (1) {
|
||||
dmpoll();
|
||||
if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) {
|
||||
break;
|
||||
}
|
||||
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
|
||||
}
|
||||
|
||||
uint8_t * dst = (uint8_t *) q->dst[q->pop_idx];
|
||||
|
||||
// FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
|
||||
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
|
||||
return dst;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
#endif /* HTP_DMA_H */
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
#ifndef HTP_MSG_H
|
||||
#define HTP_MSG_H
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
// ggml-common.h must be included prio to this header
|
||||
|
||||
// Mask to enable various stages of the Ops.
|
||||
// Used for debugging and profiling.
|
||||
enum {
|
||||
HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP)
|
||||
HTP_OPMASK_QUANTIZE = (1 << 1), // Enable Quantize
|
||||
HTP_OPMASK_COMPUTE = (1 << 2), // Enable Compute
|
||||
};
|
||||
|
||||
// Op flags
|
||||
enum {
|
||||
HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0), // Skip dynamic quantization (reuse quantized tensors)
|
||||
HTP_OPFLAGS_SKIP_COMPUTE = (1 << 1), // Skip actual computation (used for profiling)
|
||||
HTP_OPFLAGS_EARLY_WAKEUP = (1 << 2) // Send early wakeup notification
|
||||
};
|
||||
|
||||
enum htp_status {
|
||||
HTP_STATUS_OK = 1,
|
||||
HTP_STATUS_INTERNAL_ERR = 2,
|
||||
HTP_STATUS_NO_SUPPORT = 3,
|
||||
HTP_STATUS_INVAL_PARAMS = 4,
|
||||
HTP_STATUS_VTCM_TOO_SMALL = 5,
|
||||
};
|
||||
|
||||
// The values must match the ggml_type.
|
||||
// Duplicated here because we can't include full ggml.h in the htp build.
|
||||
// We have some static_asserts in the cpp code to ensure things are in sync.
|
||||
enum htp_data_type {
|
||||
HTP_TYPE_F32 = 0,
|
||||
HTP_TYPE_F16 = 1,
|
||||
HTP_TYPE_Q4_0 = 2,
|
||||
HTP_TYPE_Q8_0 = 8,
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
HTP_TYPE_COUNT
|
||||
};
|
||||
|
||||
// These values are manually translated over to HTP
|
||||
// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
|
||||
enum htp_op {
|
||||
HTP_OP_MUL = 0,
|
||||
HTP_OP_ADD = 1,
|
||||
HTP_OP_SUB = 2,
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT = 4,
|
||||
HTP_OP_MUL_MAT_ID = 5,
|
||||
HTP_OP_RMS_NORM = 6,
|
||||
HTP_OP_UNARY_SILU = 7,
|
||||
HTP_OP_GLU_SWIGLU = 8,
|
||||
HTP_OP_GLU_SWIGLU_OAI = 9,
|
||||
HTP_OP_SOFTMAX = 10,
|
||||
HTP_OP_ADD_ID = 11,
|
||||
HTP_OP_ROPE = 12,
|
||||
INVALID
|
||||
};
|
||||
|
||||
static inline size_t htp_type_block_size(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return 1;
|
||||
case HTP_TYPE_F16:
|
||||
return 1;
|
||||
case HTP_TYPE_Q4_0:
|
||||
return QK4_0;
|
||||
case HTP_TYPE_Q8_0:
|
||||
return QK8_0;
|
||||
case HTP_TYPE_MXFP4:
|
||||
return QK_MXFP4;
|
||||
default:
|
||||
assert(0 && "unsupported HTP data type");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static inline size_t htp_type_nbytes(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return 4;
|
||||
case HTP_TYPE_F16:
|
||||
return 2;
|
||||
case HTP_TYPE_Q4_0:
|
||||
return sizeof(block_q4_0);
|
||||
case HTP_TYPE_Q8_0:
|
||||
return sizeof(block_q8_0);
|
||||
case HTP_TYPE_MXFP4:
|
||||
return sizeof(block_mxfp4);
|
||||
default:
|
||||
assert(0 && "unsupported HTP data type");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static const char * htp_type_name(uint32_t t) {
|
||||
switch (t) {
|
||||
case HTP_TYPE_F32:
|
||||
return "fp32";
|
||||
case HTP_TYPE_F16:
|
||||
return "fp16";
|
||||
case HTP_TYPE_Q4_0:
|
||||
return "q4_0";
|
||||
case HTP_TYPE_Q8_0:
|
||||
return "q8_0";
|
||||
case HTP_TYPE_MXFP4:
|
||||
return "mxfp4";
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Internal types
|
||||
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
|
||||
|
||||
#define HTP_MAX_DIMS 4
|
||||
|
||||
struct htp_tensor {
|
||||
uint32_t data; // Buffer offset in the messages, and data pointer on the NSP
|
||||
uint32_t type; // Data type
|
||||
uint32_t ne[HTP_MAX_DIMS]; // Number of elements
|
||||
uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor)
|
||||
};
|
||||
|
||||
#define HTP_MAX_OP_PARAMS 64
|
||||
|
||||
struct htp_general_req {
|
||||
uint32_t op; // GGML/HTP Op
|
||||
int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
// Params for the op, e.g. epsilon of RMS norm
|
||||
uint32_t flags; // Request flags
|
||||
|
||||
struct htp_tensor src0; // Input0 tensor
|
||||
struct htp_tensor src1; // Input1 tensor
|
||||
struct htp_tensor src2; // Input2 tensor
|
||||
struct htp_tensor dst; // Output tensor
|
||||
|
||||
// should be multiple of 64 bytes (cacheline)
|
||||
};
|
||||
|
||||
struct htp_general_rsp {
|
||||
uint32_t op; // GGML/HTP Op
|
||||
uint32_t status; // HTP_STATUS_...
|
||||
uint32_t prof_usecs; // Number of usec per request
|
||||
uint32_t prof_cycles; // Number of cycles per request
|
||||
uint32_t prof_pkts; // Number of instruction packets per request
|
||||
uint8_t unused[44]; // Pad to 64 bytes
|
||||
};
|
||||
|
||||
#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
|
||||
#define HTP_MAX_PACKET_BUFFERS 4
|
||||
|
||||
#endif /* HTP_MSG_H */
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
#ifndef HTP_OPS_H
|
||||
#define HTP_OPS_H
|
||||
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-msg.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// ggml-common.h must be included prior to this header
|
||||
|
||||
struct htp_spad {
|
||||
uint8_t * data;
|
||||
size_t size;
|
||||
size_t size_per_thread;
|
||||
};
|
||||
|
||||
struct htp_ops_context {
|
||||
struct htp_context * ctx;
|
||||
|
||||
enum htp_op op;
|
||||
int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
|
||||
struct htp_tensor src0;
|
||||
struct htp_tensor src1;
|
||||
struct htp_tensor src2;
|
||||
struct htp_tensor dst;
|
||||
|
||||
struct htp_spad src0_spad;
|
||||
struct htp_spad src1_spad;
|
||||
struct htp_spad src2_spad;
|
||||
struct htp_spad dst_spad;
|
||||
|
||||
worker_pool_context_t * wpool; // worker pool
|
||||
uint32_t n_threads; // num threads
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
uint32_t src1_nrows_per_thread;
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx);
|
||||
int op_matmul_id(struct htp_ops_context * octx);
|
||||
int op_binary(struct htp_ops_context * octx);
|
||||
int op_unary(struct htp_ops_context * octx);
|
||||
int op_activations(struct htp_ops_context * octx);
|
||||
int op_softmax(struct htp_ops_context * octx);
|
||||
int op_add_id(struct htp_ops_context * octx);
|
||||
int op_rope(struct htp_ops_context * octx);
|
||||
|
||||
#endif /* HTP_OPS_H */
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
// FastRPC IDL interface for GGML HTP
|
||||
|
||||
#ifndef HTP_IDL
|
||||
#define HTP_IDL
|
||||
|
||||
#include "AEEStdDef.idl"
|
||||
#include "remote.idl"
|
||||
|
||||
interface htp_iface : remote_handle64 {
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx);
|
||||
AEEResult stop();
|
||||
AEEResult enable_etm();
|
||||
AEEResult disable_etm();
|
||||
};
|
||||
|
||||
#endif /* HTP_IDL */
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_exp_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
// assert((0 == unaligned_addr) || (0 == num_elems_whole));
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_exp_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
HVX_Vector vec_out = Q6_V_vzero();
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * p_vec_in1 = (HVX_Vector *) src;
|
||||
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++);
|
||||
*p_vec_out++ = hvx_vec_exp_fp32(neg_vec_in);
|
||||
} else {
|
||||
*p_vec_out++ = hvx_vec_exp_fp32(*p_vec_in1++);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(neg_vec_in);
|
||||
} else {
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(in);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
if (true == negate) {
|
||||
HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in);
|
||||
|
||||
vec_out = hvx_vec_exp_fp32(neg_vec_in);
|
||||
} else {
|
||||
vec_out = hvx_vec_exp_fp32(in);
|
||||
}
|
||||
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_inverse_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
// assert((0 == unaligned_addr) || (0 == num_elems_whole));
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * p_vec_in = (HVX_Vector *) src;
|
||||
HVX_Vector * p_vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
*p_vec_out++ = hvx_vec_inverse_fp32(*p_vec_in++);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32(in);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
HVX_Vector out = hvx_vec_inverse_fp32(in);
|
||||
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#if 0
|
||||
// Reference algo used in hvx-utils
|
||||
static void fast_sigmoid_f32(const float* restrict src, float* restrict dst, const int num_elems)
|
||||
{
|
||||
const float c1 = 0.03138777;
|
||||
const float c2 = 0.276281267;
|
||||
const float c_log2f = 1.442695022;
|
||||
|
||||
int32_t store_ints[32];
|
||||
float store_floats[3][32];
|
||||
|
||||
for (int i = 0; i < num_elems; i++)
|
||||
{
|
||||
float v = src0[i];
|
||||
|
||||
v *= c_log2f*0.5;
|
||||
int intPart = (int)v;
|
||||
float x = (v - intPart);
|
||||
float xx = x * x;
|
||||
float v1 = c_log2f + c2 * xx;
|
||||
float v2 = x + xx * c1 * x;
|
||||
float v3 = (v2 + v1);
|
||||
*((int*)&v3) += intPart << 24;
|
||||
float v4 = v2 - v1;
|
||||
float v5 = v3 - v4;
|
||||
float res = v3 / v5;
|
||||
|
||||
dst[i] = res;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
@ -0,0 +1,947 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
#define htp_binary_ops_preamble \
|
||||
int step_of_4 = num_elems >> 7; \
|
||||
int step_of_2 = (num_elems - step_of_4 * VLEN_FP32 * 4) >> 6; \
|
||||
int step_of_1 = (num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2) >> 5; \
|
||||
int remaining = num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32; \
|
||||
\
|
||||
const uint8_t * restrict src0_curr = src0; \
|
||||
const uint8_t * restrict src1_curr = src1; \
|
||||
uint8_t * restrict dst_curr = dst;
|
||||
|
||||
void hvx_mul_f32(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
|
||||
(0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_mul_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_mul_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
|
||||
HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
|
||||
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, *vec_in2++);
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
|
||||
HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2);
|
||||
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * src0f = (const float *) src0 + num_elems_whole;
|
||||
const float * src1f = (const float *) src1 + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in1 = *(HVX_UVector *) src0f;
|
||||
HVX_Vector in2 = *(HVX_UVector *) src1f;
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2);
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||
}
|
||||
}
|
||||
|
||||
void hvx_mul_f32_opt(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems) {
|
||||
htp_binary_ops_preamble;
|
||||
|
||||
for (int i = 0; i < step_of_4; i++) {
|
||||
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||
|
||||
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||
|
||||
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||
|
||||
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
|
||||
|
||||
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||
|
||||
HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
|
||||
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||
|
||||
HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
|
||||
|
||||
HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
|
||||
|
||||
src0_curr += 4 * VLEN;
|
||||
|
||||
HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(v3a, v3b);
|
||||
|
||||
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||
|
||||
HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
|
||||
|
||||
*(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
|
||||
|
||||
HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v4a, v4b);
|
||||
|
||||
src1_curr += 4 * VLEN;
|
||||
|
||||
*(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
|
||||
|
||||
dst_curr += 4 * VLEN;
|
||||
}
|
||||
|
||||
for (int i = 0; i < step_of_2; i++) {
|
||||
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||
|
||||
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||
|
||||
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||
|
||||
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
|
||||
|
||||
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||
|
||||
src0_curr += 2 * VLEN;
|
||||
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
|
||||
|
||||
src1_curr += 2 * VLEN;
|
||||
|
||||
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||
|
||||
dst_curr += 2 * VLEN;
|
||||
}
|
||||
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector va = *(HVX_Vector *) src0_curr;
|
||||
|
||||
src0_curr += VLEN;
|
||||
|
||||
HVX_Vector vb = *(HVX_Vector *) src1_curr;
|
||||
|
||||
src1_curr += VLEN;
|
||||
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(va, vb);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
|
||||
|
||||
dst_curr += VLEN;
|
||||
}
|
||||
|
||||
if (remaining > 0) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
|
||||
hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
void hvx_mul_mul_f32_opt(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
const uint8_t * restrict src2,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems) {
|
||||
const uint8_t * restrict src0_curr = src0;
|
||||
const uint8_t * restrict src1_curr = src1;
|
||||
const uint8_t * restrict src2_curr = src2;
|
||||
uint8_t * restrict dst_curr = dst;
|
||||
|
||||
int step_of_2 = num_elems >> 6;
|
||||
int step_of_1 = (num_elems - step_of_2 * VLEN_FP32 * 2) >> 5;
|
||||
int remaining = num_elems - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32;
|
||||
|
||||
for (int i = 0; i < step_of_2; i++) {
|
||||
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||
HVX_Vector v1c = *(HVX_Vector *) src2_curr;
|
||||
|
||||
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||
|
||||
HVX_Vector v1_ = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b);
|
||||
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1_), v1c);
|
||||
|
||||
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||
|
||||
HVX_Vector v2c = *(HVX_Vector *) (src2_curr + VLEN);
|
||||
|
||||
src0_curr += 2 * VLEN;
|
||||
|
||||
HVX_Vector v2_ = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b);
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2_), v2c);
|
||||
|
||||
src1_curr += 2 * VLEN;
|
||||
src2_curr += 2 * VLEN;
|
||||
|
||||
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||
|
||||
dst_curr += 2 * VLEN;
|
||||
}
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector va = *(HVX_Vector *) src0_curr;
|
||||
src0_curr += VLEN;
|
||||
|
||||
HVX_Vector vb = *(HVX_Vector *) src1_curr;
|
||||
src1_curr += VLEN;
|
||||
|
||||
HVX_Vector vc = *(HVX_Vector *) src2_curr;
|
||||
src2_curr += VLEN;
|
||||
|
||||
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(va, vb);
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), vc);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v2);
|
||||
dst_curr += VLEN;
|
||||
}
|
||||
if (remaining > 0) {
|
||||
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), *(HVX_Vector *) src2_curr);
|
||||
hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v2));
|
||||
}
|
||||
}
|
||||
|
||||
void hvx_add_f32(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
|
||||
(0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_add_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_add_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
|
||||
HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
|
||||
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, *vec_in2++);
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
|
||||
HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2);
|
||||
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * src0f = (const float *) src0 + num_elems_whole;
|
||||
const float * src1f = (const float *) src1 + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in1 = *(HVX_UVector *) src0f;
|
||||
HVX_Vector in2 = *(HVX_UVector *) src1f;
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2);
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||
}
|
||||
}
|
||||
|
||||
void hvx_add_f32_opt(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems) {
|
||||
htp_binary_ops_preamble;
|
||||
|
||||
for (int i = 0; i < step_of_4; i++) {
|
||||
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||
|
||||
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||
|
||||
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||
|
||||
HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b);
|
||||
|
||||
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||
|
||||
HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
|
||||
|
||||
HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||
|
||||
HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
|
||||
|
||||
HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
|
||||
|
||||
src0_curr += 4 * VLEN;
|
||||
|
||||
HVX_Vector v3 = Q6_Vqf32_vadd_VsfVsf(v3a, v3b);
|
||||
|
||||
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||
|
||||
HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
|
||||
|
||||
*(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
|
||||
|
||||
HVX_Vector v4 = Q6_Vqf32_vadd_VsfVsf(v4a, v4b);
|
||||
|
||||
src1_curr += 4 * VLEN;
|
||||
|
||||
*(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
|
||||
|
||||
dst_curr += 4 * VLEN;
|
||||
}
|
||||
for (int i = 0; i < step_of_2; i++) {
|
||||
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||
|
||||
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||
|
||||
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||
|
||||
HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b);
|
||||
|
||||
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||
|
||||
src0_curr += 2 * VLEN;
|
||||
|
||||
HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b);
|
||||
|
||||
src1_curr += 2 * VLEN;
|
||||
|
||||
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||
|
||||
dst_curr += 2 * VLEN;
|
||||
}
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector va = *(HVX_Vector *) src0_curr;
|
||||
|
||||
src0_curr += VLEN;
|
||||
|
||||
HVX_Vector vb = *(HVX_Vector *) src1_curr;
|
||||
|
||||
src1_curr += VLEN;
|
||||
|
||||
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(va, vb);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
|
||||
|
||||
dst_curr += VLEN;
|
||||
}
|
||||
if (remaining > 0) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
|
||||
hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
|
||||
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||
size_t num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_add_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, val_vec);
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
|
||||
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec);
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||
}
|
||||
}
|
||||
|
||||
void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
|
||||
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||
size_t num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_mul_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_mul_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, val_vec);
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec);
|
||||
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec);
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||
}
|
||||
}
|
||||
|
||||
void hvx_sub_f32(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems) {
|
||||
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||
size_t num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) ||
|
||||
(0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_sub_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_sub_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0;
|
||||
HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1;
|
||||
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, *vec_in2++);
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32);
|
||||
HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32);
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2);
|
||||
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * src0f = (const float *) src0 + num_elems_whole;
|
||||
const float * src1f = (const float *) src1 + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in1 = *(HVX_UVector *) src0f;
|
||||
HVX_Vector in2 = *(HVX_UVector *) src1f;
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2);
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||
}
|
||||
}
|
||||
|
||||
void hvx_sub_f32_opt(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems) {
|
||||
htp_binary_ops_preamble;
|
||||
|
||||
for (int i = 0; i < step_of_4; i++) {
|
||||
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||
|
||||
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||
|
||||
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||
|
||||
HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b);
|
||||
|
||||
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||
|
||||
HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN);
|
||||
|
||||
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||
|
||||
HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN);
|
||||
|
||||
HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN);
|
||||
|
||||
src0_curr += 4 * VLEN;
|
||||
|
||||
HVX_Vector v3 = Q6_Vqf32_vsub_VsfVsf(v3a, v3b);
|
||||
|
||||
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||
|
||||
HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN);
|
||||
|
||||
*(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3);
|
||||
|
||||
HVX_Vector v4 = Q6_Vqf32_vsub_VsfVsf(v4a, v4b);
|
||||
|
||||
src1_curr += 4 * VLEN;
|
||||
|
||||
*(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4);
|
||||
|
||||
dst_curr += 4 * VLEN;
|
||||
}
|
||||
for (int i = 0; i < step_of_2; i++) {
|
||||
HVX_Vector v1a = *(HVX_Vector *) src0_curr;
|
||||
|
||||
HVX_Vector v1b = *(HVX_Vector *) src1_curr;
|
||||
|
||||
HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN);
|
||||
|
||||
HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b);
|
||||
|
||||
HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1);
|
||||
|
||||
src0_curr += 2 * VLEN;
|
||||
|
||||
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b);
|
||||
|
||||
src1_curr += 2 * VLEN;
|
||||
|
||||
*(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2);
|
||||
|
||||
dst_curr += 2 * VLEN;
|
||||
}
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector va = *(HVX_Vector *) src0_curr;
|
||||
|
||||
src0_curr += VLEN;
|
||||
|
||||
HVX_Vector vb = *(HVX_Vector *) src1_curr;
|
||||
|
||||
src1_curr += VLEN;
|
||||
|
||||
HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(va, vb);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v);
|
||||
|
||||
dst_curr += VLEN;
|
||||
}
|
||||
if (remaining > 0) {
|
||||
HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr);
|
||||
hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
}
|
||||
|
||||
void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
|
||||
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||
size_t num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_sub_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_sub_scalar_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
HVX_Vector val_vec = hvx_vec_splat_fp32(val);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, val_vec);
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec);
|
||||
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec);
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||
}
|
||||
}
|
||||
|
||||
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
if (0 == htp_is_aligned((void *) src, VLEN)) {
|
||||
FARF(HIGH, "hvx_sum_of_squares_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
|
||||
|
||||
HVX_Vector * restrict vec_in1 = (HVX_Vector *) src;
|
||||
|
||||
HVX_Vector sum_vec_acc = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000);
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1, *vec_in1);
|
||||
sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, v);
|
||||
vec_in1++;
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
|
||||
HVX_Vector vec_left = *(HVX_UVector *) srcf;
|
||||
|
||||
HVX_Vector vec_left_sq = Q6_Vqf32_vmpy_VsfVsf(vec_left, vec_left);
|
||||
HVX_Vector vec_tmp = Q6_V_valign_VVR(vec_left_sq, zero_vec, left_over * SIZEOF_FP32);
|
||||
|
||||
sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, vec_tmp);
|
||||
}
|
||||
|
||||
HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec_acc);
|
||||
return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
|
||||
float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if (0 == htp_is_aligned((void *) src, VLEN)) {
|
||||
FARF(HIGH, "hvx_self_sum_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_self_sum_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * vec_in = (HVX_Vector *) src;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
// sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, *vec_in++);
|
||||
sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), *vec_in++);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), in);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
|
||||
HVX_Vector vec_left = *(HVX_UVector *) srcf;
|
||||
HVX_Vector vec_tmp = Q6_V_valign_VVR(vec_left, zero_vec, left_over * SIZEOF_FP32);
|
||||
// sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, vec_tmp);
|
||||
sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), vec_tmp);
|
||||
}
|
||||
|
||||
HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec);
|
||||
return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
|
||||
}
|
||||
|
||||
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * vec_in1 = (HVX_Vector *) src;
|
||||
HVX_Vector * vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec);
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
|
||||
|
||||
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
|
||||
}
|
||||
}
|
||||
|
||||
float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
|
||||
int left_over = num_elems & (VLEN_FP32 - 1);
|
||||
int num_elems_whole = num_elems - left_over;
|
||||
|
||||
int unaligned_addr = 0;
|
||||
int unaligned_loop = 0;
|
||||
if (0 == htp_is_aligned((void *) src, VLEN)) {
|
||||
FARF(HIGH, "hvx_self_max_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
unaligned_addr = 1;
|
||||
}
|
||||
|
||||
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
|
||||
unaligned_loop = 1;
|
||||
FARF(HIGH, "hvx_self_max_f32: unaligned loop in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
HVX_Vector vec_max = hvx_vec_splat_fp32(((const float *) src)[0]);
|
||||
HVX_Vector vec_first = hvx_vec_splat_fp32(((const float *) src)[0]);
|
||||
|
||||
if (0 == unaligned_loop) {
|
||||
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, *vec_in++);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
|
||||
|
||||
vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, in);
|
||||
}
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
HVX_Vector temp = Q6_V_valign_VVR(in, vec_first, left_over * SIZEOF_FP32);
|
||||
vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, temp);
|
||||
}
|
||||
|
||||
HVX_Vector v = hvx_vec_reduce_max_fp32(vec_max);
|
||||
return hvx_vec_get_fp32(v);
|
||||
}
|
||||
|
||||
void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) {
|
||||
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||
size_t num_elems_whole = num_elems - left_over;
|
||||
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
|
||||
|
||||
const float * src_f = (const float *) src;
|
||||
|
||||
HVX_Vector vec_min = Q6_V_vsplat_R(val);
|
||||
|
||||
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(vec_min);
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, in);
|
||||
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(vec_min));
|
||||
}
|
||||
}
|
||||
|
||||
void hvx_clamp_scalar_f32(const uint8_t * restrict src,
|
||||
const float limit_left,
|
||||
const float limit_right,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems) {
|
||||
size_t left_over = num_elems & (VLEN_FP32 - 1);
|
||||
size_t num_elems_whole = num_elems - left_over;
|
||||
|
||||
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
|
||||
FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n");
|
||||
}
|
||||
|
||||
assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole));
|
||||
|
||||
HVX_Vector * restrict vec_in = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
|
||||
|
||||
HVX_Vector range_left = hvx_vec_splat_fp32(limit_left);
|
||||
HVX_Vector range_right = hvx_vec_splat_fp32(limit_right);
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
|
||||
HVX_Vector in_vec = *vec_in++;
|
||||
HVX_Vector temp_v = in_vec;
|
||||
|
||||
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right);
|
||||
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec);
|
||||
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
|
||||
|
||||
*vec_out++ = Q6_Vsf_equals_Vqf32(in_vec);
|
||||
}
|
||||
|
||||
if (left_over > 0) {
|
||||
const float * srcf = (const float *) src + num_elems_whole;
|
||||
float * dstf = (float *) dst + num_elems_whole;
|
||||
|
||||
HVX_Vector in = *(HVX_UVector *) srcf;
|
||||
|
||||
HVX_Vector temp_v = in;
|
||||
|
||||
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in, range_right);
|
||||
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in);
|
||||
|
||||
in = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v);
|
||||
in = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v);
|
||||
|
||||
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(in));
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,998 @@
|
|||
#ifndef HVX_UTILS_H
|
||||
#define HVX_UTILS_H
|
||||
|
||||
#include "ops-utils.h"
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#define SIZEOF_FP32 (4)
|
||||
#define SIZEOF_FP16 (2)
|
||||
#define VLEN (128)
|
||||
#define VLEN_FP32 (VLEN / SIZEOF_FP32)
|
||||
#define VLEN_FP16 (VLEN / SIZEOF_FP16)
|
||||
|
||||
static inline HVX_Vector hvx_vec_splat_fp32(float i) {
|
||||
union {
|
||||
float f;
|
||||
int32_t i;
|
||||
} fp32 = { .f = i };
|
||||
|
||||
return Q6_V_vsplat_R(fp32.i);
|
||||
}
|
||||
|
||||
static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) {
|
||||
// Rotate as needed.
|
||||
v = Q6_V_vlalign_VVR(v, v, (size_t) addr);
|
||||
|
||||
uint32_t left_off = (size_t) addr & 127;
|
||||
uint32_t right_off = left_off + n;
|
||||
|
||||
HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) addr);
|
||||
HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off);
|
||||
|
||||
if (right_off > 128) {
|
||||
Q6_vmem_QRIV(qr, (HVX_Vector *) addr + 1, v);
|
||||
// all 1's
|
||||
qr = Q6_Q_vcmp_eq_VbVb(v, v);
|
||||
}
|
||||
|
||||
ql_not = Q6_Q_or_QQn(ql_not, qr);
|
||||
Q6_vmem_QnRIV(ql_not, (HVX_Vector *) addr, v);
|
||||
}
|
||||
|
||||
static inline void hvx_vec_store_a(void * ptr, size_t n, HVX_Vector v) {
|
||||
assert((unsigned long) ptr % 128 == 0);
|
||||
|
||||
HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) ptr);
|
||||
HVX_VectorPred qr = Q6_Q_vsetq2_R(n);
|
||||
ql_not = Q6_Q_or_QQn(ql_not, qr);
|
||||
Q6_vmem_QnRIV(ql_not, (HVX_Vector *) ptr, v);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) {
|
||||
// vdelta control to replicate first 4 bytes across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
};
|
||||
|
||||
HVX_Vector ctrl = *(HVX_Vector *) repl;
|
||||
return Q6_V_vdelta_VV(v, ctrl);
|
||||
}
|
||||
|
||||
// copy n fp16 elements : source and destination are aligned to HVX Vector (128)
|
||||
static inline void hvx_copy_fp16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
HVX_Vector * restrict vsrc = (HVX_Vector *) src;
|
||||
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
vdst[i] = v;
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
|
||||
}
|
||||
}
|
||||
|
||||
// copy n fp16 elements : source is aligned, destination is potentially unaligned
|
||||
static inline void hvx_copy_fp16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_UVector * restrict vdst = (HVX_UVector *) dst;
|
||||
HVX_Vector * restrict vsrc = (HVX_Vector *) src;
|
||||
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
vdst[i] = v;
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
|
||||
}
|
||||
}
|
||||
|
||||
// copy n fp16 elements : source is aligned, destination is potentially unaligned
|
||||
static inline void hvx_copy_fp16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
HVX_UVector * restrict vsrc = (HVX_UVector *) src;
|
||||
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
|
||||
uint32_t nvec = n / 64;
|
||||
uint32_t nloe = n % 64;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
vdst[i] = v;
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v);
|
||||
}
|
||||
}
|
||||
|
||||
// copy n fp32 elements : source and destination are aligned to HVX Vector (128)
|
||||
static inline void hvx_copy_fp32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
HVX_Vector * restrict vsrc = (HVX_Vector *) src;
|
||||
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
|
||||
uint32_t nvec = n / 32;
|
||||
uint32_t nloe = n % 32;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
vdst[i] = v;
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
|
||||
}
|
||||
}
|
||||
|
||||
// copy n fp32 elements : source is aligned, destination is unaligned
|
||||
static inline void hvx_copy_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_UVector * restrict vdst = (HVX_UVector *) dst;
|
||||
HVX_Vector * restrict vsrc = (HVX_Vector *) src;
|
||||
|
||||
assert((unsigned long) src % 128 == 0);
|
||||
|
||||
uint32_t nvec = n / 32;
|
||||
uint32_t nloe = n % 32;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
vdst[i] = v;
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
|
||||
}
|
||||
}
|
||||
|
||||
// copy n fp32 elements : source is unaligned, destination is aligned
|
||||
static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
HVX_UVector * restrict vsrc = (HVX_UVector *) src;
|
||||
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
|
||||
uint32_t nvec = n / 32;
|
||||
uint32_t nloe = n % 32;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
vdst[i] = v;
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector v = vsrc[i];
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
|
||||
}
|
||||
}
|
||||
|
||||
// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned
|
||||
static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) {
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
|
||||
HVX_Vector velem = hvx_vec_splat_fp32(elem);
|
||||
|
||||
assert((unsigned long) dst % 128 == 0);
|
||||
|
||||
uint32_t nvec = n / 32;
|
||||
uint32_t nloe = n % 32;
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; i++) {
|
||||
vdst[i] = velem;
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), velem);
|
||||
}
|
||||
}
|
||||
|
||||
static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
__fp16 d[64];
|
||||
} u = { .v = v };
|
||||
|
||||
const uint32_t n0 = n / 16;
|
||||
const uint32_t n1 = n % 16;
|
||||
int i = 0;
|
||||
for (; i < n0; i++) {
|
||||
htp_dump_fp16_line(pref, u.d + (16 * i), 16);
|
||||
}
|
||||
if (n1) {
|
||||
htp_dump_fp16_line(pref, u.d + (16 * i), n1);
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_fp16(char * pref, HVX_Vector v) {
|
||||
hvx_vec_dump_fp16_n(pref, v, 64);
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_fp32_n(char * pref, HVX_Vector v, uint32_t n) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
float d[32];
|
||||
} u = { .v = v };
|
||||
|
||||
const uint32_t n0 = n / 16;
|
||||
const uint32_t n1 = n % 16;
|
||||
int i = 0;
|
||||
for (; i < n0; i++) {
|
||||
htp_dump_fp32_line(pref, u.d + (16 * i), 16);
|
||||
}
|
||||
if (n1) {
|
||||
htp_dump_fp32_line(pref, u.d + (16 * i), n1);
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_fp32_hmt(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
float d[32];
|
||||
} u = { .v = v };
|
||||
|
||||
FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1],
|
||||
u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_fp32(char * pref, HVX_Vector v) {
|
||||
hvx_vec_dump_fp32_n(pref, v, 32);
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_int32(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
int32_t d[32];
|
||||
} u = { .v = v };
|
||||
|
||||
for (int i = 0; i < 32 / 16; i++) {
|
||||
htp_dump_int32_line(pref, u.d + (16 * i), 16);
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
int32_t d[32];
|
||||
} u = { .v = v };
|
||||
|
||||
FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12],
|
||||
u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]);
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
int8_t d[128];
|
||||
} u = { .v = v };
|
||||
|
||||
FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60],
|
||||
u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]);
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_int8(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
int8_t d[128];
|
||||
} u = { .v = v };
|
||||
|
||||
for (int i = 0; i < 128 / 16; i++) {
|
||||
htp_dump_int8_line(pref, u.d + (16 * i), 16);
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) {
|
||||
union {
|
||||
HVX_Vector v;
|
||||
uint8_t d[128];
|
||||
} u = { .v = v };
|
||||
|
||||
for (int i = 0; i < 128 / 16; i++) {
|
||||
htp_dump_uint8_line(pref, u.d + (16 * i), 16);
|
||||
}
|
||||
}
|
||||
|
||||
static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) {
|
||||
typedef union {
|
||||
HVX_Vector v;
|
||||
int8_t d[128];
|
||||
} U;
|
||||
|
||||
U u0 = { .v = v0 };
|
||||
U u1 = { .v = v1 };
|
||||
|
||||
for (int i = 0; i < n; i++) {
|
||||
if (u0.d[i] != u1.d[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static inline float hvx_vec_get_fp32(HVX_Vector v) {
|
||||
float __attribute__((aligned(128))) x;
|
||||
hvx_vec_store_a(&x, 4, v);
|
||||
return x;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_int32_reduce_sum_n(HVX_Vector in, unsigned int n) {
|
||||
unsigned int total = n * 4; // total vec nbytes
|
||||
unsigned int width = 4; // int32
|
||||
|
||||
HVX_Vector sum = in, sum_t;
|
||||
while (width < total) {
|
||||
sum_t = Q6_V_vror_VR(sum, width); // rotate right
|
||||
sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum
|
||||
width = width << 1;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_int32_reduce_sum(HVX_Vector in) {
|
||||
return hvx_vec_int32_reduce_sum_n(in, 32);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_qf32_reduce_sum_n(HVX_Vector in, unsigned int n) {
|
||||
unsigned int total = n * 4; // total vec nbytes
|
||||
unsigned int width = 4; // fp32 nbytes
|
||||
|
||||
HVX_Vector sum = in, sum_t;
|
||||
while (width < total) {
|
||||
sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right
|
||||
sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum
|
||||
width = width << 1;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_qf32_reduce_sum(HVX_Vector in) {
|
||||
return hvx_vec_qf32_reduce_sum_n(in, 32);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_fp32_reduce_sum_n(HVX_Vector in, unsigned int n) {
|
||||
unsigned int total = n * 4; // total vec nbytes
|
||||
unsigned int width = 4; // fp32 nbytes
|
||||
|
||||
HVX_Vector sum = in, sum_t;
|
||||
while (width < total) {
|
||||
sum_t = Q6_V_vror_VR(sum, width); // rotate right
|
||||
sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum
|
||||
width = width << 1;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_fp32_reduce_sum(HVX_Vector in) {
|
||||
return hvx_vec_fp32_reduce_sum_n(in, 32);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_max_fp16(HVX_Vector in) {
|
||||
unsigned total = 128; // total vec nbytes
|
||||
unsigned width = 2; // fp16 nbytes
|
||||
|
||||
HVX_Vector _max = in, _max_t;
|
||||
while (width < total) {
|
||||
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||
_max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
|
||||
width = width << 1;
|
||||
}
|
||||
|
||||
return _max;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_max2_fp16(HVX_Vector in, HVX_Vector _max) {
|
||||
unsigned total = 128; // total vec nbytes
|
||||
unsigned width = 2; // fp32 nbytes
|
||||
|
||||
HVX_Vector _max_t;
|
||||
|
||||
_max = Q6_Vhf_vmax_VhfVhf(in, _max);
|
||||
while (width < total) {
|
||||
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||
_max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max
|
||||
width = width << 1;
|
||||
}
|
||||
|
||||
return _max;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_max_fp32(HVX_Vector in) {
|
||||
unsigned total = 128; // total vec nbytes
|
||||
unsigned width = 4; // fp32 nbytes
|
||||
|
||||
HVX_Vector _max = in, _max_t;
|
||||
while (width < total) {
|
||||
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||
_max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
|
||||
width = width << 1;
|
||||
}
|
||||
|
||||
return _max;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_reduce_max2_fp32(HVX_Vector in, HVX_Vector _max) {
|
||||
unsigned total = 128; // total vec nbytes
|
||||
unsigned width = 4; // fp32 nbytes
|
||||
|
||||
HVX_Vector _max_t;
|
||||
|
||||
_max = Q6_Vsf_vmax_VsfVsf(in, _max);
|
||||
while (width < total) {
|
||||
_max_t = Q6_V_vror_VR(_max, width); // rotate right
|
||||
_max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max
|
||||
width = width << 1;
|
||||
}
|
||||
|
||||
return _max;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_abs_fp16(HVX_Vector v) {
|
||||
// abs by clearing the fp16 sign bit
|
||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);
|
||||
return Q6_V_vand_VV(v, mask);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) {
|
||||
// neg by setting the fp16 sign bit
|
||||
HVX_Vector mask = Q6_Vh_vsplat_R(0x8000);
|
||||
return Q6_V_vor_VV(v, mask);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
|
||||
// abs by clearing the fp32 sign bit
|
||||
HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff);
|
||||
return Q6_V_vand_VV(v, mask);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
|
||||
#if __HTP_ARCH__ > 75
|
||||
return Q6_Vsf_vfneg_Vsf(v);
|
||||
#else
|
||||
// neg by setting the fp32 sign bit
|
||||
HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
|
||||
return Q6_V_vor_VV(v, mask);
|
||||
#endif // __HTP_ARCH__ > 75
|
||||
}
|
||||
|
||||
// ====================================================
|
||||
// FUNCTION: 1/(x+1) y(0) = 1, y(0.5) = 0.6667, y(1) = 0.5
|
||||
// Order:3; continuity: True; Ends forced: True
|
||||
// Mode: unsigned; Result fractional bits: 14
|
||||
// Peak Error: 1.1295e-04 Rms Error: 2.8410e-05 Mean Error: 1.1370e-05
|
||||
// 32769 -32706 31252 -10589
|
||||
// 32590 -30635 22793 -4493
|
||||
// 32066 -27505 16481 -2348
|
||||
// 31205 -24054 11849 -1306
|
||||
|
||||
static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) {
|
||||
// input is 0..0xffff representing 0.0 .. 1.0
|
||||
HVX_Vector p;
|
||||
p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull);
|
||||
p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull);
|
||||
p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull);
|
||||
p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull);
|
||||
return p; // signed result, 14 fractional bits
|
||||
}
|
||||
|
||||
// Find reciprocal of fp16.
|
||||
// (1) first, convert to fp32, multiplying by 1.0; this is done to
|
||||
// handle denormals. Ignoring sign and zero, result should be at
|
||||
// least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000)
|
||||
// (exponent in range [103,143])
|
||||
// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly
|
||||
// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32
|
||||
// (4) convert that to fp16
|
||||
// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace
|
||||
// the result with the max value.
|
||||
static inline HVX_Vector hvx_vec_inverse_fp16(HVX_Vector vals) {
|
||||
HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF);
|
||||
HVX_Vector avals = Q6_V_vand_VV(vals, em_mask);
|
||||
HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(avals, vals);
|
||||
// is too small to 1/x ? for 'standard' fp16, this would be 0x101
|
||||
HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals);
|
||||
|
||||
HVX_VectorPair to_qf32 = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00)); // *1.0
|
||||
HVX_Vector to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32));
|
||||
HVX_Vector to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32));
|
||||
|
||||
// bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector
|
||||
HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9));
|
||||
// likewise extract the upper 16 from each, containing the exponents in range 103..142
|
||||
HVX_Vector exp_u16 = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0);
|
||||
//Get exponent in IEEE 32-bit representation
|
||||
exp_u16 = Q6_Vuh_vlsr_VuhR(exp_u16, 7);
|
||||
|
||||
// so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane
|
||||
// We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0)
|
||||
// Use poly to transform to 1/x, with 14 fractional bits
|
||||
//
|
||||
HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16);
|
||||
|
||||
HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros
|
||||
|
||||
// Get mantissa for 16-bit represenation
|
||||
HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF));
|
||||
|
||||
//Compute Reciprocal Exponent
|
||||
HVX_Vector exp_recip =
|
||||
Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1)));
|
||||
//Convert it for 16-bit representation
|
||||
exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15));
|
||||
exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10);
|
||||
|
||||
//Merge exponent and mantissa for reciprocal
|
||||
HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip);
|
||||
// map 'small' inputs to standard largest value 0x7bff
|
||||
recip = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip);
|
||||
// add sign back
|
||||
recip = Q6_V_vandor_VQR(recip, is_neg, 0x80008000);
|
||||
return recip;
|
||||
}
|
||||
|
||||
#define IEEE_VSF_EXPLEN (8)
|
||||
#define IEEE_VSF_EXPBIAS (127)
|
||||
#define IEEE_VSF_EXPMASK (0xFF)
|
||||
#define IEEE_VSF_MANTLEN (23)
|
||||
#define IEEE_VSF_MANTMASK (0x7FFFFF)
|
||||
#define IEEE_VSF_MIMPMASK (0x800000)
|
||||
|
||||
static inline HVX_Vector hvx_vec_truncate_fp32(HVX_Vector in_vec) {
|
||||
HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
|
||||
HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
|
||||
HVX_Vector const_zero_v = Q6_V_vzero();
|
||||
|
||||
HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
|
||||
|
||||
HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
|
||||
expval_v &= IEEE_VSF_EXPMASK;
|
||||
expval_v -= IEEE_VSF_EXPBIAS;
|
||||
|
||||
// negative exp == fractional value
|
||||
HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
|
||||
|
||||
HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift
|
||||
|
||||
HVX_Vector mant_v = in_vec & mask_mant_v; // obtain mantissa
|
||||
HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0
|
||||
|
||||
vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer
|
||||
vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0
|
||||
|
||||
HVX_Vector neg_vout = -vout;
|
||||
|
||||
vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives
|
||||
|
||||
return (vout);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_floor_fp32(HVX_Vector in_vec) {
|
||||
HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK);
|
||||
HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK);
|
||||
HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN);
|
||||
HVX_Vector const_zero_v = Q6_V_vzero();
|
||||
HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf
|
||||
|
||||
HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec);
|
||||
|
||||
HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN;
|
||||
expval_v &= IEEE_VSF_EXPMASK;
|
||||
expval_v -= IEEE_VSF_EXPBIAS;
|
||||
|
||||
HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v);
|
||||
HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v);
|
||||
HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v);
|
||||
HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec);
|
||||
|
||||
// if expval < 0 (q_negexp) // <0, floor is 0
|
||||
// if vin > 0
|
||||
// floor = 0
|
||||
// if vin < 0
|
||||
// floor = -1
|
||||
// if expval < mant_len (q_expltmn) // >0, but fraction may exist
|
||||
// get sign (q_negative)
|
||||
// mask >> expval // fraction bits to mask off
|
||||
// vout = ~(mask) // apply mask to remove fraction
|
||||
// if (qneg) // negative floor is one less (more, sign bit for neg)
|
||||
// vout += ((impl_mask) >> expval)
|
||||
// if (mask && vin)
|
||||
// vout = vin
|
||||
// else // already an integer
|
||||
// ; // no change
|
||||
|
||||
// compute floor
|
||||
mask_mant_v >>= expval_v;
|
||||
HVX_Vector neg_addin_v = mask_impl_v >> expval_v;
|
||||
HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v);
|
||||
HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec);
|
||||
|
||||
HVX_Vector mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v); // chk if bits set
|
||||
HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v);
|
||||
|
||||
HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear
|
||||
HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits
|
||||
|
||||
vout = in_vec;
|
||||
vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval<mant
|
||||
vout = Q6_V_vmux_QVV(q_integral, in_vec, vout); // integral values
|
||||
vout = Q6_V_vmux_QVV(q_negexp_pos, const_zero_v, vout); // expval<0 x>0 -> 0
|
||||
vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1
|
||||
|
||||
return vout;
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
|
||||
// This looks complicated.
|
||||
// Ideally should just be Q6_Vh_equals_Vhf(vin)
|
||||
// but that instruction does not do proper rounding.
|
||||
|
||||
// convert to qf32, multiplying by 1.0 in the process.
|
||||
HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00));
|
||||
|
||||
// 'in-range' values are +/32752.
|
||||
// add 192K to it, convert to sf
|
||||
HVX_Vector v192K = Q6_V_vsplat_R(0x48400000);
|
||||
HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K));
|
||||
HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K));
|
||||
|
||||
// for in-range cases, result is {163858... 229360} so the exponent is always 144.
|
||||
// if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer.
|
||||
// Start by <<10 to get the final 'sign' bit in bit 15...
|
||||
vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10);
|
||||
vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10);
|
||||
|
||||
// now round down to 16
|
||||
return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) {
|
||||
HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3);
|
||||
HVX_Vector two_sf = hvx_vec_splat_fp32(2.0);
|
||||
|
||||
// First approximation
|
||||
HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf);
|
||||
|
||||
HVX_Vector r_qf;
|
||||
|
||||
// Refine
|
||||
r_qf = Q6_Vqf32_vmpy_VsfVsf(
|
||||
i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf)))));
|
||||
r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
|
||||
r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
|
||||
r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32(
|
||||
r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf))));
|
||||
|
||||
return Q6_Vsf_equals_Vqf32(r_qf);
|
||||
}
|
||||
|
||||
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
|
||||
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
|
||||
#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267
|
||||
#define FAST_SIGMOID_C3 (0x3f000000) // 0.5
|
||||
|
||||
static inline HVX_Vector hvx_vec_fast_sigmoid_fp32(HVX_Vector v) {
|
||||
v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
|
||||
v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3));
|
||||
|
||||
HVX_Vector in_int = hvx_vec_truncate_fp32(Q6_Vsf_equals_Vqf32(v));
|
||||
HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int));
|
||||
HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x);
|
||||
|
||||
HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2));
|
||||
v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F));
|
||||
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1));
|
||||
v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx);
|
||||
v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x);
|
||||
|
||||
HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1));
|
||||
HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1);
|
||||
v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24);
|
||||
v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent);
|
||||
v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24);
|
||||
|
||||
HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1));
|
||||
HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4));
|
||||
|
||||
HVX_Vector res = hvx_vec_inverse_fp32(v5);
|
||||
res = Q6_Vqf32_vmpy_VsfVsf(v3, res);
|
||||
|
||||
return Q6_Vsf_equals_Vqf32(res);
|
||||
}
|
||||
|
||||
#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!)
|
||||
#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!)
|
||||
#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!)
|
||||
#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!)
|
||||
#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!)
|
||||
#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
|
||||
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
|
||||
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
|
||||
#define EXP_ONE (0x3f800000) // 1.0
|
||||
#define EXP_RANGE_R (0x41a00000) // 20.0
|
||||
#define EXP_RANGE_L (0xc1a00000) // -20.0
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp_fp32(HVX_Vector in_vec) {
|
||||
HVX_Vector z_qf32_v;
|
||||
HVX_Vector x_v;
|
||||
HVX_Vector x_qf32_v;
|
||||
HVX_Vector y_v;
|
||||
HVX_Vector k_v;
|
||||
HVX_Vector f_v;
|
||||
HVX_Vector epsilon_v;
|
||||
HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E);
|
||||
HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2);
|
||||
HVX_Vector E_const;
|
||||
HVX_Vector zero_v = Q6_V_vzero();
|
||||
|
||||
// exp(x) is approximated as follows:
|
||||
// f = floor(x/ln(2)) = floor(x*log2(e))
|
||||
// epsilon = x - f*ln(2)
|
||||
// exp(x) = exp(epsilon+f*ln(2))
|
||||
// = exp(epsilon)*exp(f*ln(2))
|
||||
// = exp(epsilon)*2^f
|
||||
//
|
||||
// Since epsilon is close to zero, it can be approximated with its Taylor series:
|
||||
// exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+...
|
||||
// Preserving the first eight elements, we get:
|
||||
// exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7
|
||||
// = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2
|
||||
|
||||
HVX_Vector temp_v = in_vec;
|
||||
|
||||
// Clamp inputs to (-20.0, 20.0)
|
||||
HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R));
|
||||
HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec);
|
||||
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v);
|
||||
in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v);
|
||||
|
||||
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec);
|
||||
epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v);
|
||||
|
||||
// f_v is the floating point result and k_v is the integer result
|
||||
f_v = hvx_vec_floor_fp32(epsilon_v);
|
||||
k_v = hvx_vec_truncate_fp32(f_v);
|
||||
|
||||
x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v);
|
||||
|
||||
// x = x - f_v * logn2;
|
||||
epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2);
|
||||
x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v);
|
||||
// normalize before every QFloat's vmpy
|
||||
x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v);
|
||||
|
||||
// z = x * x;
|
||||
z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v);
|
||||
z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v);
|
||||
|
||||
x_v = Q6_Vsf_equals_Vqf32(x_qf32_v);
|
||||
|
||||
// y = E4 + E5 * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_5);
|
||||
y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v);
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_4);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = E3 + y * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_3);
|
||||
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = E2 + y * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_2);
|
||||
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = E1 + y * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_1);
|
||||
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = E0 + y * x;
|
||||
E_const = Q6_V_vsplat_R(EXP_COEFF_0);
|
||||
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = x + y * z;
|
||||
y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v);
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v);
|
||||
|
||||
// y = y + 1.0;
|
||||
y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE));
|
||||
|
||||
// insert exponents
|
||||
// y = ldexpf(y, k);
|
||||
// y_v += k_v; // qf32
|
||||
// modify exponent
|
||||
|
||||
y_v = Q6_Vsf_equals_Vqf32(y_v);
|
||||
|
||||
// add k_v to the exponent of y_v
|
||||
HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1);
|
||||
|
||||
y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1);
|
||||
y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent);
|
||||
|
||||
// exponent cannot be negative; if overflow is detected, result is set to zero
|
||||
HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent);
|
||||
|
||||
y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN);
|
||||
|
||||
y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v);
|
||||
|
||||
return y_v;
|
||||
}
|
||||
|
||||
#define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation
|
||||
#define RSQRT_ONE_HALF 0x3f000000 // 0.5
|
||||
#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
|
||||
|
||||
static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) {
|
||||
//Algorithm :
|
||||
// x2 = input*0.5
|
||||
// y = * (long *) &input
|
||||
// y = 0x5f3759df - (y>>2)
|
||||
// y = y*(threehalfs - x2*y*y)
|
||||
|
||||
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
|
||||
HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF);
|
||||
HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES);
|
||||
|
||||
HVX_Vector x2, y, ypower2, temp;
|
||||
|
||||
x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf);
|
||||
x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero());
|
||||
|
||||
y = Q6_Vw_vasr_VwR(in_vec, 1);
|
||||
y = Q6_Vw_vsub_VwVw(rsqrtconst, y);
|
||||
|
||||
// 1st iteration
|
||||
ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y);
|
||||
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
|
||||
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
|
||||
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
|
||||
temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp));
|
||||
|
||||
// 2nd iteration
|
||||
y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
|
||||
ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
|
||||
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
|
||||
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
|
||||
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
|
||||
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
|
||||
|
||||
// 3rd iteration
|
||||
y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero());
|
||||
ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y);
|
||||
ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero());
|
||||
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2);
|
||||
temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp));
|
||||
temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp);
|
||||
|
||||
return Q6_Vsf_equals_Vqf32(temp);
|
||||
}
|
||||
|
||||
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
|
||||
int step_of_1 = num_elems >> 5;
|
||||
int remaining = num_elems - step_of_1 * VLEN_FP32;
|
||||
|
||||
assert(remaining == 0);
|
||||
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
v_dst[i] = hvx_vec_fast_sigmoid_fp32(v_src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
|
||||
void hvx_mul_f32(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems);
|
||||
void hvx_mul_f32_opt(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems);
|
||||
void hvx_mul_mul_f32_opt(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
const uint8_t * restrict src2,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems);
|
||||
void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_add_f32(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems);
|
||||
void hvx_add_f32_opt(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems);
|
||||
void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_sub_f32(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems);
|
||||
void hvx_sub_f32_opt(const uint8_t * restrict src0,
|
||||
const uint8_t * restrict src1,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems);
|
||||
void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale);
|
||||
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate);
|
||||
float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems);
|
||||
float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems);
|
||||
void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
|
||||
void hvx_clamp_scalar_f32(const uint8_t * restrict src,
|
||||
const float limit_left,
|
||||
const float limit_right,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems);
|
||||
|
||||
#endif /* HVX_UTILS_H */
|
||||
|
|
@ -0,0 +1,945 @@
|
|||
#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
|
||||
#define FARF_ERROR 1
|
||||
#define FARF_HIGH 1
|
||||
#define FARF_MEDIUM 0
|
||||
#define FARF_LOW 0
|
||||
#include <AEEStdErr.h>
|
||||
#include <dspqueue.h>
|
||||
#include <HAP_compute_res.h>
|
||||
#include <HAP_etm_config.h>
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_power.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <qurt.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <remote.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "ops-utils.h"
|
||||
#include "worker-pool.h"
|
||||
|
||||
AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) {
|
||||
struct htp_context * ctx;
|
||||
int err = 0;
|
||||
|
||||
ctx = calloc(1, sizeof(*ctx));
|
||||
if (ctx == NULL) {
|
||||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
// Use the context structure as a handle
|
||||
*handle = (remote_handle64) ctx;
|
||||
|
||||
// Enable FARF logs
|
||||
HAP_setFARFRuntimeLoggingParams(0xffff, NULL, 0);
|
||||
|
||||
// Set client class
|
||||
{
|
||||
HAP_power_request_t request;
|
||||
memset(&request, 0, sizeof(HAP_power_request_t));
|
||||
request.type = HAP_power_set_apptype;
|
||||
request.apptype = HAP_POWER_COMPUTE_CLIENT_CLASS;
|
||||
|
||||
if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
HAP_power_request_t request;
|
||||
memset(&request, 0, sizeof(request));
|
||||
|
||||
request.type = HAP_power_set_DCVS_v3;
|
||||
request.dcvs_v3.set_dcvs_enable = TRUE;
|
||||
request.dcvs_v3.dcvs_enable = TRUE;
|
||||
request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE;
|
||||
request.dcvs_v3.set_bus_params = TRUE;
|
||||
request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX;
|
||||
request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX;
|
||||
request.dcvs_v3.bus_params.target_corner = HAP_DCVS_VCORNER_MAX;
|
||||
request.dcvs_v3.set_core_params = TRUE;
|
||||
request.dcvs_v3.core_params.min_corner = HAP_DCVS_VCORNER_MAX;
|
||||
request.dcvs_v3.core_params.max_corner = HAP_DCVS_VCORNER_MAX;
|
||||
request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX;
|
||||
request.dcvs_v3.set_sleep_disable = TRUE;
|
||||
request.dcvs_v3.sleep_disable = TRUE;
|
||||
if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
|
||||
return err;
|
||||
}
|
||||
|
||||
memset(&request, 0, sizeof(request));
|
||||
request.type = HAP_power_set_HVX;
|
||||
request.hvx.power_up = TRUE;
|
||||
if ((err = HAP_power_set((void *) ctx, &request)) != 0) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Power on HMX
|
||||
HAP_power_request_t request;
|
||||
memset(&request, 0, sizeof(HAP_power_request_t));
|
||||
request.type = HAP_power_set_HMX;
|
||||
request.hmx.power_up = TRUE;
|
||||
FARF(ALWAYS, "Powering HMX on\n");
|
||||
err = HAP_power_set((void *) &ctx, &request);
|
||||
if (err != AEE_SUCCESS) {
|
||||
FARF(ERROR, "Error powering on HMX.");
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult htp_iface_close(remote_handle64 handle) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
|
||||
if (!ctx) {
|
||||
return AEE_EBADPARM;
|
||||
}
|
||||
|
||||
if (ctx->queue) {
|
||||
FARF(ERROR, "Closing handle with queue still open");
|
||||
return AEE_EITEMBUSY;
|
||||
}
|
||||
|
||||
free(ctx);
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult htp_iface_enable_etm(remote_handle64 handle) {
|
||||
int err = HAP_user_etm_enable();
|
||||
if (err) {
|
||||
if (err == AEE_EVERSIONNOTSUPPORT) {
|
||||
FARF(ERROR, "API HAP_user_etm_enable is not supported\n");
|
||||
} else {
|
||||
FARF(ERROR, "Error executing HAP_user_etm_enable with error code : 0x%x\n", err);
|
||||
}
|
||||
}
|
||||
return err;
|
||||
}
|
||||
|
||||
AEEResult htp_iface_disable_etm(remote_handle64 handle) {
|
||||
int err = HAP_user_etm_disable();
|
||||
if (err) {
|
||||
if (err == AEE_EVERSIONNOTSUPPORT) {
|
||||
FARF(ERROR, "API HAP_user_etm_disable is not supported\n");
|
||||
} else {
|
||||
FARF(ERROR, "Error executing HAP_user_etm_disable with error code : 0x%x\n", err);
|
||||
}
|
||||
}
|
||||
return err;
|
||||
}
|
||||
|
||||
static int vtcm_acquire(struct htp_context * ctx) {
|
||||
if (!ctx->vtcm_valid) {
|
||||
// Temporarily bump thread priority to make sure it's higher than other sessions.
|
||||
// This way the resource manager will notify the other thread to release VTCM.
|
||||
// Note that we need to reaquire VTCM at normal priority for this to work next time.
|
||||
qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10);
|
||||
HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
|
||||
HAP_compute_res_release_cached(ctx->vtcm_rctx);
|
||||
qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio);
|
||||
|
||||
HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000);
|
||||
ctx->vtcm_valid = true;
|
||||
}
|
||||
|
||||
ctx->vtcm_inuse = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int vtcm_release(struct htp_context * ctx) {
|
||||
ctx->vtcm_inuse = false;
|
||||
|
||||
if (ctx->vtcm_valid && ctx->vtcm_needs_release) {
|
||||
ctx->vtcm_valid = false;
|
||||
ctx->vtcm_needs_release = false;
|
||||
HAP_compute_res_release_cached(ctx->vtcm_rctx);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int vtcm_release_callback(unsigned int rctx, void * state) {
|
||||
struct htp_context * ctx = (struct htp_context *) state;
|
||||
|
||||
if (!ctx || ctx->vtcm_rctx != rctx) {
|
||||
return AEE_EBADPARM;
|
||||
}
|
||||
|
||||
// If VTCM is not inuse (not processing Ops) release it right here
|
||||
// otherwise we'll release it once we're done with the current Op.
|
||||
|
||||
if (ctx->vtcm_inuse) {
|
||||
ctx->vtcm_needs_release = false;
|
||||
return 0;
|
||||
}
|
||||
|
||||
ctx->vtcm_valid = false;
|
||||
HAP_compute_res_release_cached(ctx->vtcm_rctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int vtcm_alloc(struct htp_context * ctx) {
|
||||
unsigned int vtcm_size = 8 * 1024 * 1024; // 8MB default
|
||||
HAP_compute_res_query_VTCM(0, &vtcm_size, NULL, NULL, NULL);
|
||||
|
||||
compute_res_attr_t attr;
|
||||
HAP_compute_res_attr_init(&attr);
|
||||
HAP_compute_res_attr_set_serialize(&attr, 0);
|
||||
HAP_compute_res_attr_set_cache_mode(&attr, 1);
|
||||
HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size);
|
||||
HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx);
|
||||
HAP_compute_res_attr_set_hmx_param(&attr, 1);
|
||||
|
||||
// Allocate VTCM for scratch pads
|
||||
uint32_t rctx = HAP_compute_res_acquire(&attr, 1000000 /* timeout */);
|
||||
if (!rctx) {
|
||||
FARF(ERROR, "failed to allocate %zu bytes VTCM\n", ctx->vtcm_size);
|
||||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
void * vtcm_ptr;
|
||||
if (HAP_compute_res_attr_get_vtcm_ptr_v2(&attr, &vtcm_ptr, &vtcm_size) != 0) {
|
||||
HAP_compute_res_release(rctx);
|
||||
FARF(ERROR, "failed to allocate %zu bytes VTCM (new)\n", ctx->vtcm_size);
|
||||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
ctx->vtcm_base = (uint8_t *) vtcm_ptr;
|
||||
ctx->vtcm_size = vtcm_size;
|
||||
ctx->vtcm_rctx = rctx;
|
||||
ctx->vtcm_valid = false;
|
||||
ctx->vtcm_inuse = false;
|
||||
ctx->vtcm_needs_release = false;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void vtcm_free(struct htp_context * ctx) {
|
||||
if (ctx->vtcm_rctx) {
|
||||
HAP_compute_res_release(ctx->vtcm_rctx);
|
||||
ctx->vtcm_base = 0;
|
||||
ctx->vtcm_rctx = 0;
|
||||
}
|
||||
}
|
||||
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context);
|
||||
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
|
||||
if (!ctx) {
|
||||
return AEE_EBADPARM;
|
||||
}
|
||||
|
||||
if (ctx->queue) {
|
||||
FARF(ERROR, "Queue already open");
|
||||
return AEE_EITEMBUSY;
|
||||
}
|
||||
|
||||
// Import queue created on the CPU
|
||||
int err = dspqueue_import(dsp_queue_id, // Queue ID from dspqueue_export
|
||||
htp_packet_callback, // Packet callback
|
||||
htp_error_callback, // Error callback; no errors expected on the DSP
|
||||
(void *) ctx, // Callback context
|
||||
&ctx->queue);
|
||||
|
||||
if (err) {
|
||||
FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err);
|
||||
return err;
|
||||
}
|
||||
|
||||
ctx->thread_id = qurt_thread_get_id();
|
||||
ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id);
|
||||
|
||||
// allocate VTCM
|
||||
err = vtcm_alloc(ctx);
|
||||
if (err != AEE_SUCCESS) {
|
||||
FARF(ERROR, "Unable to allocate VTCM");
|
||||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
qurt_sysenv_max_hthreads_t hw_threads;
|
||||
qurt_sysenv_get_max_hw_threads(&hw_threads);
|
||||
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
|
||||
|
||||
if (n_hvx == 0) {
|
||||
n_hvx = hw_nhvx;
|
||||
}
|
||||
if (n_hvx > hw_threads.max_hthreads) {
|
||||
n_hvx = hw_threads.max_hthreads;
|
||||
}
|
||||
if (n_hvx > HTP_MAX_NTHREADS) {
|
||||
n_hvx = HTP_MAX_NTHREADS;
|
||||
}
|
||||
|
||||
ctx->n_threads = n_hvx;
|
||||
for (int i = 0; i < ctx->n_threads; i++) {
|
||||
ctx->dma[i] = dma_queue_create(HTP_SPAD_SRC0_NROWS * 2);
|
||||
}
|
||||
|
||||
// init worker pool
|
||||
err = worker_pool_init(&ctx->worker_pool, n_hvx);
|
||||
if (err != AEE_SUCCESS) {
|
||||
FARF(ERROR, "Unable to create worker pool");
|
||||
return err;
|
||||
}
|
||||
|
||||
FARF(HIGH, "session %u started: n-hvx %u vtcm-size %zu vtcm-rctx %u n-threads %u thread-id %d thread-prio %d \n",
|
||||
sess_id, hw_nhvx, ctx->vtcm_size, ctx->vtcm_rctx, ctx->n_threads, ctx->thread_id, ctx->thread_prio);
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult htp_iface_stop(remote_handle64 handle) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
if (!ctx) {
|
||||
return AEE_EBADPARM;
|
||||
}
|
||||
|
||||
if (!ctx->queue) {
|
||||
FARF(ERROR, "Queue not open");
|
||||
return AEE_EBADSTATE;
|
||||
}
|
||||
|
||||
// Close queue. dspqueue_close() will also wait for callbacks to finish.
|
||||
int err = dspqueue_close(ctx->queue);
|
||||
ctx->queue = NULL;
|
||||
if (err != 0) {
|
||||
FARF(ERROR, "Queue close failed with 0x%08x", (unsigned) err);
|
||||
return err;
|
||||
}
|
||||
|
||||
if (ctx->worker_pool) {
|
||||
// Release worker pool
|
||||
worker_pool_release(&ctx->worker_pool);
|
||||
}
|
||||
|
||||
for (int i = 0; i < ctx->n_threads; i++) {
|
||||
dma_queue_delete(ctx->dma[i]);
|
||||
}
|
||||
|
||||
vtcm_free(ctx);
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context) {
|
||||
// No errors expected on the DSP.
|
||||
FARF(ERROR, "Error callback: 0x%08x", (unsigned) error);
|
||||
}
|
||||
|
||||
struct profile_data {
|
||||
uint64_t usecs;
|
||||
uint64_t cycles;
|
||||
uint64_t pkts;
|
||||
};
|
||||
|
||||
static inline void profile_start(struct profile_data * d) {
|
||||
d->usecs = HAP_perf_get_qtimer_count();
|
||||
d->cycles = htp_get_cycles();
|
||||
d->pkts = htp_get_pktcnt();
|
||||
}
|
||||
|
||||
static inline void profile_stop(struct profile_data * d) {
|
||||
d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs);
|
||||
d->cycles = htp_get_cycles() - d->cycles;
|
||||
d->pkts = htp_get_pktcnt() - d->pkts;
|
||||
}
|
||||
|
||||
static int send_htp_rsp(struct htp_context * c,
|
||||
uint32_t op,
|
||||
uint32_t status,
|
||||
struct dspqueue_buffer * bufs,
|
||||
size_t n_bufs,
|
||||
struct profile_data * prof) {
|
||||
// Prep response struct
|
||||
struct htp_general_rsp rsp;
|
||||
rsp.op = op;
|
||||
rsp.status = status;
|
||||
rsp.prof_usecs = prof->usecs;
|
||||
rsp.prof_cycles = prof->cycles;
|
||||
rsp.prof_pkts = prof->pkts;
|
||||
|
||||
int err = dspqueue_write(c->queue,
|
||||
0, // Flags
|
||||
n_bufs,
|
||||
bufs, // Buffer references
|
||||
sizeof(rsp),
|
||||
(const uint8_t *) &rsp, // Message
|
||||
DSPQUEUE_TIMEOUT_NONE);
|
||||
|
||||
if (err != 0) {
|
||||
FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err);
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
static void proc_matmul_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
size_t n_bufs) {
|
||||
// Prep response buffer structs (needed for error responses, etc)
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_matmul(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
|
||||
}
|
||||
|
||||
static void proc_matmul_id_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
size_t n_bufs) {
|
||||
// Prep response buffer structs (needed for error responses, etc)
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[3].fd = bufs[3].fd;
|
||||
rsp_bufs[3].ptr = bufs[3].ptr;
|
||||
rsp_bufs[3].size = bufs[3].size;
|
||||
rsp_bufs[3].offset = bufs[3].offset;
|
||||
rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.src2 = req->src2;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.src2.data = (uint32_t) bufs[2].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[3].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_matmul_id(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
|
||||
}
|
||||
|
||||
static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_binary(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 3, &prof);
|
||||
}
|
||||
|
||||
static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[2].fd = bufs[2].fd;
|
||||
rsp_bufs[2].ptr = bufs[2].ptr;
|
||||
rsp_bufs[2].offset = bufs[2].offset;
|
||||
rsp_bufs[2].size = bufs[2].size;
|
||||
rsp_bufs[2].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[3].fd = bufs[3].fd;
|
||||
rsp_bufs[3].ptr = bufs[3].ptr;
|
||||
rsp_bufs[3].offset = bufs[3].offset;
|
||||
rsp_bufs[3].size = bufs[3].size;
|
||||
rsp_bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
octx.src2 = req->src2;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.src2.data = (uint32_t) bufs[2].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[3].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_binary(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 4, &prof);
|
||||
}
|
||||
|
||||
static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_unary(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 2, &prof);
|
||||
}
|
||||
|
||||
static void proc_activations_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
uint32_t n_bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
int write_idx = 1;
|
||||
if (3 == n_bufs) {
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
write_idx = 2;
|
||||
}
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||
rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
if (3 == n_bufs) {
|
||||
octx.src1 = req->src1;
|
||||
}
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
if (3 == n_bufs) {
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||
} else {
|
||||
octx.dst.data = (uint32_t) bufs[1].ptr;
|
||||
}
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
if (octx.op == HTP_OP_SOFTMAX) {
|
||||
rsp_status = op_softmax(&octx);
|
||||
} else {
|
||||
rsp_status = op_activations(&octx);
|
||||
}
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
|
||||
}
|
||||
|
||||
static void proc_rope_req(struct htp_context * ctx,
|
||||
struct htp_general_req * req,
|
||||
struct dspqueue_buffer * bufs,
|
||||
uint32_t n_bufs) {
|
||||
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
memset(rsp_bufs, 0, sizeof(rsp_bufs));
|
||||
|
||||
rsp_bufs[0].fd = bufs[0].fd;
|
||||
rsp_bufs[0].ptr = bufs[0].ptr;
|
||||
rsp_bufs[0].offset = bufs[0].offset;
|
||||
rsp_bufs[0].size = bufs[0].size;
|
||||
rsp_bufs[0].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
rsp_bufs[1].fd = bufs[1].fd;
|
||||
rsp_bufs[1].ptr = bufs[1].ptr;
|
||||
rsp_bufs[1].offset = bufs[1].offset;
|
||||
rsp_bufs[1].size = bufs[1].size;
|
||||
rsp_bufs[1].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
int write_idx = 2;
|
||||
if (4 == n_bufs) {
|
||||
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||
rsp_bufs[write_idx].flags = DSPQUEUE_BUFFER_FLAG_DEREF; // Release reference
|
||||
|
||||
write_idx++;
|
||||
}
|
||||
|
||||
// We had written to the output buffer, we'd also need to flush it
|
||||
rsp_bufs[write_idx].fd = bufs[write_idx].fd;
|
||||
rsp_bufs[write_idx].ptr = bufs[write_idx].ptr;
|
||||
rsp_bufs[write_idx].offset = bufs[write_idx].offset;
|
||||
rsp_bufs[write_idx].size = bufs[write_idx].size;
|
||||
rsp_bufs[write_idx].flags = (DSPQUEUE_BUFFER_FLAG_DEREF | // Release reference
|
||||
DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush NSP
|
||||
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
|
||||
|
||||
// Setup Op context
|
||||
struct htp_ops_context octx = { 0 };
|
||||
octx.ctx = ctx;
|
||||
octx.src0 = req->src0;
|
||||
octx.src1 = req->src1;
|
||||
if (4 == n_bufs) {
|
||||
octx.src2 = req->src2;
|
||||
}
|
||||
octx.dst = req->dst;
|
||||
octx.flags = req->flags;
|
||||
octx.op = req->op;
|
||||
|
||||
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
|
||||
|
||||
// Update data pointers
|
||||
octx.src0.data = (uint32_t) bufs[0].ptr;
|
||||
octx.src1.data = (uint32_t) bufs[1].ptr;
|
||||
if (4 == n_bufs) {
|
||||
octx.src2.data = (uint32_t) bufs[2].ptr;
|
||||
octx.dst.data = (uint32_t) bufs[3].ptr;
|
||||
} else {
|
||||
octx.dst.data = (uint32_t) bufs[2].ptr;
|
||||
}
|
||||
octx.n_threads = ctx->n_threads;
|
||||
|
||||
struct profile_data prof;
|
||||
profile_start(&prof);
|
||||
|
||||
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
|
||||
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
|
||||
rsp_status = op_rope(&octx);
|
||||
vtcm_release(ctx);
|
||||
}
|
||||
|
||||
profile_stop(&prof);
|
||||
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, n_bufs, &prof);
|
||||
}
|
||||
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
struct htp_context * ctx = (struct htp_context *) context;
|
||||
|
||||
// Repeatedly read packets from the queue until it's empty. We don't
|
||||
// necessarily get a separate callback for each packet, and new packets
|
||||
// may arrive while we're processing the previous one. This ensures we
|
||||
// keep the DSP busy as much as possible and avoid waiting for the CPU.
|
||||
|
||||
while (1) {
|
||||
struct htp_general_req req;
|
||||
uint32_t req_size;
|
||||
|
||||
struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS];
|
||||
uint32_t n_bufs;
|
||||
uint32_t flags;
|
||||
|
||||
// Read packet from queue
|
||||
int err = dspqueue_read_noblock(queue, &flags,
|
||||
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
|
||||
&n_bufs, // Number of buffer references
|
||||
bufs, // Buffer references
|
||||
sizeof(req), // Max message length
|
||||
&req_size, // Message length
|
||||
(uint8_t *) &req); // Message
|
||||
|
||||
if (err == AEE_EWOULDBLOCK) {
|
||||
// Consumed all packets available for now
|
||||
return;
|
||||
}
|
||||
|
||||
if (err != 0) {
|
||||
FARF(ERROR, "dspqueue_read_noblock failed: 0x%08x", (unsigned) err);
|
||||
return;
|
||||
}
|
||||
|
||||
if (req_size != sizeof(req)) {
|
||||
FARF(ERROR, "Invalid request size");
|
||||
continue;
|
||||
}
|
||||
|
||||
if (req.flags & HTP_OPFLAGS_EARLY_WAKEUP) {
|
||||
// Host wants early notification
|
||||
dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0);
|
||||
}
|
||||
|
||||
// Process packet based on its message type
|
||||
switch (req.op) {
|
||||
case HTP_OP_MUL_MAT:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad matmul-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_matmul_req(ctx, &req, bufs, n_bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_MUL_MAT_ID:
|
||||
if (n_bufs != 4) {
|
||||
FARF(ERROR, "Bad matmul-id-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_matmul_id_req(ctx, &req, bufs, n_bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_MUL:
|
||||
case HTP_OP_ADD:
|
||||
case HTP_OP_SUB:
|
||||
if (n_bufs != 3) {
|
||||
FARF(ERROR, "Bad binary-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_binary_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_RMS_NORM:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad unary-req buffer list");
|
||||
continue;
|
||||
}
|
||||
|
||||
proc_unary_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_UNARY_SILU:
|
||||
if (n_bufs != 2) {
|
||||
FARF(ERROR, "Bad act-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_activations_req(ctx, &req, bufs, n_bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU:
|
||||
case HTP_OP_SOFTMAX:
|
||||
if ((n_bufs != 2) && (n_bufs != 3)) {
|
||||
FARF(ERROR, "Bad act-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_activations_req(ctx, &req, bufs, n_bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_ADD_ID:
|
||||
if (n_bufs != 4) {
|
||||
FARF(ERROR, "Bad add-id-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_add_id_req(ctx, &req, bufs);
|
||||
break;
|
||||
|
||||
case HTP_OP_ROPE:
|
||||
if ((n_bufs != 3) && (n_bufs != 4)) {
|
||||
FARF(ERROR, "Bad rope-req buffer list");
|
||||
continue;
|
||||
}
|
||||
proc_rope_req(ctx, &req, bufs, n_bufs);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unknown Op %u", req.op);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,116 @@
|
|||
#ifndef OPS_UTILS_H
|
||||
#define OPS_UTILS_H
|
||||
|
||||
#include "htp-msg.h"
|
||||
|
||||
#ifndef MAX
|
||||
# define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
#ifndef MIN
|
||||
# define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
static inline uint64_t htp_get_cycles() {
|
||||
uint64_t cycles = 0;
|
||||
asm volatile(" %0 = c15:14\n" : "=r"(cycles));
|
||||
return cycles;
|
||||
}
|
||||
|
||||
static inline uint64_t htp_get_pktcnt() {
|
||||
uint64_t pktcnt;
|
||||
asm volatile(" %0 = c19:18\n" : "=r"(pktcnt));
|
||||
return pktcnt;
|
||||
}
|
||||
|
||||
static inline int32_t htp_is_aligned(void * addr, uint32_t align) {
|
||||
return ((size_t) addr & (align - 1)) == 0;
|
||||
}
|
||||
|
||||
static inline uint32_t htp_round_up(uint32_t n, uint32_t m) {
|
||||
return m * ((n + m - 1) / m);
|
||||
}
|
||||
|
||||
static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control));
|
||||
}
|
||||
|
||||
static inline int32_t htp_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
static inline void htp_dump_int8_line(char * pref, const int8_t * x, int n) {
|
||||
char str[1024], *p = str;
|
||||
p += sprintf(p, "%s: ", pref);
|
||||
for (int i = 0; i < 16; i++) {
|
||||
p += sprintf(p, "%d, ", x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void htp_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) {
|
||||
char str[1024], *p = str;
|
||||
p += sprintf(p, "%s: ", pref);
|
||||
for (int i = 0; i < n; i++) {
|
||||
p += sprintf(p, "%d, ", x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void htp_dump_int32_line(char * pref, const int32_t * x, uint32_t n) {
|
||||
char str[1024], *p = str;
|
||||
p += sprintf(p, "%s: ", pref);
|
||||
for (int i = 0; i < n; i++) {
|
||||
p += sprintf(p, "%d, ", (int) x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void htp_dump_fp16_line(char * pref, const __fp16 * x, uint32_t n) {
|
||||
char str[1024], *p = str;
|
||||
p += sprintf(p, "%s: ", pref);
|
||||
for (int i = 0; i < n; i++) {
|
||||
p += sprintf(p, "%.6f, ", (float) x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void htp_dump_fp32_line(char * pref, const float * x, uint32_t n) {
|
||||
char str[1024], *p = str;
|
||||
p += sprintf(p, "%s: ", pref);
|
||||
for (int i = 0; i < n; i++) {
|
||||
p += sprintf(p, "%.6f, ", x[i]);
|
||||
}
|
||||
FARF(HIGH, "%s\n", str);
|
||||
}
|
||||
|
||||
static inline void htp_dump_f32(char * pref, const float * x, uint32_t n) {
|
||||
uint32_t n0 = n / 16;
|
||||
uint32_t n1 = n % 16;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < n0; i++) {
|
||||
htp_dump_fp32_line(pref, x + (16 * i), 16);
|
||||
}
|
||||
if (n1) {
|
||||
htp_dump_fp32_line(pref, x + (16 * i), n1);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void htp_dump_f16(char * pref, const __fp16 * x, uint32_t n) {
|
||||
uint32_t n0 = n / 16;
|
||||
uint32_t n1 = n % 16;
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < n0; i++) {
|
||||
htp_dump_fp16_line(pref, x + (16 * i), 16);
|
||||
}
|
||||
if (n1) {
|
||||
htp_dump_fp16_line(pref, x + (16 * i), n1);
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* OPS_UTILS_H */
|
||||
|
|
@ -0,0 +1,418 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define htp_rope_preamble \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct rope_th_ctx {
|
||||
int32_t n_dims;
|
||||
int32_t mode;
|
||||
int32_t n_ctx_orig;
|
||||
int32_t sections[4];
|
||||
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
float ext_factor;
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
float theta_scale;
|
||||
float corr_dims[2];
|
||||
|
||||
struct htp_ops_context * octx;
|
||||
};
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
|
||||
|
||||
return (1 - MIN(1, MAX(0, y)));
|
||||
}
|
||||
|
||||
static void rope_cache_init(const float theta_base,
|
||||
float freq_scale,
|
||||
const float * freq_factors,
|
||||
float * corr_dims,
|
||||
uint32_t ne0,
|
||||
float ext_factor,
|
||||
float mscale,
|
||||
float * cache,
|
||||
float theta_scale) {
|
||||
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
|
||||
float theta = theta_base;
|
||||
|
||||
for (uint32_t i0 = 0; i0 < ne0; i0 += 2) {
|
||||
const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f;
|
||||
|
||||
float theta_extrap = theta / ff;
|
||||
|
||||
// Get n-d rotational scaling corrected for extrapolation
|
||||
float theta_interp = freq_scale * theta_extrap;
|
||||
float theta2 = theta_interp;
|
||||
|
||||
if (ext_factor != 0.0f) {
|
||||
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
||||
theta2 = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||
}
|
||||
|
||||
cache[i0 + 0] = cosf(theta2) * mscale;
|
||||
cache[i0 + 1] = sinf(theta2) * mscale;
|
||||
|
||||
theta *= theta_scale;
|
||||
}
|
||||
}
|
||||
|
||||
#define M_PI 3.1415926535897932384626433
|
||||
|
||||
static void rope_corr_dims(int n_dims,
|
||||
int n_ctx_orig,
|
||||
float freq_base,
|
||||
float beta_fast,
|
||||
float beta_slow,
|
||||
float * dims) {
|
||||
float start = floorf(n_dims * logf(n_ctx_orig / (beta_fast * 2 * (float) M_PI)) / (2 * logf(freq_base)));
|
||||
float end = ceilf(n_dims * logf(n_ctx_orig / (beta_slow * 2 * (float) M_PI)) / (2 * logf(freq_base)));
|
||||
dims[0] = MAX(0, start);
|
||||
dims[1] = MIN(n_dims - 1, end);
|
||||
}
|
||||
|
||||
static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
|
||||
memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
|
||||
|
||||
const int32_t * op_params = &octx->op_params[0];
|
||||
|
||||
rope_ctx->n_dims = ((const int32_t *) op_params)[1];
|
||||
rope_ctx->mode = ((const int32_t *) op_params)[2];
|
||||
rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
|
||||
|
||||
memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
|
||||
memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
|
||||
memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
|
||||
memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
|
||||
memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
|
||||
memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
|
||||
memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
|
||||
|
||||
rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
|
||||
|
||||
rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
|
||||
rope_ctx->beta_slow, rope_ctx->corr_dims);
|
||||
|
||||
rope_ctx->octx = octx;
|
||||
FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
|
||||
rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
|
||||
}
|
||||
|
||||
static void hvx_calc_rope_f32(const float * restrict src0,
|
||||
float * restrict dst,
|
||||
const int num_elems,
|
||||
const float * restrict theta_cache) {
|
||||
// for (int i = 0; i < num_elems; i += 2) {
|
||||
//const float cos_theta = theta_cache[i + 0];
|
||||
//const float sin_theta = theta_cache[i + 1];
|
||||
|
||||
//const float x0 = src[0];
|
||||
//const float x1 = src[1];
|
||||
|
||||
//dst[0] = x0*cos_theta - x1*sin_theta;
|
||||
//dst[1] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
//src += 2;
|
||||
//dst += 2;
|
||||
// }
|
||||
|
||||
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
|
||||
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
|
||||
uint8_t * restrict dst_curr = (uint8_t *) dst;
|
||||
|
||||
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
|
||||
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
|
||||
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
|
||||
|
||||
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
|
||||
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
|
||||
|
||||
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
|
||||
|
||||
HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin));
|
||||
HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin));
|
||||
HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin));
|
||||
|
||||
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
||||
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
||||
|
||||
HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore);
|
||||
*(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
|
||||
|
||||
src0_curr += 2 * VLEN;
|
||||
theta_curr += 2 * VLEN;
|
||||
dst_curr += 2 * VLEN;
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||
const uint32_t ir0,
|
||||
const uint32_t ir1,
|
||||
int nth,
|
||||
int ith,
|
||||
int opt_path) {
|
||||
struct htp_ops_context * octx = rope_ctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_rope_preamble;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
|
||||
int ir = 0;
|
||||
|
||||
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
const int32_t p = pos[i2];
|
||||
|
||||
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
|
||||
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
|
||||
|
||||
for (uint32_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
||||
if (ir++ < ir0) {
|
||||
continue;
|
||||
}
|
||||
if (ir > ir1) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
|
||||
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
|
||||
|
||||
const float * src_loc = src;
|
||||
float * dst_data_loc = dst_data;
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||
} else {
|
||||
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
|
||||
const float cos_theta = wp0[i0 + 0];
|
||||
const float sin_theta = wp0[i0 + 1];
|
||||
|
||||
const float x0 = src_loc[0];
|
||||
const float x1 = src_loc[1];
|
||||
|
||||
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
|
||||
|
||||
src_loc += 2;
|
||||
dst_data_loc += 2;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
|
||||
dst_data_loc[0] = src_loc[0];
|
||||
dst_data_loc[1] = src_loc[1];
|
||||
|
||||
src_loc += 2;
|
||||
dst_data_loc += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
|
||||
struct htp_ops_context * octx = rope_ctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_rope_preamble;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||
FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
|
||||
is_aligned = 0;
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
|
||||
|
||||
rope_job_f32_per_thread(rope_ctx, n, i);
|
||||
}
|
||||
|
||||
static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
struct rope_th_ctx rope_ctx;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_ROPE:
|
||||
op_func = rope_job_dispatcher_f32;
|
||||
op_type = "rope-f32";
|
||||
|
||||
init_rope_ctx(&rope_ctx, octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t n_threads = octx->n_threads;
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src0_row_size;
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (src2->ne[0]) {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
|
||||
"dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
} else {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
octx->dst_spad.size);
|
||||
}
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
|
||||
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
int op_rope(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_rope_f32(octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
|
@ -0,0 +1,402 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define htp_softmax_preamble3 \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
const uint32_t ne02 = src0->ne[2]; \
|
||||
const uint32_t ne03 = src0->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src0->nb[0]; \
|
||||
const uint32_t nb01 = src0->nb[1]; \
|
||||
const uint32_t nb02 = src0->nb[2]; \
|
||||
const uint32_t nb03 = src0->nb[3]; \
|
||||
\
|
||||
const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \
|
||||
const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \
|
||||
const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \
|
||||
const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \
|
||||
\
|
||||
const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \
|
||||
const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \
|
||||
const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \
|
||||
const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct softmax_th_ctx {
|
||||
bool use_f16;
|
||||
bool use_src1;
|
||||
uint32_t n_head;
|
||||
uint32_t n_head_log2;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
struct htp_ops_context * octx;
|
||||
};
|
||||
|
||||
static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
|
||||
memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
|
||||
|
||||
memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
|
||||
memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
|
||||
softmax_ctx->n_head = src0->ne[2];
|
||||
softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
|
||||
|
||||
softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
|
||||
softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
|
||||
|
||||
softmax_ctx->use_src1 = (src1->ne[0] != 0);
|
||||
softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
|
||||
|
||||
softmax_ctx->octx = octx;
|
||||
}
|
||||
|
||||
static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
const int num_elems,
|
||||
float scale,
|
||||
const uint8_t * restrict mask,
|
||||
float slope) {
|
||||
const uint8_t * restrict src_curr = src;
|
||||
uint8_t * restrict dst_curr = dst;
|
||||
const uint8_t * restrict mask_curr = mask;
|
||||
|
||||
HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
|
||||
HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
|
||||
|
||||
int step_of_1 = num_elems >> 5;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v1 = *(HVX_Vector *) src_curr;
|
||||
|
||||
HVX_Vector v3 = *(HVX_Vector *) mask_curr;
|
||||
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);
|
||||
|
||||
HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v3, slope_vec);
|
||||
|
||||
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, v4);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v5);
|
||||
|
||||
src_curr += VLEN;
|
||||
dst_curr += VLEN;
|
||||
mask_curr += VLEN;
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
uint8_t * restrict pad,
|
||||
const int num_elems) {
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_pad = (HVX_Vector *) pad;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector max_vec = hvx_vec_splat_fp32(((const float *) src)[0]);
|
||||
HVX_Vector zero_v = Q6_V_vzero();
|
||||
HVX_Vector one_v = hvx_vec_splat_fp32(1.0);
|
||||
|
||||
int step_of_1 = num_elems >> 5;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
|
||||
}
|
||||
|
||||
HVX_Vector v = hvx_vec_reduce_max_fp32(max_vec);
|
||||
max_vec = hvx_vec_repl4(v);
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec);
|
||||
|
||||
HVX_Vector v3 = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(v2));
|
||||
|
||||
sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3);
|
||||
|
||||
v_pad[i] = v3;
|
||||
}
|
||||
|
||||
v = hvx_vec_qf32_reduce_sum(sum_vec);
|
||||
sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v));
|
||||
|
||||
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
|
||||
HVX_Vector v4 = hvx_vec_inverse_fp32(sum_vec);
|
||||
HVX_Vector scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v);
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v1 = v_pad[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec);
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
||||
}
|
||||
}
|
||||
|
||||
static float hvx_softmax_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const int num_elems,
|
||||
const float max) {
|
||||
hvx_sub_scalar_f32(src, max, spad, num_elems);
|
||||
|
||||
hvx_exp_f32(spad, dst, num_elems, false);
|
||||
|
||||
float sum = hvx_self_sum_f32(dst, num_elems);
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
|
||||
struct htp_ops_context * octx = softmax_ctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_softmax_preamble3;
|
||||
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
|
||||
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1);
|
||||
|
||||
float * wp0 = (float *) src0_spad_data;
|
||||
float * wp1 = (float *) src1_spad_data;
|
||||
float * wp2 = (float *) dst_spad_data;
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const uint32_t i11 = i01;
|
||||
const uint32_t i12 = i02 % ne12;
|
||||
const uint32_t i13 = i03 % ne13;
|
||||
|
||||
// ALiBi
|
||||
const uint32_t h = i02; // head
|
||||
|
||||
const float slope = (softmax_ctx->max_bias > 0.0f) ?
|
||||
h < softmax_ctx->n_head_log2 ?
|
||||
powf(softmax_ctx->m0, h + 1) :
|
||||
powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
|
||||
1.0f;
|
||||
|
||||
float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
|
||||
float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
|
||||
|
||||
// broadcast the mask across rows
|
||||
__fp16 * mp_f16 = (softmax_ctx->use_src1) ?
|
||||
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
float * mp_f32 = (softmax_ctx->use_src1) ?
|
||||
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
|
||||
if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
|
||||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale);
|
||||
if (mp_f32) {
|
||||
if (softmax_ctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * (float) mp_f16[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else {
|
||||
float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
|
||||
struct htp_ops_context * octx = softmax_ctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_softmax_preamble3;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n");
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||
ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
|
||||
struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
|
||||
softmax_job_f32_per_thread(p_softmax_ctx, n, i);
|
||||
}
|
||||
|
||||
static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
struct softmax_th_ctx softmax_ctx;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_SOFTMAX:
|
||||
op_func = softmax_job_dispatcher_f32;
|
||||
op_type = "softmax-f32";
|
||||
|
||||
init_softmax_ctx(&softmax_ctx, octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const uint32_t n_threads = octx->n_threads;
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src0_row_size;
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (src1->ne[0]) {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
octx->dst_spad.size);
|
||||
} else {
|
||||
FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
}
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
|
||||
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
int op_softmax(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_softmax_f32(octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
|
@ -0,0 +1,255 @@
|
|||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
#pragma clang diagnostic ignored "-Wunused-function"
|
||||
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
|
||||
#include <HAP_farf.h>
|
||||
#include <HAP_mem.h>
|
||||
#include <HAP_perf.h>
|
||||
#include <HAP_ps.h>
|
||||
#include <hexagon_protos.h>
|
||||
#include <hexagon_types.h>
|
||||
#include <math.h>
|
||||
#include <qurt_thread.h>
|
||||
#include <string.h>
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-dma.h"
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "ops-utils.h"
|
||||
|
||||
#define htp_unary_preamble \
|
||||
const uint32_t ne00 = src->ne[0]; \
|
||||
const uint32_t ne01 = src->ne[1]; \
|
||||
const uint32_t ne02 = src->ne[2]; \
|
||||
const uint32_t ne03 = src->ne[3]; \
|
||||
\
|
||||
const uint32_t ne0 = dst->ne[0]; \
|
||||
const uint32_t ne1 = dst->ne[1]; \
|
||||
const uint32_t ne2 = dst->ne[2]; \
|
||||
const uint32_t ne3 = dst->ne[3]; \
|
||||
\
|
||||
const uint32_t nb00 = src->nb[0]; \
|
||||
const uint32_t nb01 = src->nb[1]; \
|
||||
const uint32_t nb02 = src->nb[2]; \
|
||||
const uint32_t nb03 = src->nb[3]; \
|
||||
\
|
||||
const uint32_t nb0 = dst->nb[0]; \
|
||||
const uint32_t nb1 = dst->nb[1]; \
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
||||
uint8_t * restrict dst,
|
||||
uint8_t * restrict pad,
|
||||
const int num_elems,
|
||||
float epsilon) {
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector epsilon_v = hvx_vec_splat_fp32(epsilon);
|
||||
|
||||
int step_of_1 = num_elems >> 5;
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
HVX_Vector reduced_sum = hvx_vec_qf32_reduce_sum(sum_v);
|
||||
sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum));
|
||||
|
||||
HVX_Vector t_v = hvx_vec_splat_fp32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_fp32(t_v);
|
||||
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
|
||||
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
|
||||
|
||||
HVX_Vector scale_v = hvx_vec_rsqrt_fp32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
float epsilon = 0.f;
|
||||
memcpy(&epsilon, op_params, sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
||||
} else {
|
||||
float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
|
||||
|
||||
const float mean = sum / row_elems;
|
||||
const float scale = 1.0f / sqrtf(mean + epsilon);
|
||||
|
||||
hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
struct htp_tensor * dst,
|
||||
uint8_t * spad,
|
||||
int htp_op,
|
||||
int32_t * op_params,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
htp_unary_preamble;
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == htp_is_aligned((void *) src->data, VLEN)) || (0 == htp_is_aligned((void *) dst->data, VLEN))) {
|
||||
is_aligned = 0;
|
||||
FARF(HIGH, "unary-f32: unaligned addresses in unary op, possibly slower execution\n");
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src = (const uint8_t *) src->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
|
||||
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
|
||||
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
|
||||
uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01);
|
||||
|
||||
switch (htp_op) {
|
||||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
|
||||
src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
|
||||
unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
|
||||
octx->src0_nrows_per_thread);
|
||||
}
|
||||
|
||||
static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t unary_op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_RMS_NORM:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "rmsnorm-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
const int n_threads = octx->n_threads;
|
||||
const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads;
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
|
||||
|
||||
FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
||||
int op_unary(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
switch (octx->src0.type) {
|
||||
case HTP_TYPE_F32:
|
||||
err = execute_op_unary_f32(octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
err = HTP_STATUS_NO_SUPPORT;
|
||||
break;
|
||||
}
|
||||
|
||||
return err;
|
||||
}
|
||||
|
|
@ -0,0 +1,297 @@
|
|||
#include "worker-pool.h"
|
||||
|
||||
#include <qurt.h>
|
||||
#include <stdatomic.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef HTP_DEBUG
|
||||
# define FARF_HIGH 1
|
||||
#endif
|
||||
|
||||
#include "HAP_farf.h"
|
||||
|
||||
#define WORKER_THREAD_STACK_SZ (2 * 16384)
|
||||
#define LOWEST_USABLE_QURT_PRIO (254)
|
||||
|
||||
struct worker_pool_s;
|
||||
|
||||
// internal structure kept in thread-local storage per instance of worker pool
|
||||
typedef struct {
|
||||
struct worker_pool_s * pool;
|
||||
unsigned int id;
|
||||
} worker_context_t;
|
||||
|
||||
// internal structure kept in thread-local storage per instance of worker pool
|
||||
typedef struct worker_pool_s {
|
||||
worker_pool_job_t job[MAX_NUM_WORKERS]; // list of job descriptors
|
||||
qurt_thread_t thread[MAX_NUM_WORKERS]; // thread ID's of the workers
|
||||
worker_context_t context[MAX_NUM_WORKERS]; // worker contexts
|
||||
void * stack[MAX_NUM_WORKERS]; // thread stack pointers
|
||||
unsigned int n_threads; // number of workers in this pool
|
||||
|
||||
atomic_uint seqn; // seqno used to detect new jobs
|
||||
atomic_uint next_job; // next job index
|
||||
atomic_uint n_pending; // number of pending jobs
|
||||
atomic_uint n_jobs; // number of current jobs
|
||||
atomic_bool killed; // threads need to exit
|
||||
} worker_pool_t;
|
||||
|
||||
static void worker_pool_main(void * context) {
|
||||
worker_context_t * me = (worker_context_t *) context;
|
||||
worker_pool_t * pool = me->pool;
|
||||
|
||||
FARF(HIGH, "worker-pool: thread %u started", me->id);
|
||||
|
||||
unsigned int prev_seqn = 0;
|
||||
while (!atomic_load(&pool->killed)) {
|
||||
unsigned int seqn = atomic_load(&pool->seqn);
|
||||
if (seqn == prev_seqn) {
|
||||
// Nothing to do
|
||||
qurt_futex_wait(&pool->seqn, prev_seqn);
|
||||
continue;
|
||||
}
|
||||
|
||||
// New job
|
||||
prev_seqn = seqn;
|
||||
|
||||
unsigned int n = atomic_load(&pool->n_jobs);
|
||||
unsigned int i = atomic_fetch_add(&pool->next_job, 1);
|
||||
if (i >= n) {
|
||||
// Spurios wakeup
|
||||
continue;
|
||||
}
|
||||
|
||||
pool->job[i].func(n, i, pool->job[i].data);
|
||||
|
||||
atomic_fetch_sub(&pool->n_pending, 1);
|
||||
}
|
||||
|
||||
FARF(HIGH, "worker-pool: thread %u stopped", me->id);
|
||||
}
|
||||
|
||||
AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context, uint32_t n_threads, uint32_t stack_size) {
|
||||
int err = 0;
|
||||
|
||||
if (NULL == context) {
|
||||
FARF(ERROR, "NULL context passed to worker_pool_init().");
|
||||
return AEE_EBADPARM;
|
||||
}
|
||||
|
||||
// Allocations
|
||||
int size = (stack_size * n_threads) + (sizeof(worker_pool_t));
|
||||
|
||||
unsigned char * mem_blob = (unsigned char *) malloc(size);
|
||||
if (!mem_blob) {
|
||||
FARF(ERROR, "Could not allocate memory for worker pool!!");
|
||||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
worker_pool_t * me = (worker_pool_t *) (mem_blob + stack_size * n_threads);
|
||||
|
||||
// name for the first worker, useful in debugging threads
|
||||
char name[19];
|
||||
snprintf(name, 12, "0x%8x:", (int) me);
|
||||
strcat(name, "worker0");
|
||||
me->n_threads = n_threads;
|
||||
|
||||
// initializations
|
||||
for (unsigned int i = 0; i < me->n_threads; i++) {
|
||||
me->stack[i] = NULL;
|
||||
me->thread[i] = 0;
|
||||
|
||||
me->context[i].id = i;
|
||||
me->context[i].pool = me;
|
||||
}
|
||||
|
||||
// initialize job queue
|
||||
me->n_pending = 0;
|
||||
me->n_jobs = 0;
|
||||
me->next_job = 0;
|
||||
me->seqn = 0;
|
||||
me->killed = 0;
|
||||
|
||||
// launch the workers
|
||||
qurt_thread_attr_t attr;
|
||||
qurt_thread_attr_init(&attr);
|
||||
|
||||
for (unsigned int i = 0; i < me->n_threads; i++) {
|
||||
// set up stack
|
||||
me->stack[i] = mem_blob;
|
||||
mem_blob += stack_size;
|
||||
qurt_thread_attr_set_stack_addr(&attr, me->stack[i]);
|
||||
qurt_thread_attr_set_stack_size(&attr, stack_size);
|
||||
|
||||
// set up name
|
||||
qurt_thread_attr_set_name(&attr, name);
|
||||
name[17] = (name[17] + 1);
|
||||
// name threads context:worker0, context:worker1, .. (recycle at 9, but num threads should be less than that anyway)
|
||||
if (name[17] > '9') {
|
||||
name[17] = '0';
|
||||
}
|
||||
|
||||
// set up priority - by default, match the creating thread's prio
|
||||
int prio = qurt_thread_get_priority(qurt_thread_get_id());
|
||||
|
||||
if (prio < 1) {
|
||||
prio = 1;
|
||||
}
|
||||
if (prio > LOWEST_USABLE_QURT_PRIO) {
|
||||
prio = LOWEST_USABLE_QURT_PRIO;
|
||||
}
|
||||
|
||||
qurt_thread_attr_set_priority(&attr, prio);
|
||||
|
||||
// launch
|
||||
err = qurt_thread_create(&me->thread[i], &attr, worker_pool_main, (void *) &me->context[i]);
|
||||
if (err) {
|
||||
FARF(ERROR, "Could not launch worker threads!");
|
||||
worker_pool_release((worker_pool_context_t *) &me);
|
||||
return AEE_EQURTTHREADCREATE;
|
||||
}
|
||||
}
|
||||
*context = (worker_pool_context_t *) me;
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads) {
|
||||
return worker_pool_init_with_stack_size(context, n_threads, WORKER_THREAD_STACK_SZ);
|
||||
}
|
||||
|
||||
// clean up worker pool
|
||||
void worker_pool_release(worker_pool_context_t * context) {
|
||||
worker_pool_t * me = (worker_pool_t *) *context;
|
||||
|
||||
// if no worker pool exists, return error.
|
||||
if (NULL == me) {
|
||||
return;
|
||||
}
|
||||
|
||||
atomic_store(&me->killed, 1);
|
||||
atomic_fetch_add(&me->seqn, 1);
|
||||
qurt_futex_wake(&me->seqn, me->n_threads);
|
||||
|
||||
// de-initializations
|
||||
for (unsigned int i = 0; i < me->n_threads; i++) {
|
||||
if (me->thread[i]) {
|
||||
int status;
|
||||
(void) qurt_thread_join(me->thread[i], &status);
|
||||
}
|
||||
}
|
||||
|
||||
// free allocated memory (were allocated as a single buffer starting at stack[0])
|
||||
if (me->stack[0]) {
|
||||
free(me->stack[0]);
|
||||
}
|
||||
|
||||
*context = NULL;
|
||||
}
|
||||
|
||||
// run jobs
|
||||
AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n) {
|
||||
worker_pool_t * me = (worker_pool_t *) context;
|
||||
if (NULL == me) {
|
||||
FARF(ERROR, "worker-pool: invalid context");
|
||||
return AEE_EBADPARM;
|
||||
}
|
||||
|
||||
if (n > me->n_threads) {
|
||||
FARF(ERROR, "worker-pool: invalid number of jobs %u for n-threads %u", n, me->n_threads);
|
||||
return AEE_EBADPARM;
|
||||
}
|
||||
|
||||
memcpy(me->job, job, sizeof(worker_pool_job_t) * n);
|
||||
|
||||
if (n > 1) {
|
||||
atomic_store(&me->next_job, 1);
|
||||
atomic_store(&me->n_jobs, n);
|
||||
atomic_store(&me->n_pending, n - 1);
|
||||
|
||||
// wake up workers
|
||||
atomic_fetch_add(&me->seqn, 1);
|
||||
qurt_futex_wake(&me->seqn, n - 1);
|
||||
}
|
||||
|
||||
// main thread runs job #0
|
||||
me->job[0].func(n, 0, me->job[0].data);
|
||||
|
||||
if (n > 1) {
|
||||
while (atomic_load(&me->n_pending))
|
||||
;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// run func
|
||||
AEEResult worker_pool_run_func(worker_pool_context_t context, worker_callback_t func, void * data, unsigned int n) {
|
||||
worker_pool_job_t job[n];
|
||||
|
||||
for (unsigned int i = 0; i < n; i++) {
|
||||
job[i].func = func;
|
||||
job[i].data = data;
|
||||
}
|
||||
|
||||
return worker_pool_run_jobs(context, job, n);
|
||||
}
|
||||
|
||||
AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio) {
|
||||
worker_pool_t * me = (worker_pool_t *) context;
|
||||
|
||||
// if no worker pool exists, return error.
|
||||
if (!me) {
|
||||
return AEE_ENOMORE;
|
||||
}
|
||||
|
||||
int result = AEE_SUCCESS;
|
||||
if (prio < 1) {
|
||||
prio = 1;
|
||||
}
|
||||
if (prio > LOWEST_USABLE_QURT_PRIO) {
|
||||
prio = LOWEST_USABLE_QURT_PRIO;
|
||||
}
|
||||
|
||||
for (unsigned int i = 0; i < me->n_threads; i++) {
|
||||
int res = qurt_thread_set_priority(me->thread[i], (unsigned short) prio);
|
||||
if (0 != res) {
|
||||
result = AEE_EBADPARM;
|
||||
FARF(ERROR, "QURT failed to set priority of thread %d, ERROR = %d", me->thread[i], res);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids) {
|
||||
worker_pool_t * me = (worker_pool_t *) context;
|
||||
if (!me) {
|
||||
FARF(ERROR, "worker-pool: invalid context");
|
||||
return AEE_EBADPARM;
|
||||
;
|
||||
}
|
||||
|
||||
for (int i = 0; i < me->n_threads; i++) {
|
||||
tids[i] = me->thread[i];
|
||||
}
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio) {
|
||||
worker_pool_t * me = (worker_pool_t *) context;
|
||||
if (!me) {
|
||||
FARF(ERROR, "worker-pool: invalid context");
|
||||
return AEE_EBADPARM;
|
||||
}
|
||||
|
||||
int priority = qurt_thread_get_priority(me->thread[0]);
|
||||
if (priority > 0) {
|
||||
*prio = priority;
|
||||
return 0;
|
||||
} else {
|
||||
*prio = 0;
|
||||
return AEE_EBADSTATE;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
#ifndef HTP_WORKER_POOL_H
|
||||
#define HTP_WORKER_POOL_H
|
||||
|
||||
// MACRO enables function to be visible in shared-library case.
|
||||
#define WORKERPOOL_API __attribute__((visibility("default")))
|
||||
|
||||
#include <AEEStdDef.h>
|
||||
#include <AEEStdErr.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/// signature of callbacks to be invoked by worker threads
|
||||
typedef void (*worker_callback_t)(unsigned int n, unsigned int i, void *);
|
||||
|
||||
/// Typedef of worker_pool context
|
||||
typedef void * worker_pool_context_t;
|
||||
|
||||
/// descriptor for requested callback
|
||||
typedef struct {
|
||||
worker_callback_t func;
|
||||
void * data;
|
||||
} worker_pool_job_t;
|
||||
|
||||
/// Maximum supported number of worker threads.
|
||||
#define MAX_NUM_WORKERS 10
|
||||
|
||||
// Initialize worker pool.
|
||||
WORKERPOOL_API AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads);
|
||||
|
||||
// Initialize worker pool with custom stack size
|
||||
WORKERPOOL_API AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context,
|
||||
uint32_t n_threads,
|
||||
uint32_t stack_size);
|
||||
|
||||
// Kill worker threads and release worker pool resources
|
||||
WORKERPOOL_API void worker_pool_release(worker_pool_context_t * context);
|
||||
|
||||
// Run jobs with the worker pool.
|
||||
WORKERPOOL_API AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n);
|
||||
|
||||
WORKERPOOL_API AEEResult worker_pool_run_func(worker_pool_context_t context,
|
||||
worker_callback_t func,
|
||||
void * data,
|
||||
unsigned int n);
|
||||
|
||||
WORKERPOOL_API AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio);
|
||||
WORKERPOOL_API AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio);
|
||||
WORKERPOOL_API AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // #ifndef HTP_WORKER_POOL_H
|
||||
|
|
@ -28,8 +28,10 @@ if (CXX_IS_HIPCC)
|
|||
" Prefer setting the HIP compiler directly. See README for details.")
|
||||
endif()
|
||||
else()
|
||||
# Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
|
||||
if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||
# Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
|
||||
if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||
set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS})
|
||||
elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
||||
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
|
||||
endif()
|
||||
cmake_minimum_required(VERSION 3.21)
|
||||
|
|
|
|||
|
|
@ -565,14 +565,23 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
|
|||
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
||||
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
||||
|
||||
static inline int32_t ggml_node_get_use_count(const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
const struct ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
|
||||
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
|
||||
return 0;
|
||||
}
|
||||
return cgraph->use_counts[hash_pos];
|
||||
}
|
||||
|
||||
// return true if the node's results are only used by N other nodes
|
||||
// and can be fused into their calculations.
|
||||
static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
|
||||
const struct ggml_tensor * node = cgraph->nodes[node_idx];
|
||||
|
||||
// check the use count against how many we're replacing
|
||||
size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
|
||||
if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
|
||||
if (ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -638,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
|||
return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops);
|
||||
}
|
||||
|
||||
GGML_API bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
|
||||
const int * node_idxs,
|
||||
int count,
|
||||
const enum ggml_op * ops,
|
||||
const int * outputs,
|
||||
int num_outputs);
|
||||
|
||||
// Returns true if the subgraph formed by {node_idxs} can be fused
|
||||
// checks whethers all nodes which are not part of outputs can be elided
|
||||
// by checking if their num_uses are confined to the subgraph
|
||||
static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||
int node_idx,
|
||||
int count,
|
||||
const enum ggml_op * ops,
|
||||
const int * outputs,
|
||||
int num_outputs) {
|
||||
GGML_ASSERT(count < 32);
|
||||
if (node_idx + count > cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int idxs[32];
|
||||
|
||||
for (int i = 0; i < count; ++i) {
|
||||
idxs[i] = node_idx + i;
|
||||
}
|
||||
|
||||
return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -651,6 +690,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
|
|||
return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
|
||||
}
|
||||
|
||||
inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
|
||||
int start_idx,
|
||||
std::initializer_list<enum ggml_op> ops,
|
||||
std::initializer_list<int> outputs = {}) {
|
||||
return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
|
||||
}
|
||||
|
||||
// expose GGUF internals for test code
|
||||
GGML_API size_t gguf_type_size(enum gguf_type type);
|
||||
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
|
||||
|
|
|
|||
|
|
@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
|
|||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous(op->src[1]));
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(op->type == GGML_TYPE_F32);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_UPSCALE);
|
||||
|
||||
|
|
|
|||
|
|
@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me
|
|||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||
|
|
|
|||
|
|
@ -653,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_SCALE:
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) &&
|
||||
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SQR:
|
||||
|
|
|
|||
|
|
@ -514,6 +514,19 @@ typedef struct {
|
|||
uint64_t nb1;
|
||||
} ggml_metal_kargs_conv_transpose_1d;
|
||||
|
||||
typedef struct {
|
||||
int32_t IC;
|
||||
int32_t IH;
|
||||
int32_t IW;
|
||||
int32_t KH;
|
||||
int32_t KW;
|
||||
int32_t OC;
|
||||
int32_t s0;
|
||||
uint64_t nb0;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
} ggml_metal_kargs_conv_transpose_2d;
|
||||
|
||||
typedef struct {
|
||||
uint64_t ofs0;
|
||||
uint64_t ofs1;
|
||||
|
|
|
|||
|
|
@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|||
{
|
||||
n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
{
|
||||
n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
|
||||
} break;
|
||||
case GGML_OP_UPSCALE:
|
||||
{
|
||||
n_fuse = ggml_metal_op_upscale(ctx, idx);
|
||||
|
|
@ -3118,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
||||
|
||||
const int32_t s0 = ((const int32_t *)(op->op_params))[0];
|
||||
|
||||
const int32_t IC = op->src[1]->ne[2];
|
||||
const int32_t IH = op->src[1]->ne[1];
|
||||
const int32_t IW = op->src[1]->ne[0];
|
||||
|
||||
const int32_t KH = op->src[0]->ne[1];
|
||||
const int32_t KW = op->src[0]->ne[0];
|
||||
|
||||
const int32_t OW = op->ne[0];
|
||||
const int32_t OH = op->ne[1];
|
||||
const int32_t OC = op->ne[2];
|
||||
|
||||
ggml_metal_kargs_conv_transpose_2d args = {
|
||||
/*.IC =*/ IC,
|
||||
/*.IH =*/ IH,
|
||||
/*.IW =*/ IW,
|
||||
/*.KH =*/ KH,
|
||||
/*.KW =*/ KW,
|
||||
/*.OC =*/ OC,
|
||||
/*.s0 =*/ s0,
|
||||
/*.nb0 =*/ nb0,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
};
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
|
||||
|
||||
// Metal requires buffer size to be multiple of 16 bytes
|
||||
const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
|
|||
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
|
||||
int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
|
||||
|
|
|
|||
|
|
@ -4179,6 +4179,97 @@ kernel void kernel_conv_transpose_1d<half>(
|
|||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
|
||||
typedef void (conv_transpose_2d_t)(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tgpg[[threadgroups_per_grid]]);
|
||||
|
||||
template <typename T>
|
||||
kernel void kernel_conv_transpose_2d(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const T * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
|
||||
const int64_t out_x = tgpig[0];
|
||||
const int64_t out_y = tgpig[1];
|
||||
const int64_t out_c = tgpig[2];
|
||||
|
||||
const int64_t kw = tpitg[0];
|
||||
const int64_t kh = tpitg[1];
|
||||
|
||||
float v = 0.0f;
|
||||
|
||||
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
|
||||
int64_t in_y = out_y - kh;
|
||||
|
||||
if (in_y < 0 || in_y % args.s0) continue;
|
||||
|
||||
in_y /= args.s0;
|
||||
|
||||
if (in_y >= args.IH) continue;
|
||||
|
||||
int64_t in_x = out_x - kw;
|
||||
|
||||
if (in_x < 0 || in_x % args.s0) continue;
|
||||
|
||||
in_x /= args.s0;
|
||||
|
||||
if (in_x >= args.IW) continue;
|
||||
|
||||
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
|
||||
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
|
||||
|
||||
v += (float)src0[kernel_idx] * src1[input_idx];
|
||||
}
|
||||
|
||||
const uint tid = tpitg.y * ntg.x + tpitg.x;
|
||||
shared_sum[tid] = v;
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tid == 0) {
|
||||
float total = 0.0f;
|
||||
const uint num_threads = ntg.x * ntg.y;
|
||||
for (uint i = 0; i < num_threads; i++) {
|
||||
total += shared_sum[i];
|
||||
}
|
||||
|
||||
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
|
||||
dst_ptr[0] = total;
|
||||
}
|
||||
}
|
||||
|
||||
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
|
||||
kernel void kernel_conv_transpose_2d<float>(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
|
||||
kernel void kernel_conv_transpose_2d<half>(
|
||||
constant ggml_metal_kargs_conv_transpose_2d & args,
|
||||
device const half * src0,
|
||||
device const float * src1,
|
||||
device char * dst,
|
||||
threadgroup float * shared_sum [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]);
|
||||
|
||||
kernel void kernel_upscale_f32(
|
||||
constant ggml_metal_kargs_upscale & args,
|
||||
device const char * src0,
|
||||
|
|
|
|||
|
|
@ -91,6 +91,8 @@ set(GGML_OPENCL_KERNELS
|
|||
mul_mv_id_q8_0_f32_flat
|
||||
mul_mv_id_mxfp4_f32
|
||||
mul_mv_id_mxfp4_f32_flat
|
||||
gemm_moe_mxfp4_f32
|
||||
gemv_moe_mxfp4_f32
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul_mm_q8_0_f32_l4_lm
|
||||
|
|
|
|||
|
|
@ -15,13 +15,12 @@
|
|||
|
||||
#include <CL/cl.h>
|
||||
|
||||
#include <inttypes.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <atomic>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <cmath>
|
||||
|
|
@ -402,6 +401,7 @@ struct ggml_backend_opencl_context {
|
|||
cl_program program_conv_2d_f32;
|
||||
cl_program program_conv_2d_f16_f32;
|
||||
cl_program program_tsembd;
|
||||
cl_program program_gemv_moe_mxfp4_f32, program_gemm_moe_mxfp4_f32;
|
||||
cl_program program_mul_mv_id_q4_0_f32_8x_flat;
|
||||
cl_program program_mul_mv_id_q8_0_f32, program_mul_mv_id_q8_0_f32_flat;
|
||||
cl_program program_mul_mv_id_mxfp4_f32;
|
||||
|
|
@ -452,7 +452,7 @@ struct ggml_backend_opencl_context {
|
|||
cl_kernel kernel_mul_mat_f16_f32_tiled;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
|
||||
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
|
||||
cl_kernel kernel_convert_block_mxfp4, kernel_restore_block_mxfp4;
|
||||
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
|
||||
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0;
|
||||
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
|
||||
cl_kernel kernel_convert_block_q4_0_noshuffle;
|
||||
|
|
@ -475,6 +475,7 @@ struct ggml_backend_opencl_context {
|
|||
cl_kernel kernel_conv_2d_f32;
|
||||
cl_kernel kernel_conv_2d_f16_f32;
|
||||
cl_kernel kernel_timestep_embedding;
|
||||
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
|
||||
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
|
||||
cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat;
|
||||
cl_kernel kernel_mul_mv_id_mxfp4_f32;
|
||||
|
|
@ -531,25 +532,17 @@ struct ggml_backend_opencl_context {
|
|||
}
|
||||
|
||||
// Dump a csv
|
||||
float total_kernel_time = 0;
|
||||
fprintf(fperf, "op name, kernel name, queued duration (ms), submit duration(ms), exec duration (ms), complete duration (ms), total duration (ms), global size, local size, output size\n");
|
||||
fprintf(fperf, "op name, kernel name, exec duration (ms), global size, local size, output size\n");
|
||||
for (const ProfilingInfo & info : profiling_info) {
|
||||
total_kernel_time += info.cmd_duration_ns/1.e6f;
|
||||
fprintf(fperf, "%s,%s,%f,%f,%f,%f,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
|
||||
fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
|
||||
info.op_name.c_str(), info.kernel_name.c_str(),
|
||||
info.cmd_queued_duration_ns/1.e6f,
|
||||
info.cmd_submit_duration_ns/1.e6f,
|
||||
info.cmd_duration_ns/1.e6f,
|
||||
info.cmd_complete_duration_ns/1.e6f,
|
||||
info.cmd_total_duration_ns/1.e6f,
|
||||
info.global_size[0], info.global_size[1], info.global_size[2],
|
||||
info.local_size[0], info.local_size[1], info.local_size[2],
|
||||
info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
|
||||
}
|
||||
fclose(fperf);
|
||||
|
||||
GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
|
||||
|
||||
// Dump a simple chrome trace
|
||||
FILE* ftrace = fopen("cl_trace.json", "w");
|
||||
if (!ftrace) {
|
||||
|
|
@ -559,14 +552,14 @@ struct ggml_backend_opencl_context {
|
|||
|
||||
fprintf(ftrace, "[\n");
|
||||
for (const ProfilingInfo & info : profiling_info) {
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_queued/1000);
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_submit/1000);
|
||||
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_start/1000);
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n",
|
||||
info.kernel_name.c_str(), info.cmd_end/1000);
|
||||
}
|
||||
fclose(ftrace);
|
||||
|
|
@ -777,6 +770,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|||
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err));
|
||||
CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err));
|
||||
|
|
@ -1991,6 +1986,42 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
|||
CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std +
|
||||
" -cl-mad-enable "
|
||||
" -cl-fast-relaxed-math";
|
||||
|
||||
// gemv_moe_mxfp4_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemv_moe_mxfp4_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemv_moe_mxfp4_f32.cl");
|
||||
#endif
|
||||
backend_ctx->program_gemv_moe_mxfp4_f32 =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemv_moe_mxfp4_f32, "kernel_gemv_moe_mxfp4_f32", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
|
||||
// gemm_moe_mxfp4_f32
|
||||
{
|
||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||
const std::string kernel_src {
|
||||
#include "gemm_moe_mxfp4_f32.cl.h"
|
||||
};
|
||||
#else
|
||||
const std::string kernel_src = read_file("gemm_moe_mxfp4_f32.cl");
|
||||
#endif
|
||||
backend_ctx->program_gemm_moe_mxfp4_f32 =
|
||||
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts);
|
||||
|
||||
CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err));
|
||||
GGML_LOG_CONT(".");
|
||||
}
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
GGML_LOG_CONT("\n");
|
||||
}
|
||||
|
|
@ -3299,6 +3330,12 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c
|
|||
tensor->ne[2] == 1 && tensor->ne[3] == 1;
|
||||
}
|
||||
|
||||
inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) {
|
||||
GGML_UNUSED(backend_ctx);
|
||||
int ne01 = tensor->ne[1];
|
||||
return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0);
|
||||
}
|
||||
|
||||
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
|
||||
|
||||
|
|
@ -3601,14 +3638,39 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans;
|
||||
|
||||
int ne00 = tensor->ne[0];
|
||||
int ne01 = tensor->ne[1];
|
||||
int ne02 = tensor->ne[2];
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01));
|
||||
|
||||
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
|
||||
size_t local_work_size[3] = {64, 2, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
tensor->extra = extra;
|
||||
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4;
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e));
|
||||
|
||||
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[] = {64, 1, 1};
|
||||
size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
|
||||
size_t local_work_size[3] = {64, 1, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
|
||||
|
|
@ -3624,7 +3686,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
|
|||
{ extra->q }
|
||||
};
|
||||
extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err);
|
||||
|
||||
tensor->extra = extra;
|
||||
|
||||
return;
|
||||
|
|
@ -3751,6 +3812,33 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
|
|||
ggml_nbytes(tensor), NULL, &err);
|
||||
CL_CHECK(err);
|
||||
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_moe_kernels(backend_ctx, tensor)) {
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans;
|
||||
|
||||
int ne00 = tensor->ne[0];
|
||||
int ne01 = tensor->ne[1];
|
||||
int ne02 = tensor->ne[2];
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01));
|
||||
|
||||
size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)};
|
||||
size_t local_work_size[3] = {64, 2, 1};
|
||||
|
||||
cl_event evt;
|
||||
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
|
||||
global_work_size, local_work_size, 0, NULL, &evt));
|
||||
CL_CHECK(clWaitForEvents(1, &evt));
|
||||
CL_CHECK(clEnqueueReadBuffer(
|
||||
queue, data_device, CL_TRUE, offset,
|
||||
size, data, 0, NULL, NULL));
|
||||
CL_CHECK(clReleaseMemObject(data_device));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4;
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e));
|
||||
|
|
@ -7553,6 +7641,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
|||
const int ne21 = src2->ne[1];
|
||||
|
||||
const cl_ulong nb21 = src2->nb[1];
|
||||
const cl_ulong nb20 = src2->nb[0];
|
||||
|
||||
UNUSED(nb20);
|
||||
|
||||
const int ne0 = dst->ne[0];
|
||||
const int ne1 = dst->ne[1];
|
||||
|
|
@ -7692,6 +7783,105 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0,
|
|||
break;
|
||||
}
|
||||
case GGML_TYPE_MXFP4: {
|
||||
#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
if (use_adreno_moe_kernels(backend_ctx, src0)) {
|
||||
cl_int status;
|
||||
|
||||
size_t local_size[3] = {64, 2, 1};
|
||||
size_t global_size[3] = {64, 2, 1};
|
||||
|
||||
cl_mem src1_sub_buffer, buf_src1_image, buf_src2;
|
||||
|
||||
int tile_size = 320;
|
||||
if (ne12 == 1) { // for gemv
|
||||
kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32;
|
||||
|
||||
// create a sub_buffer for src2
|
||||
cl_buffer_region region;
|
||||
region.origin = offset2;
|
||||
region.size = ne20 * ne21 * sizeof(int);
|
||||
buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// set thread grid
|
||||
global_size[0] = static_cast<size_t>(ne01);
|
||||
global_size[1] = 4;
|
||||
global_size[2] = static_cast<size_t>(ne20);
|
||||
local_size[1] = 4;
|
||||
} else { // for gemm
|
||||
kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32;
|
||||
|
||||
// preprocess router table
|
||||
int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size;
|
||||
void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short));
|
||||
void * host_src2 = malloc(ne21 * nb21);
|
||||
CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL));
|
||||
int total_experts = nb21 / nb20;
|
||||
int out_idx = 0;
|
||||
for (int i_expert = 0; i_expert < ne02; i_expert++) {
|
||||
for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) {
|
||||
for (int j = 0; j < ne21; j++) {
|
||||
for (int i = 0; i < ne20; i++) {
|
||||
int expert = ((int *)host_src2)[j * total_experts + i];
|
||||
if (i_expert == expert) {
|
||||
((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert);
|
||||
((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11));
|
||||
((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i);
|
||||
((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile);
|
||||
out_idx += 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// set thread grid
|
||||
global_size[0] = static_cast<size_t>(tile_size);
|
||||
global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert);
|
||||
}
|
||||
|
||||
// create a sub_buffer for src1
|
||||
cl_buffer_region region;
|
||||
region.origin = offset1;
|
||||
region.size = ne10 * ne11 * ne12 * sizeof(float);
|
||||
src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// create image for src1
|
||||
cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT};
|
||||
cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}};
|
||||
buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status);
|
||||
CL_CHECK(status);
|
||||
|
||||
// Set kernel args
|
||||
int arg_idx = 0;
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01));
|
||||
if (ne12 == 1) {
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11));
|
||||
} else {
|
||||
CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size));
|
||||
}
|
||||
|
||||
// launch kernel
|
||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst);
|
||||
|
||||
// deallocate sub buffers and images
|
||||
CL_CHECK(clReleaseMemObject(src1_sub_buffer));
|
||||
CL_CHECK(clReleaseMemObject(buf_src1_image));
|
||||
CL_CHECK(clReleaseMemObject(buf_src2));
|
||||
return;
|
||||
} // else fallback to generic kernel
|
||||
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
|
||||
|
||||
#ifdef GGML_OPENCL_SOA_Q
|
||||
kernel = backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat;
|
||||
|
||||
|
|
|
|||
|
|
@ -147,6 +147,27 @@ kernel void kernel_convert_block_mxfp4(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_convert_block_mxfp4_trans(
|
||||
global struct block_mxfp4 * src0,
|
||||
__global uint4 * dst_q,
|
||||
__global uchar * dst_e,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
int i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_MXFP4;
|
||||
uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
|
||||
global struct block_mxfp4 * b = src0 + src_blk_offset;
|
||||
|
||||
dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0];
|
||||
dst_e[dst_blk_offset] = b->e;
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_mxfp4(
|
||||
global uchar * src_q,
|
||||
global half * src_e,
|
||||
|
|
@ -162,6 +183,27 @@ kernel void kernel_restore_block_mxfp4(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_mxfp4_trans(
|
||||
__global uint4 * src_q,
|
||||
__global uchar * src_e,
|
||||
global struct block_mxfp4 * dst,
|
||||
uint ne00,
|
||||
uint ne01
|
||||
) {
|
||||
int i00 = get_global_id(1);
|
||||
uint i01 = get_global_id(0);
|
||||
uint i02 = get_global_id(2);
|
||||
|
||||
uint ne00_blk = ne00 / QK_MXFP4;
|
||||
uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01;
|
||||
uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01;
|
||||
|
||||
global struct block_mxfp4 * b = dst + dst_blk_offset;
|
||||
|
||||
((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset];
|
||||
b->e = src_e[src_blk_offset];
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q8_0
|
||||
//------------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -0,0 +1,162 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#define QK_MXFP4 32
|
||||
#define N_SIMDGROUP 2
|
||||
#define SIMDGROUP_WIDTH 64
|
||||
|
||||
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
|
||||
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
||||
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
||||
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
||||
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
||||
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
||||
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
||||
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s0 & 0x8000;
|
||||
|
||||
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
||||
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
||||
|
||||
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
||||
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
||||
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
||||
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
||||
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
||||
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
||||
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
||||
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s1 & 0x8000;
|
||||
|
||||
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
||||
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
||||
|
||||
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("half")))
|
||||
__kernel void kernel_gemm_moe_mxfp4_f32(
|
||||
__global uint4 * src0_q,
|
||||
__global uchar * src0_e,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global ushort4 * src2,
|
||||
__global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int tile_size
|
||||
) {
|
||||
uint i01 = get_global_id(0);
|
||||
uint i20 = get_global_id(2);
|
||||
uint sgid = get_local_id(1);
|
||||
uint slid = get_sub_group_local_id();
|
||||
|
||||
ushort4 router = src2[i20];
|
||||
ushort expert_id = router.x;
|
||||
ushort i11 = router.y;
|
||||
ushort i1 = router.z;
|
||||
ushort tile_id = router.w;
|
||||
|
||||
if (tile_id * tile_size + i01 >= ne01) { // handle edge case when ne01 is not multiple of tile_size
|
||||
return;
|
||||
}
|
||||
|
||||
uint expert_offset = expert_id * ne00 * ne01 / 32;
|
||||
uint tile_offset = expert_offset + tile_id * tile_size + i01;
|
||||
|
||||
__private float sum = 0.0f; // each thread calculate partial sum of one output
|
||||
|
||||
// loop along ne00 in block granularity, skip 4 blocks every iter
|
||||
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
|
||||
// load one block of q
|
||||
uint4 regQ = src0_q[tile_offset + ib00 * ne01];
|
||||
// convert 8 fp4 to fp16
|
||||
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
|
||||
|
||||
uint offset = i11 * ne00 / 4 + ib00 * 8;
|
||||
float4 shared_y4;
|
||||
shared_y4 = read_imagef(src1, (offset + 0));
|
||||
float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 4));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 1));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 5));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 2));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 6));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 3));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 7));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
uchar regE = src0_e[tile_offset + ib00 * ne01];
|
||||
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #subgroups=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
||||
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
||||
// if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
||||
// if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
// if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
// if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 outputs per thread in subgroup 0
|
||||
if (sgid == 0) {
|
||||
dst = dst + (offsetd >> 2);
|
||||
dst[i01 + tile_id * tile_size + i1 * ne01] = sum;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,156 @@
|
|||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#define QK_MXFP4 32
|
||||
#define N_SIMDGROUP 4
|
||||
#define SIMDGROUP_WIDTH 64
|
||||
|
||||
static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { //, ushort 0x0E00, ushort 0x8000) {
|
||||
ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b;
|
||||
fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00;
|
||||
fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00;
|
||||
fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00;
|
||||
fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0;
|
||||
fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0;
|
||||
fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0;
|
||||
fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s0 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s0 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s0 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s0 & 0x8000;
|
||||
|
||||
fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0;
|
||||
fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0;
|
||||
|
||||
ushort2 fp16_packed_a_1, fp16_packed_b_1;
|
||||
fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00;
|
||||
fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00;
|
||||
fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00;
|
||||
fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00;
|
||||
|
||||
bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0;
|
||||
bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0;
|
||||
bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0;
|
||||
|
||||
fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0;
|
||||
fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0;
|
||||
fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0;
|
||||
fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0;
|
||||
|
||||
sign_a.lo = (fp4x8.s1 << 12) & 0x8000;
|
||||
sign_a.hi = (fp4x8.s1 << 8) & 0x8000;
|
||||
sign_b.lo = (fp4x8.s1 << 4) & 0x8000;
|
||||
sign_b.hi = fp4x8.s1 & 0x8000;
|
||||
|
||||
fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1;
|
||||
fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1;
|
||||
|
||||
return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1));
|
||||
}
|
||||
|
||||
static inline float e8m0_to_fp32(uchar x) {
|
||||
int bits;
|
||||
bits = (x == 0) ? 0x00400000 : ((uint) x << 23);
|
||||
return as_float(bits);
|
||||
}
|
||||
|
||||
|
||||
__attribute__((qcom_reqd_sub_group_size("half")))
|
||||
__kernel void kernel_gemv_moe_mxfp4_f32(
|
||||
__global uint4 * src0_q,
|
||||
__global uchar * src0_e,
|
||||
__read_only image1d_buffer_t src1,
|
||||
__global uint * src2,
|
||||
__global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne11
|
||||
) {
|
||||
uint i01 = get_global_id(0);
|
||||
uint i20 = get_global_id(2);
|
||||
uint sgid = get_local_id(1);
|
||||
uint slid = get_sub_group_local_id();
|
||||
|
||||
uint i11 = i20 % ne11;
|
||||
|
||||
uint expert_id = src2[i20];
|
||||
uint expert_offset = expert_id * ne00 * ne01 / 32;
|
||||
|
||||
__private float sum = 0.0f; // each thread calculate partial sum of one output
|
||||
|
||||
// loop along ne00 in block granularity, skip 4 blocks every iter
|
||||
for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) {
|
||||
|
||||
// load one block of q
|
||||
uint4 regQ = src0_q[expert_offset + ib00 * ne01 + i01];
|
||||
|
||||
uint offset = i11 * ne00 / 4 + ib00 * 8;
|
||||
|
||||
half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0));
|
||||
|
||||
float4 shared_y4;
|
||||
shared_y4 = read_imagef(src1, (offset + 0));
|
||||
float4 acc = shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 4));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 1));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 5));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 2));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 6));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
|
||||
fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3));
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 3));
|
||||
acc += shared_y4 * (float4)(fp16x8.s0, fp16x8.s2, fp16x8.s4, fp16x8.s6);
|
||||
|
||||
shared_y4 = read_imagef(src1, (offset + 7));
|
||||
acc += shared_y4 * (float4)(fp16x8.s1, fp16x8.s3, fp16x8.s5, fp16x8.s7);
|
||||
|
||||
uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset];
|
||||
sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3));
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #subgroups=4
|
||||
__local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)];
|
||||
if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum;
|
||||
if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum;
|
||||
if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
// 1 outputs per thread in subgroup 0
|
||||
if (sgid == 0) {
|
||||
dst = dst + (offsetd >> 2);
|
||||
dst[i01 + i20 * ne01] = sum;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -939,6 +939,7 @@ public:
|
|||
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
|
||||
bool init_tensor(const rpc_msg_init_tensor_req & request);
|
||||
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
|
||||
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
|
||||
|
||||
private:
|
||||
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
|
||||
|
|
@ -1458,6 +1459,20 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
|
|||
return true;
|
||||
}
|
||||
|
||||
bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
|
||||
uint32_t dev_id = request.device;
|
||||
if (dev_id >= backends.size()) {
|
||||
return false;
|
||||
}
|
||||
size_t free, total;
|
||||
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
response.free_mem = free;
|
||||
response.total_mem = total;
|
||||
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
|
||||
return true;
|
||||
}
|
||||
|
||||
rpc_server::~rpc_server() {
|
||||
for (auto buffer : buffers) {
|
||||
ggml_backend_buffer_free(buffer);
|
||||
|
|
@ -1465,7 +1480,7 @@ rpc_server::~rpc_server() {
|
|||
}
|
||||
|
||||
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
|
||||
sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
|
||||
sockfd_t sockfd) {
|
||||
rpc_server server(backends, cache_dir);
|
||||
uint8_t cmd;
|
||||
if (!recv_data(sockfd, &cmd, 1)) {
|
||||
|
|
@ -1689,15 +1704,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
return;
|
||||
}
|
||||
auto dev_id = request.device;
|
||||
if (dev_id >= backends.size()) {
|
||||
rpc_msg_get_device_memory_rsp response;
|
||||
if (!server.get_device_memory(request, response)) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_get_device_memory_rsp response;
|
||||
response.free_mem = free_mem[dev_id];
|
||||
response.total_mem = total_mem[dev_id];
|
||||
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
|
||||
response.free_mem, response.total_mem);
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -1712,15 +1722,12 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
|
|||
}
|
||||
|
||||
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
|
||||
size_t n_threads, size_t n_devices,
|
||||
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
|
||||
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
|
||||
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
|
||||
if (n_devices == 0 || devices == nullptr) {
|
||||
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
|
||||
return;
|
||||
}
|
||||
std::vector<ggml_backend_t> backends;
|
||||
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
|
||||
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
|
||||
printf("Starting RPC server v%d.%d.%d\n",
|
||||
RPC_PROTO_MAJOR_VERSION,
|
||||
RPC_PROTO_MINOR_VERSION,
|
||||
|
|
@ -1730,8 +1737,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
|
|||
printf("Devices:\n");
|
||||
for (size_t i = 0; i < n_devices; i++) {
|
||||
auto dev = devices[i];
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
|
||||
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
|
||||
total / 1024 / 1024, free / 1024 / 1024);
|
||||
auto backend = ggml_backend_dev_init(dev, nullptr);
|
||||
if (!backend) {
|
||||
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
|
||||
|
|
@ -1775,7 +1784,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
|
|||
}
|
||||
printf("Accepted client connection\n");
|
||||
fflush(stdout);
|
||||
rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
|
||||
rpc_serve_client(backends, cache_dir, client_socket->fd);
|
||||
printf("Client connection closed\n");
|
||||
fflush(stdout);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,5 +37,7 @@
|
|||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "wkv.hpp"
|
||||
#include "pad_reflect_1d.hpp"
|
||||
|
||||
|
||||
#endif // GGML_SYCL_BACKEND_HPP
|
||||
|
|
|
|||
|
|
@ -150,6 +150,26 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
|
|||
return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_floor(T x) {
|
||||
return sycl::floor(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_ceil(T x) {
|
||||
return sycl::ceil(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_round(T x) {
|
||||
return sycl::round(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static __dpct_inline__ T op_trunc(T x) {
|
||||
return sycl::trunc(x);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
|
|
@ -304,6 +324,34 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl:
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_floor(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_ceil(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_round(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = op_trunc(x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void upscale(const T *x, T *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
|
|
@ -397,6 +445,14 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst,
|
|||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void arange_kernel(T * dst, const int k, T start, T step,
|
||||
const sycl::nd_item<1> &item_ct1) {
|
||||
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
|
||||
dst[i] = start + static_cast<T>(i) * step;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
|
|
@ -565,6 +621,25 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx
|
|||
}
|
||||
|
||||
|
||||
static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
float start, stop, step;
|
||||
memcpy(&start, dst->op_params, sizeof(float));
|
||||
memcpy(&stop, (float *) dst->op_params + 1, sizeof(float));
|
||||
memcpy(&step, (float *) dst->op_params + 2, sizeof(float));
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
float * dst_ptr = (float *)dst->data;
|
||||
const int k = (int)ggml_nelements(dst);
|
||||
const int num_blocks = ceil_div(k, SYCL_ARANGE_BLOCK_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
|
||||
sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
arange_kernel(dst_ptr, k, start, step, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ggml_sycl_detail
|
||||
|
||||
|
||||
|
|
@ -870,6 +945,58 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
|
|||
}, min_val, max_val);
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
|
||||
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
|
||||
const int num_blocks = ceil_div(k_elements, 256);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
|
||||
sycl::range<1>(256)),
|
||||
[=](sycl::nd_item<1> item_ct1) {
|
||||
unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
|
||||
|
|
@ -1090,3 +1217,28 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_geglu_quick(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
|
||||
ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_floor(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_ceil(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_round(ctx, dst);
|
||||
}
|
||||
|
||||
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
ggml_sycl_op_trunc(ctx, dst);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,5 +80,11 @@ void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
|||
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
||||
|
|
|
|||
|
|
@ -30,6 +30,9 @@
|
|||
#include <regex>
|
||||
|
||||
#include <sycl/sycl.hpp>
|
||||
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
||||
# include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp>
|
||||
#endif
|
||||
#include <sycl/half_type.hpp>
|
||||
|
||||
#include "ggml-sycl.h"
|
||||
|
|
@ -42,6 +45,7 @@
|
|||
#include "ggml-sycl/presets.hpp"
|
||||
#include "ggml-sycl/gemm.hpp"
|
||||
#include "ggml-sycl/set_rows.hpp"
|
||||
#include "ggml-sycl/set.hpp"
|
||||
#include "ggml-sycl/sycl_hw.hpp"
|
||||
#include "ggml-sycl/getrows.hpp"
|
||||
#include "ggml-sycl/quantize.hpp"
|
||||
|
|
@ -53,6 +57,7 @@ int g_ggml_sycl_disable_optimize = 0;
|
|||
int g_ggml_sycl_disable_graph = 0;
|
||||
int g_ggml_sycl_disable_dnn = 0;
|
||||
int g_ggml_sycl_prioritize_dmmv = 0;
|
||||
int g_ggml_sycl_use_async_mem_op = 0;
|
||||
|
||||
static ggml_sycl_device_info ggml_sycl_init() {
|
||||
ggml_sycl_device_info info = {};
|
||||
|
|
@ -236,7 +241,20 @@ static void ggml_check_sycl() try {
|
|||
fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
|
||||
#endif
|
||||
*/
|
||||
|
||||
// Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be
|
||||
// properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in
|
||||
// other places.
|
||||
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
||||
g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph;
|
||||
if (g_ggml_sycl_use_async_mem_op) {
|
||||
for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) {
|
||||
if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) {
|
||||
g_ggml_sycl_use_async_mem_op = 0;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
if (CHECK_TRY_ERROR(g_all_sycl_device_count =
|
||||
dpct::dev_mgr::instance().device_count()) != 0) {
|
||||
initialized = true;
|
||||
|
|
@ -2151,6 +2169,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
|
|||
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
const int64_t ncols = dst->src[0]->ne[0];
|
||||
const int64_t nrows = ggml_nrows(dst->src[0]);
|
||||
|
||||
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
|
||||
|
||||
main_stream->parallel_for(
|
||||
sycl::range<1>(nrows),
|
||||
[=](sycl::id<1> row) {
|
||||
dst_dd[row] /= ncols;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_I32);
|
||||
|
|
@ -3006,19 +3048,51 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) {
|
|||
}
|
||||
}
|
||||
|
||||
// Helper functions to unify device memory allocation for both async and sync paths
|
||||
static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) {
|
||||
bool use_async = g_ggml_sycl_use_async_mem_op;
|
||||
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
||||
if (use_async) {
|
||||
return syclex::async_malloc(*stream, sycl::usm::alloc::device, size);
|
||||
}
|
||||
#else
|
||||
// If async allocation extension is not available, use_async should always be false.
|
||||
GGML_ASSERT(!use_async);
|
||||
#endif
|
||||
return sycl::malloc(size, *stream, sycl::usm::alloc::device);
|
||||
}
|
||||
|
||||
static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) {
|
||||
bool use_async = g_ggml_sycl_use_async_mem_op;
|
||||
#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC
|
||||
if (use_async) {
|
||||
syclex::async_free(*stream, ptr);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
// If async allocation extension is not available, use_async should always be false.
|
||||
GGML_ASSERT(!use_async);
|
||||
#endif
|
||||
sycl::free(ptr, *stream);
|
||||
}
|
||||
|
||||
static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset,
|
||||
dpct::queue_ptr stream) {
|
||||
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
||||
SYCL_CHECK(
|
||||
CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size)
|
||||
.wait()));
|
||||
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
|
||||
|
||||
sycl::event copy_event;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
|
||||
if (!g_ggml_sycl_use_async_mem_op) {
|
||||
copy_event.wait();
|
||||
}
|
||||
|
||||
GGML_ASSERT((size % sizeof(block_q4_0) == 0));
|
||||
GGML_ASSERT((offset % sizeof(block_q4_0) == 0));
|
||||
int offset_blks = offset / sizeof(block_q4_0);
|
||||
auto qs_ptr = data_device + offset_blks * QK4_0 / 2;
|
||||
auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks;
|
||||
|
||||
stream->parallel_for(
|
||||
auto reorder_event = stream->parallel_for(
|
||||
size / sizeof(block_q4_0),
|
||||
[=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
const block_q4_0* x = (const block_q4_0*)tmp_buf;
|
||||
|
|
@ -3029,9 +3103,11 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr
|
|||
*(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j];
|
||||
}
|
||||
*(d_ptr + ib) = x[ib].d;
|
||||
}).wait_and_throw();
|
||||
|
||||
sycl::free(tmp_buf, *stream);
|
||||
});
|
||||
if (!g_ggml_sycl_use_async_mem_op) {
|
||||
reorder_event.wait_and_throw();
|
||||
}
|
||||
sycl_ext_free(stream, tmp_buf);
|
||||
}
|
||||
|
||||
static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
|
|
@ -3040,14 +3116,19 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
|
|||
|
||||
const int nblocks = size / sizeof(block_q4_K);
|
||||
|
||||
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
|
||||
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
|
||||
|
||||
sycl::event copy_event;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
|
||||
if (!g_ggml_sycl_use_async_mem_op) {
|
||||
copy_event.wait();
|
||||
}
|
||||
|
||||
auto * qs_ptr = data_device;
|
||||
auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks;
|
||||
auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks);
|
||||
|
||||
stream->parallel_for(nblocks, [=](auto i) {
|
||||
auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
|
||||
const block_q4_K * x = (const block_q4_K *) tmp_buf;
|
||||
const int ib = i;
|
||||
|
||||
|
|
@ -3060,9 +3141,11 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
|
|||
}
|
||||
|
||||
dm_ptr[ib] = x[ib].dm;
|
||||
}).wait_and_throw();
|
||||
|
||||
sycl::free(tmp_buf, *stream);
|
||||
});
|
||||
if (!g_ggml_sycl_use_async_mem_op) {
|
||||
reorder_event.wait_and_throw();
|
||||
}
|
||||
sycl_ext_free(stream, tmp_buf);
|
||||
}
|
||||
|
||||
static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
|
||||
|
|
@ -3071,17 +3154,20 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
|
|||
|
||||
const int nblocks = size / sizeof(block_q6_K);
|
||||
|
||||
auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
|
||||
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));
|
||||
uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size));
|
||||
|
||||
sycl::event copy_event;
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size)));
|
||||
if (!g_ggml_sycl_use_async_mem_op) {
|
||||
copy_event.wait();
|
||||
}
|
||||
|
||||
auto * ql_ptr = data_device;
|
||||
auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
|
||||
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
|
||||
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);
|
||||
|
||||
stream
|
||||
->parallel_for(nblocks,
|
||||
[=](auto i) {
|
||||
auto reorder_event = stream->parallel_for(nblocks, [=](auto i) {
|
||||
const block_q6_K * x = (const block_q6_K *) tmp_buf;
|
||||
const int ib = i;
|
||||
|
||||
|
|
@ -3103,10 +3189,11 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d
|
|||
}
|
||||
|
||||
dm_ptr[ib] = x[ib].d;
|
||||
})
|
||||
.wait_and_throw();
|
||||
|
||||
sycl::free(tmp_buf, *stream);
|
||||
});
|
||||
if (!g_ggml_sycl_use_async_mem_op) {
|
||||
reorder_event.wait_and_throw();
|
||||
}
|
||||
sycl_ext_free(stream, tmp_buf);
|
||||
}
|
||||
|
||||
static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
|
||||
|
|
@ -3535,6 +3622,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
|||
ggml_sycl_op_sum_rows(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
ggml_sycl_op_mean(ctx, dst);
|
||||
}
|
||||
|
||||
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
|
|
@ -3589,6 +3682,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_OP_GET_ROWS:
|
||||
ggml_sycl_get_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET:
|
||||
ggml_sycl_op_set(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SET_ROWS:
|
||||
ggml_sycl_op_set_rows(ctx, dst);
|
||||
break;
|
||||
|
|
@ -3664,6 +3760,18 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_UNARY_OP_ELU:
|
||||
ggml_sycl_elu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
ggml_sycl_floor(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
ggml_sycl_ceil(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
ggml_sycl_round(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
ggml_sycl_trunc(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -3698,6 +3806,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_OP_CONCAT:
|
||||
ggml_sycl_op_concat(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
ggml_sycl_op_pad_reflect_1d(ctx,dst);
|
||||
break;
|
||||
case GGML_OP_UPSCALE:
|
||||
ggml_sycl_upscale(ctx, dst);
|
||||
break;
|
||||
|
|
@ -3784,6 +3895,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_OP_SUM_ROWS:
|
||||
ggml_sycl_sum_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
ggml_sycl_mean(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
ggml_sycl_argsort(ctx, dst);
|
||||
break;
|
||||
|
|
@ -3799,6 +3913,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_sycl_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARANGE:
|
||||
ggml_sycl_arange(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -4001,6 +4118,18 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) {
|
|||
GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__,
|
||||
ggml_op_name(node_op));
|
||||
return false;
|
||||
case GGML_OP_MUL_MAT:
|
||||
// We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available,
|
||||
// as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present
|
||||
// in reordering.
|
||||
if (!g_ggml_sycl_use_async_mem_op) {
|
||||
GGML_LOG_INFO(
|
||||
"%s: disabling SYCL graphs due to unsupported node type when using a compiler without the "
|
||||
"oneAPI async memory allocation extension "
|
||||
"%s\n",
|
||||
__func__, ggml_op_name(node_op));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
@ -4222,6 +4351,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_UNARY_OP_SGN:
|
||||
case GGML_UNARY_OP_ABS:
|
||||
case GGML_UNARY_OP_ELU:
|
||||
case GGML_UNARY_OP_FLOOR:
|
||||
case GGML_UNARY_OP_CEIL:
|
||||
case GGML_UNARY_OP_ROUND:
|
||||
case GGML_UNARY_OP_TRUNC:
|
||||
#if defined (GGML_SYCL_F16)
|
||||
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
|
||||
#else
|
||||
|
|
@ -4295,6 +4428,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_SET:
|
||||
return (op->type == GGML_TYPE_F32) &&
|
||||
(op->src[0] && op->src[1]) &&
|
||||
(op->src[0]->type == GGML_TYPE_F32) &&
|
||||
(op->src[1]->type == GGML_TYPE_F32);
|
||||
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
|
||||
|
|
@ -4393,6 +4532,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_DIV:
|
||||
case GGML_OP_REPEAT:
|
||||
return true;
|
||||
case GGML_OP_PAD_REFLECT_1D:
|
||||
return ggml_is_contiguous(op->src[0]) && op-> type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SQR:
|
||||
case GGML_OP_SQRT:
|
||||
case GGML_OP_SIN:
|
||||
|
|
@ -4431,6 +4572,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGSORT:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_POOL_2D:
|
||||
|
|
@ -4444,6 +4586,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
return true;
|
||||
case GGML_OP_ARANGE:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,72 @@
|
|||
#include "pad_reflect_1d.hpp"
|
||||
|
||||
void pad_reflect_1d_f32(const float* src,float* dst,
|
||||
const int64_t ne0, const int64_t ne02, const int p0, const int p1,
|
||||
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
|
||||
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
|
||||
const sycl::nd_item<3> &item_ct1){
|
||||
|
||||
const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0);
|
||||
const int i1 = item_ct1.get_group(1);
|
||||
const int g2 = item_ct1.get_group(2);
|
||||
const int i2 = g2 % ne02;
|
||||
const int i3 = g2 / ne02;
|
||||
|
||||
if (i0 >= p0 + ne0 + p1) return;
|
||||
|
||||
int t = i0 - p0;
|
||||
int period = 2 * ne0 -2;
|
||||
int m = t % period;
|
||||
m += (m < 0) * period;
|
||||
int center = ne0 -1;
|
||||
int srci0 = center - abs(center - m);
|
||||
|
||||
int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0;
|
||||
int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00;
|
||||
dst[offest_dst] = src[offest_src];
|
||||
|
||||
}
|
||||
|
||||
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
queue_ptr stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int32_t * opts = (const int32_t *) dst->op_params;
|
||||
const int p0 = opts[0];
|
||||
const int p1 = opts[1];
|
||||
|
||||
const int64_t ne0 = src0->ne[0];
|
||||
|
||||
const int64_t ne00 = dst->ne[0];
|
||||
const int64_t ne01 = dst->ne[1];
|
||||
const int64_t ne02 = dst->ne[2];
|
||||
const int64_t ne03 = dst->ne[3];
|
||||
|
||||
const int64_t nb00 = dst->nb[0];
|
||||
const int64_t nb01 = dst->nb[1];
|
||||
const int64_t nb02 = dst->nb[2];
|
||||
const int64_t nb03 = dst->nb[3];
|
||||
const int64_t nb0 = src0->nb[0];
|
||||
const int64_t nb1 = src0->nb[1];
|
||||
const int64_t nb2 = src0->nb[2];
|
||||
const int64_t nb3 = src0->nb[3];
|
||||
|
||||
int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
|
||||
sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03);
|
||||
sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1);
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(global,
|
||||
local),
|
||||
[=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32(
|
||||
(const float *) src0->data, (float *) dst->data,
|
||||
ne0, ne02, p0, p1,
|
||||
nb0, nb1, nb2, nb3,
|
||||
nb00, nb01, nb02, nb03
|
||||
, item_ct1);
|
||||
});
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
#ifndef GGML_SYCL_PAD_REFLECT_1D_HPP
|
||||
#define GGML_SYCL_PAD_REFLECT_1D_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
#endif // GGML_SYCL_PAD_REFLECT_1D_HPP
|
||||
|
|
@ -31,6 +31,7 @@
|
|||
#define SYCL_SQRT_BLOCK_SIZE 256
|
||||
#define SYCL_SIN_BLOCK_SIZE 256
|
||||
#define SYCL_SQR_BLOCK_SIZE 256
|
||||
#define SYCL_SET_BLOCK_SIZE 256
|
||||
#define SYCL_CPY_BLOCK_SIZE 32
|
||||
#define SYCL_SCALE_BLOCK_SIZE 256
|
||||
#define SYCL_CLAMP_BLOCK_SIZE 256
|
||||
|
|
@ -49,6 +50,7 @@
|
|||
#define SYCL_ARGMAX_BLOCK_SIZE 256
|
||||
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
||||
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
||||
#define SYCL_ARANGE_BLOCK_SIZE 256
|
||||
|
||||
// dmmv = dequantize_mul_mat_vec
|
||||
#ifndef GGML_SYCL_DMMV_X
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
#include "presets.hpp"
|
||||
#include "common.hpp"
|
||||
#include "ggml.h"
|
||||
#include "set.hpp"
|
||||
#include <cstdint>
|
||||
#include <sycl/sycl.hpp>
|
||||
using namespace sycl;
|
||||
|
||||
// Internal function: perform element-wise set operation for each thread
|
||||
inline void set_f32(const float* src, float* dst,
|
||||
const int64_t ne0, const int64_t ne1,
|
||||
const int64_t ne2, const int64_t ne3,
|
||||
const int64_t nb[3], const int64_t src_nb[3],
|
||||
const int64_t offset_elem,
|
||||
const nd_item<1>& item)
|
||||
{
|
||||
const size_t idx = item.get_global_id(0);
|
||||
const size_t total = ne0 * ne1 * ne2 * ne3;
|
||||
if (idx >= total) return;
|
||||
|
||||
// Convert linear index to 4D indices
|
||||
const size_t i3 = idx / (ne2 * ne1 * ne0);
|
||||
const size_t rem = idx % (ne2 * ne1 * ne0);
|
||||
const size_t i2 = rem / (ne1 * ne0);
|
||||
const size_t rem2 = rem % (ne1 * ne0);
|
||||
const size_t i1 = rem2 / ne0;
|
||||
const size_t i0 = rem2 % ne0;
|
||||
|
||||
// Compute source and destination indices and copy
|
||||
dst[i0 + i1*nb[0] + i2*nb[1] + i3*nb[2] + offset_elem] =
|
||||
src[i0 + i1*src_nb[0] + i2*src_nb[1] + i3*src_nb[2]];
|
||||
}
|
||||
|
||||
// Main function: prepare GPU queue and launch parallel_for
|
||||
void ggml_sycl_op_set(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
const ggml_tensor* src0 = dst->src[0];
|
||||
const ggml_tensor* src1 = dst->src[1];
|
||||
|
||||
// Ensure shapes and types are compatible
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(dst->type == src0->type && src0->type == src1->type && dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int32_t* opts = (const int32_t*) dst->op_params;
|
||||
const int64_t nb[3] = {opts[0]/sizeof(float), opts[1]/sizeof(float), opts[2]/sizeof(float)};
|
||||
const int64_t offset_elem = opts[3] / sizeof(float);
|
||||
const bool inplace = opts[4];
|
||||
|
||||
float* dst_ptr = (float*) dst->data;
|
||||
const float* src0_ptr = (const float*) src0->data;
|
||||
const float* src1_ptr = (const float*) src1->data;
|
||||
|
||||
queue_ptr stream = ctx.stream();
|
||||
|
||||
// Copy src0 to dst if not inplace
|
||||
if (!inplace)
|
||||
stream->memcpy(dst_ptr, src0_ptr, ggml_nbytes(dst));
|
||||
|
||||
const int64_t ne[4] = {src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]};
|
||||
const int64_t src_nb[3] = {src1->nb[1]/sizeof(float), src1->nb[2]/sizeof(float), src1->nb[3]/sizeof(float)};
|
||||
|
||||
const size_t total_threads = ne[0]*ne[1]*ne[2]*ne[3];
|
||||
const size_t grid_size = ((total_threads + SYCL_SET_BLOCK_SIZE - 1) / SYCL_SET_BLOCK_SIZE) * SYCL_SET_BLOCK_SIZE;
|
||||
|
||||
// Copy src0 to dst if not inplace
|
||||
stream->parallel_for(
|
||||
nd_range<1>(range<1>(grid_size), range<1>(SYCL_SET_BLOCK_SIZE)),
|
||||
[=](nd_item<1> item) {
|
||||
set_f32(src1_ptr, dst_ptr,
|
||||
ne[0], ne[1], ne[2], ne[3],
|
||||
nb, src_nb, offset_elem, item); }
|
||||
);
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
#pragma once
|
||||
#include "backend.hpp"
|
||||
#include "ggml.h"
|
||||
|
||||
void ggml_sycl_op_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
|
@ -385,6 +385,14 @@ enum shader_reduction_mode {
|
|||
|
||||
static constexpr uint32_t num_argsort_pipelines = 11;
|
||||
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
||||
static constexpr uint32_t num_topk_moe_pipelines = 10;
|
||||
|
||||
static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
||||
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
||||
static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
||||
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
||||
|
||||
|
||||
struct vk_device_struct {
|
||||
std::recursive_mutex mutex;
|
||||
|
|
@ -582,6 +590,9 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_pool2d_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
vk_pipeline pipeline_ssm_scan_f32_d128;
|
||||
vk_pipeline pipeline_ssm_scan_f32_d256;
|
||||
vk_pipeline pipeline_ssm_conv_f32;
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
vk_pipeline pipeline_opt_step_sgd_f32;
|
||||
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
||||
|
|
@ -595,6 +606,9 @@ struct vk_device_struct {
|
|||
|
||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||
|
||||
// [2] is {!norm, norm}
|
||||
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
||||
|
||||
std::vector<vk_pipeline_ref> all_pipelines;
|
||||
|
||||
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
|
||||
|
|
@ -938,6 +952,11 @@ struct vk_op_multi_add_push_constants {
|
|||
static_assert(MAX_PARAMETER_COUNT == 12);
|
||||
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
||||
|
||||
struct vk_op_topk_moe_push_constants {
|
||||
uint32_t n_rows;
|
||||
uint32_t n_expert_used;
|
||||
};
|
||||
|
||||
struct vk_op_add_id_push_constants {
|
||||
uint32_t ne0;
|
||||
uint32_t ne1;
|
||||
|
|
@ -1087,6 +1106,19 @@ struct vk_op_rwkv_wkv7_push_constants {
|
|||
uint32_t C;
|
||||
uint32_t H;
|
||||
};
|
||||
struct vk_op_ssm_scan_push_constants {
|
||||
uint32_t nb02, nb03, nb12, nb13;
|
||||
uint32_t nb21, nb22, nb31;
|
||||
uint32_t nb42, nb43, nb52, nb53;
|
||||
uint32_t s_off;
|
||||
uint32_t n_head, d_head, n_group, n_tok;
|
||||
};
|
||||
struct vk_op_ssm_conv_push_constants {
|
||||
uint32_t nb01, nb02;
|
||||
uint32_t nb11;
|
||||
uint32_t dst_nb0, dst_nb1, dst_nb2;
|
||||
uint32_t nc, ncs, nr, n_t, n_s;
|
||||
};
|
||||
|
||||
struct vk_op_conv2d_push_constants {
|
||||
uint32_t Cout;
|
||||
|
|
@ -3591,6 +3623,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
|
@ -3701,6 +3738,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
|
||||
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
|
||||
}
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
|
@ -7983,6 +8025,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
|
||||
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
||||
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
||||
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||
return ctx->device->pipeline_topk_moe[idx][with_norm];
|
||||
}
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
||||
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
||||
}
|
||||
|
|
@ -8098,6 +8147,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_rwkv_wkv7_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
const uint32_t d_state = src0->ne[0];
|
||||
if (d_state == 128) {
|
||||
return ctx->device->pipeline_ssm_scan_f32_d128;
|
||||
} else if (d_state == 256) {
|
||||
return ctx->device->pipeline_ssm_scan_f32_d256;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SSM_CONV:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_ssm_conv_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||
|
|
@ -8592,6 +8656,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
}
|
||||
}
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
{
|
||||
const uint32_t nr = src0->ne[1];
|
||||
const uint32_t n_t = dst->ne[1];
|
||||
const uint32_t n_s = dst->ne[2];
|
||||
elements = { nr, n_t, n_s };
|
||||
}
|
||||
break;
|
||||
default:
|
||||
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
|
||||
break;
|
||||
|
|
@ -9038,6 +9110,117 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
const ggml_tensor * src3 = dst->src[3];
|
||||
const ggml_tensor * src4 = dst->src[4];
|
||||
const ggml_tensor * src5 = dst->src[5];
|
||||
|
||||
GGML_ASSERT(dst->buffer != nullptr);
|
||||
|
||||
const uint32_t head_dim = src0->ne[1];
|
||||
const uint32_t n_head = src1->ne[1];
|
||||
const uint32_t n_group = src4->ne[1];
|
||||
const uint32_t n_tok = src1->ne[2];
|
||||
const uint32_t n_seq = src1->ne[3];
|
||||
|
||||
bool is_mamba2 = (src3->nb[1] == sizeof(float));
|
||||
GGML_ASSERT(is_mamba2);
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
|
||||
GGML_ASSERT(pipeline != nullptr);
|
||||
|
||||
if (dryrun) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
|
||||
|
||||
const vk_op_ssm_scan_push_constants pc = {
|
||||
(uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
|
||||
(uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
|
||||
(uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
|
||||
(uint32_t)src3->nb[1],
|
||||
(uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
|
||||
(uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
|
||||
(uint32_t)s_off,
|
||||
n_head, head_dim, n_group, n_tok
|
||||
};
|
||||
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC];
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
||||
}
|
||||
|
||||
vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr };
|
||||
size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 };
|
||||
bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false };
|
||||
|
||||
if (ctx->device->uma) {
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
|
||||
srcs_uma[i] = d_srcs[i] != nullptr;
|
||||
}
|
||||
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
||||
dst_uma = d_D != nullptr;
|
||||
}
|
||||
|
||||
if (!dst_uma) {
|
||||
d_D = dst_buf_ctx->dev_buffer;
|
||||
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
if (!srcs_uma[i]) {
|
||||
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
|
||||
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
|
||||
}
|
||||
}
|
||||
|
||||
size_t dst_size = ggml_nbytes(dst);
|
||||
size_t src_sizes[GGML_MAX_SRC];
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
src_sizes[i] = ggml_nbytes(dst->src[i]);
|
||||
}
|
||||
|
||||
std::array<uint32_t, 3> elements;
|
||||
|
||||
const int splitH = 16;
|
||||
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
|
||||
const uint32_t num_workgroups_y = n_seq;
|
||||
elements = { num_workgroups_x, num_workgroups_y, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
|
||||
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
|
||||
(uint32_t)src1->nb[1],
|
||||
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
|
||||
(uint32_t)src1->ne[0],
|
||||
(uint32_t)src0->ne[0],
|
||||
(uint32_t)src0->ne[1],
|
||||
(uint32_t)dst->ne[1],
|
||||
(uint32_t)dst->ne[2],
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
|
||||
const ggml_tensor * x = dst->src[0];
|
||||
const ggml_tensor * g = dst->src[1];
|
||||
|
|
@ -9434,6 +9617,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
|
||||
|
||||
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
||||
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
||||
ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||
ggml_tensor * ids = cgraph->nodes[node_idx + 3];
|
||||
|
||||
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
|
||||
const int n_experts = logits->ne[0];
|
||||
const int n_rows = logits->ne[1];
|
||||
const int n_expert_used = weights->ne[1];
|
||||
|
||||
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
|
||||
|
||||
if (dryrun) {
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context;
|
||||
ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context;
|
||||
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
|
||||
|
||||
vk_buffer d_logits = nullptr;
|
||||
size_t logits_buf_offset = 0;
|
||||
vk_buffer d_weights = nullptr;
|
||||
size_t weights_buf_offset = 0;
|
||||
vk_buffer d_ids = nullptr;
|
||||
size_t ids_buf_offset = 0;
|
||||
|
||||
bool logits_uma = false;
|
||||
bool weights_uma = false;
|
||||
bool ids_uma = false;
|
||||
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset);
|
||||
ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset);
|
||||
ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
|
||||
logits_uma = d_logits != nullptr;
|
||||
weights_uma = d_weights != nullptr;
|
||||
ids_uma = d_ids != nullptr;
|
||||
}
|
||||
|
||||
if (!logits_uma) {
|
||||
d_logits = logits_buf_ctx->dev_buffer;
|
||||
logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs;
|
||||
GGML_ASSERT(d_logits != nullptr);
|
||||
}
|
||||
if (!weights_uma) {
|
||||
d_weights = weights_buf_ctx->dev_buffer;
|
||||
weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs;
|
||||
GGML_ASSERT(d_weights != nullptr);
|
||||
}
|
||||
if (!ids_uma) {
|
||||
d_ids = ids_buf_ctx->dev_buffer;
|
||||
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
|
||||
GGML_ASSERT(d_ids != nullptr);
|
||||
}
|
||||
|
||||
vk_op_topk_moe_push_constants pc;
|
||||
pc.n_rows = n_rows;
|
||||
pc.n_expert_used = n_expert_used;
|
||||
|
||||
GGML_ASSERT(n_expert_used <= n_experts);
|
||||
|
||||
const uint32_t rows_per_block = 4;
|
||||
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset),
|
||||
}, pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
|
||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
||||
const int mode = ((int32_t *) dst->op_params)[2];
|
||||
|
|
@ -10870,6 +11134,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
|
|
@ -11017,11 +11283,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
ctx->unsynced_nodes_read.clear();
|
||||
ggml_vk_sync_buffers(ctx, compute_ctx);
|
||||
}
|
||||
// Add the last fused node and all fused source nodes to the unsynchronized list.
|
||||
const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
|
||||
ctx->unsynced_nodes_written.push_back(last_node);
|
||||
// Add all fused nodes to the unsynchronized lists.
|
||||
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
|
||||
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
|
||||
// Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
|
||||
ctx->unsynced_nodes_written.push_back(cur_node);
|
||||
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
|
||||
if (!cur_node->src[j]) {
|
||||
continue;
|
||||
|
|
@ -11188,7 +11454,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
|
||||
} else {
|
||||
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
||||
}
|
||||
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX_BACK:
|
||||
|
|
@ -11287,6 +11557,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|||
|
||||
break;
|
||||
|
||||
case GGML_OP_SSM_SCAN:
|
||||
ggml_vk_ssm_scan(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_vk_ssm_conv(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
|
|
@ -11398,6 +11678,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
|
|
@ -11972,6 +12254,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
||||
int node_idx, bool with_norm) {
|
||||
|
||||
if (with_norm) {
|
||||
if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
|
||||
if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < topk_moe.size(); ++i) {
|
||||
if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
|
||||
const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
||||
|
||||
const float * op_params = (const float *)softmax->op_params;
|
||||
|
||||
float scale = op_params[0];
|
||||
float max_bias = op_params[1];
|
||||
|
||||
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (scale != 1.0f || max_bias != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// don't fuse when masks or sinks are present
|
||||
if (softmax->src[1] || softmax->src[2]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int n_expert = softmax->ne[0];
|
||||
// n_expert must be a power of 2
|
||||
if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that the nodes don't have any unexpected uses
|
||||
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
||||
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
|
||||
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
||||
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
|
||||
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
|
||||
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
|
||||
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
|
||||
|
||||
// softmax is used by reshape and argsort
|
||||
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
|
||||
reshape1->src[0] != softmax ||
|
||||
argsort->src[0] != softmax) {
|
||||
return false;
|
||||
}
|
||||
// reshape is used by get_rows
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
|
||||
get_rows->src[0] != reshape1) {
|
||||
return false;
|
||||
}
|
||||
// argsort is used by view
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
|
||||
view->src[0] != argsort) {
|
||||
return false;
|
||||
}
|
||||
// view is written (via argsort), we can skip checking it
|
||||
|
||||
if (with_norm) {
|
||||
// get_rows is used by reshape
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
|
||||
reshape5->src[0] != get_rows) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// reshape is used by sum_rows and div
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
|
||||
sum_rows->src[0] != reshape5 ||
|
||||
div->src[0] != reshape5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// sum_rows is used by div
|
||||
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
|
||||
div->src[1] != sum_rows) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// div/reshape are written
|
||||
if (reshape8->src[0] != div) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ctx->device->subgroup_arithmetic ||
|
||||
!ctx->device->subgroup_shuffle ||
|
||||
!ctx->device->subgroup_require_full_support ||
|
||||
ctx->device->disable_fusion) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
||||
|
||||
const ggml_tensor *first_node = cgraph->nodes[node_idx];
|
||||
|
|
@ -12047,6 +12443,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
ctx->num_additional_fused_ops = num_adds - 1;
|
||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||
ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||
}
|
||||
}
|
||||
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
||||
|
|
@ -12144,6 +12544,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
ctx->num_additional_fused_ops = num_adds - 1;
|
||||
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
||||
ctx->num_additional_fused_ops = 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
||||
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
||||
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
||||
ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -12151,10 +12555,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|||
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
||||
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
||||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
||||
(i + ctx->num_additional_fused_ops == last_node) ||
|
||||
(i + ctx->num_additional_fused_ops >= last_node) ||
|
||||
(almost_ready && !ctx->almost_ready_fence_pending);
|
||||
|
||||
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
|
||||
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
|
||||
|
||||
if (vk_perf_logger_enabled) {
|
||||
if (ctx->compute_ctx.expired()) {
|
||||
|
|
@ -12275,6 +12679,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|||
while (first_unused < graph->n_nodes) {
|
||||
std::vector<int> current_set;
|
||||
|
||||
// Avoid reordering topk_moe_norm
|
||||
if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
|
||||
bool is_topk_moe_norm = true;
|
||||
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||
if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
|
||||
is_topk_moe_norm = false;
|
||||
}
|
||||
}
|
||||
if (is_topk_moe_norm) {
|
||||
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
||||
new_order.push_back(graph->nodes[first_unused + j]);
|
||||
used[first_unused + j] = true;
|
||||
}
|
||||
while (first_unused < graph->n_nodes && used[first_unused]) {
|
||||
first_unused++;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// First, grab the next unused node.
|
||||
current_set.push_back(first_unused);
|
||||
|
||||
|
|
@ -12879,6 +13302,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
return true;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
for (int i = 0; i < 6; i++) {
|
||||
if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint32_t d_state = op->src[0]->ne[0];
|
||||
const uint32_t head_dim = op->src[0]->ne[1];
|
||||
|
||||
bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
|
||||
if (!is_mamba2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
const vk_device& device = ggml_vk_get_device(ctx->device);
|
||||
|
||||
const uint32_t SPLIT_H = 16;
|
||||
|
||||
size_t stateC_size = SPLIT_H * d_state * sizeof(float);
|
||||
|
||||
if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_SSM_CONV:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_CONV_2D:
|
||||
|
|
@ -13223,14 +13687,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
|
||||
struct ggml_context * ggml_ctx = ggml_init(iparams);
|
||||
|
||||
std::array<struct ggml_tensor *, 6> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
std::array<size_t, 6> src_size = {0, 0, 0, 0, 0, 0};
|
||||
std::array<void *, 6> src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"};
|
||||
std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
std::array<size_t, GGML_MAX_SRC> src_size = {};
|
||||
std::array<void *, GGML_MAX_SRC> src_buffer = {};
|
||||
const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
|
||||
|
||||
struct ggml_tensor * tensor_clone = nullptr;
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
ggml_tensor * srci = tensor->src[i];
|
||||
if (fused_rms_norm_mul) {
|
||||
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
|
||||
|
|
@ -13537,6 +14001,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
src_clone[2]);
|
||||
} else if (tensor->op == GGML_OP_ADD_ID) {
|
||||
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
|
||||
} else if (tensor->op == GGML_OP_SSM_SCAN) {
|
||||
tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
|
||||
src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
|
||||
} else if (tensor->op == GGML_OP_SSM_CONV) {
|
||||
tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
|
||||
}
|
||||
else {
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
|
|
@ -13558,7 +14027,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|||
memcpy(comp_result, tensor_clone->data, comp_size);
|
||||
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (src_buffer[i] != nullptr) {
|
||||
free(src_buffer[i]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -345,7 +345,7 @@ void main() {
|
|||
|
||||
float Lfrcp[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
|
|
|||
|
|
@ -380,7 +380,7 @@ void main() {
|
|||
|
||||
float Lfrcp[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
|
|
|||
|
|
@ -121,7 +121,11 @@ void main() {
|
|||
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
||||
|
||||
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
|
||||
#else
|
||||
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
|
||||
#endif
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
|
||||
|
||||
|
|
@ -294,7 +298,7 @@ void main() {
|
|||
|
||||
[[unroll]]
|
||||
for (int k = 0; k < Ldiag.length(); ++k) {
|
||||
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
|
||||
Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
|
||||
}
|
||||
|
||||
O = Ldiag*O;
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ void main() {
|
|||
L = L*ms + vs;
|
||||
}
|
||||
|
||||
L = 1.0 / L;
|
||||
L = (L == 0.0) ? 0.0 : 1.0 / L;
|
||||
|
||||
// D dimension is split across workgroups in the y dimension
|
||||
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(binding = 0) readonly buffer Src0 { float src0[]; };
|
||||
layout(binding = 1) readonly buffer Src1 { float src1[]; };
|
||||
layout(binding = 2) buffer Dst { float dst[]; };
|
||||
|
||||
layout(push_constant) uniform PushConstants {
|
||||
uint nb01; uint nb02;
|
||||
uint nb11;
|
||||
uint dst_nb0; uint dst_nb1; uint dst_nb2;
|
||||
uint nc; uint ncs; uint nr; uint n_t; uint n_s;
|
||||
};
|
||||
|
||||
void main() {
|
||||
const uint global_thread_id = gl_GlobalInvocationID.x;
|
||||
const uint i2 = gl_WorkGroupID.y;
|
||||
const uint i3 = gl_WorkGroupID.z;
|
||||
|
||||
if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i1 = global_thread_id;
|
||||
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
|
||||
const uint src1_base = i1 * (nb11 / 4);
|
||||
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
|
||||
|
||||
float sum = 0.0;
|
||||
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
|
||||
const uint src0_idx = src0_base + i0;
|
||||
const uint src1_idx = src1_base + i0;
|
||||
sum += src0[src0_idx] * src1[src1_idx];
|
||||
}
|
||||
|
||||
dst[dst_idx] = sum;
|
||||
}
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout(constant_id = 0) const uint D_STATE = 128;
|
||||
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
|
||||
layout(constant_id = 2) const uint SPLIT_H = 16;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(binding = 0) readonly buffer Src0 { float s0[]; };
|
||||
layout(binding = 1) readonly buffer Src1 { float x[]; };
|
||||
layout(binding = 2) readonly buffer Src2 { float dt[]; };
|
||||
layout(binding = 3) readonly buffer Src3 { float A[]; };
|
||||
layout(binding = 4) readonly buffer Src4 { float B[]; };
|
||||
layout(binding = 5) readonly buffer Src5 { float C[]; };
|
||||
layout(binding = 6) readonly buffer Src6 { int ids[]; };
|
||||
layout(binding = 7) buffer Dst { float d[]; };
|
||||
|
||||
layout(push_constant) uniform PushConstants {
|
||||
uint nb02; uint nb03; uint nb12; uint nb13;
|
||||
uint nb21; uint nb22; uint nb31;
|
||||
uint nb42; uint nb43; uint nb52; uint nb53;
|
||||
uint s_off;
|
||||
uint n_head;
|
||||
uint d_head;
|
||||
uint n_group;
|
||||
uint n_tok;
|
||||
};
|
||||
|
||||
float softplus(float x) {
|
||||
if (x <= 20.0) {
|
||||
return log(1.0 + exp(x));
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
|
||||
shared float stateC[SPLIT_H * D_STATE];
|
||||
|
||||
void main() {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
|
||||
const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
|
||||
const uint seq_idx = gl_WorkGroupID.y;
|
||||
|
||||
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
|
||||
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
|
||||
const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
|
||||
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
|
||||
const uint A_base_idx = (head_idx * nb31) / 4;
|
||||
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
|
||||
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
|
||||
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
|
||||
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
|
||||
|
||||
const uint stride_x = nb12 / 4;
|
||||
const uint stride_dt = nb21 / 4;
|
||||
const uint stride_B = nb42 / 4;
|
||||
const uint stride_C = nb52 / 4;
|
||||
const uint stride_y = n_head * d_head;
|
||||
|
||||
float state[SPLIT_H];
|
||||
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
||||
state[j] = s0[s0_base_idx + j * D_STATE + tid];
|
||||
}
|
||||
|
||||
for (uint i = 0; i < n_tok; i++) {
|
||||
const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
|
||||
|
||||
const float dA = exp(dt_soft_plus * A[A_base_idx]);
|
||||
|
||||
const float B_val = B[B_base_idx + i * stride_B + tid];
|
||||
const float C_val = C[C_base_idx + i * stride_C + tid];
|
||||
|
||||
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
||||
const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
|
||||
|
||||
state[j] = (state[j] * dA) + (B_val * x_dt);
|
||||
|
||||
stateC[j * D_STATE + tid] = state[j] * C_val;
|
||||
}
|
||||
|
||||
barrier();
|
||||
for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
|
||||
[[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
|
||||
const uint k = (tid % (w >> 1)) +
|
||||
(D_STATE * (tid / (w >> 1))) +
|
||||
j * D_STATE * (D_STATE / (w >> 1));
|
||||
if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
|
||||
stateC[k] += stateC[k + (w >> 1)];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
|
||||
const uint idx = (tid % SUBGROUP_SIZE) +
|
||||
D_STATE * (tid / SUBGROUP_SIZE) +
|
||||
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
|
||||
|
||||
uint lane = tid % SUBGROUP_SIZE;
|
||||
|
||||
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
|
||||
if (idx + offset < SPLIT_H * D_STATE) {
|
||||
stateC[idx] += stateC[idx + offset];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
|
||||
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
|
||||
d[y_base_idx + i * stride_y + k] = stateC[idx];
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
|
||||
d[s_base_idx + j * D_STATE + tid] = state[j];
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_rows;
|
||||
uint n_expert_used;
|
||||
};
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
|
||||
|
||||
layout(constant_id = 0) const uint WARP_SIZE = 32;
|
||||
layout(constant_id = 1) const uint n_experts = 512;
|
||||
layout(constant_id = 2) const bool with_norm = true;
|
||||
|
||||
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
|
||||
|
||||
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
|
||||
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
|
||||
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
|
||||
if (row >= n_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint logits_offset = n_experts * row;
|
||||
const uint weights_offset = n_expert_used * row;
|
||||
const uint ids_offset = n_experts * row;
|
||||
|
||||
float logits_r[experts_per_thread];
|
||||
|
||||
const float INFINITY = 1.0 / 0.0;
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
|
||||
const uint expert = i + gl_LocalInvocationID.x;
|
||||
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
|
||||
}
|
||||
|
||||
float max_val = logits_r[0];
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
max_val = max(val, max_val);
|
||||
}
|
||||
|
||||
max_val = subgroupMax(max_val);
|
||||
|
||||
float wt[experts_per_thread];
|
||||
float tmp = 0.f;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
const float val = logits_r[i];
|
||||
wt[i] = exp(val - max_val);
|
||||
tmp += wt[i];
|
||||
}
|
||||
|
||||
tmp = subgroupAdd(tmp);
|
||||
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 0; i < experts_per_thread; i++) {
|
||||
wt[i] = wt[i] * inv_sum;
|
||||
}
|
||||
|
||||
// at this point, each thread holds a portion of softmax,
|
||||
// we do the argmax reduce over n_expert_used, each time marking
|
||||
// the expert weight as -inf to exclude from the next iteration
|
||||
|
||||
float wt_sum = 0.f;
|
||||
|
||||
float output_weights[experts_per_thread];
|
||||
|
||||
for (int k = 0; k < n_expert_used; k++) {
|
||||
float max_val = wt[0];
|
||||
uint max_expert = gl_LocalInvocationID.x;
|
||||
|
||||
[[unroll]]
|
||||
for (int i = 1; i < experts_per_thread; i++) {
|
||||
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
|
||||
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
|
||||
max_val = wt[i];
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
|
||||
const float val = subgroupShuffleXor(max_val, mask);
|
||||
const uint expert = subgroupShuffleXor(max_expert, mask);
|
||||
if (val > max_val || (val == max_val && expert < max_expert)) {
|
||||
max_val = val;
|
||||
max_expert = expert;
|
||||
}
|
||||
}
|
||||
|
||||
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
||||
output_weights[k / WARP_SIZE] = max_val;
|
||||
}
|
||||
|
||||
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
|
||||
wt[max_expert / WARP_SIZE] = -INFINITY;
|
||||
|
||||
ids[ids_offset + k] = max_expert;
|
||||
if (with_norm) {
|
||||
wt_sum += max_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (with_norm) {
|
||||
wt_sum = subgroupAdd(wt_sum);
|
||||
const float inv_sum = 1.0f / wt_sum;
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||
output_weights[i] *= inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]]
|
||||
for (uint i = 0; i < experts_per_thread; ++i) {
|
||||
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
|
||||
if (idx < n_expert_used) {
|
||||
weights[weights_offset + idx] = output_weights[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -916,6 +916,12 @@ void process_shaders() {
|
|||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||
|
||||
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
|
||||
|
||||
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
|
||||
|
||||
string_to_spv("topk_moe_f32", "topk_moe.comp", {});
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
|
@ -959,7 +965,7 @@ void write_output_files() {
|
|||
}
|
||||
|
||||
std::string suffixes[2] = {"_f32", "_f16"};
|
||||
for (auto op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||
for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||
hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
|
||||
hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
|
||||
|
||||
|
|
|
|||
|
|
@ -6964,6 +6964,78 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
|
|||
GGML_LOG_INFO("========================================\n");
|
||||
}
|
||||
|
||||
static int ggml_node_list_find_tensor(const struct ggml_cgraph * cgraph,
|
||||
const int * idxs,
|
||||
int count,
|
||||
const struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(cgraph && idxs);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
const int node_idx = idxs[i];
|
||||
|
||||
if (node_idx >= cgraph->n_nodes) {
|
||||
return -1;
|
||||
}
|
||||
if (cgraph->nodes[node_idx] == tensor) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph,
|
||||
const int * node_idxs,
|
||||
int count,
|
||||
const enum ggml_op * ops,
|
||||
const int * outputs,
|
||||
int num_outputs) {
|
||||
GGML_ASSERT(outputs && num_outputs > 0);
|
||||
|
||||
for (int i = 0; i < count; ++i) {
|
||||
if (node_idxs[i] >= cgraph->n_nodes) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const struct ggml_tensor * node = cgraph->nodes[node_idxs[i]];
|
||||
|
||||
if (node->op != ops[i]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int subgraph_uses = 0;
|
||||
for (int j = i + 1; j < count; ++j) {
|
||||
const struct ggml_tensor * other_node = cgraph->nodes[node_idxs[j]];
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; src_idx++) {
|
||||
if (other_node->src[src_idx] == node) {
|
||||
subgraph_uses++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (subgraph_uses != ggml_node_get_use_count(cgraph, node_idxs[i])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if node is a view, check if the view_src and all it's parent view_srcs are within the subgraph
|
||||
struct ggml_tensor * view_src = node->view_src;
|
||||
while (view_src) {
|
||||
if (ggml_node_list_find_tensor(cgraph, node_idxs, count, view_src) == -1) {
|
||||
return false;
|
||||
}
|
||||
view_src = view_src->view_src;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// check if node is part of the graph
|
||||
static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) {
|
||||
if (cgraph == NULL) {
|
||||
|
|
|
|||
|
|
@ -102,6 +102,8 @@ class Keys:
|
|||
EXPERT_COUNT = "{arch}.expert_count"
|
||||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
|
||||
EXPERT_GROUP_COUNT = "{arch}.expert_group_count"
|
||||
EXPERT_GROUP_USED_COUNT = "{arch}.expert_group_used_count"
|
||||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
|
||||
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
|
||||
|
|
@ -400,6 +402,7 @@ class MODEL_ARCH(IntEnum):
|
|||
WAVTOKENIZER_DEC = auto()
|
||||
PLM = auto()
|
||||
BAILINGMOE = auto()
|
||||
BAILINGMOE2 = auto()
|
||||
DOTS1 = auto()
|
||||
ARCEE = auto()
|
||||
ERNIE4_5 = auto()
|
||||
|
|
@ -744,6 +747,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
|
||||
MODEL_ARCH.PLM: "plm",
|
||||
MODEL_ARCH.BAILINGMOE: "bailingmoe",
|
||||
MODEL_ARCH.BAILINGMOE2: "bailingmoe2",
|
||||
MODEL_ARCH.DOTS1: "dots1",
|
||||
MODEL_ARCH.ARCEE: "arcee",
|
||||
MODEL_ARCH.ERNIE4_5: "ernie4_5",
|
||||
|
|
@ -2533,6 +2537,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.BAILINGMOE2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.NEXTN_EH_PROJ,
|
||||
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
|
||||
MODEL_TENSOR.NEXTN_ENORM,
|
||||
MODEL_TENSOR.NEXTN_HNORM,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
|
||||
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
],
|
||||
MODEL_ARCH.DOTS1: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
|||
|
|
@ -755,6 +755,12 @@ class GGUFWriter:
|
|||
def add_expert_shared_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_expert_group_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_expert_group_used_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_expert_weights_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
||||
|
||||
|
|
|
|||
|
|
@ -174,6 +174,7 @@ class TensorNameMap:
|
|||
"h.{bid}.self_attention.query_key_value", # bloom
|
||||
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
|
||||
"model.layers.{bid}.self_attn.query_key_value", # persimmon
|
||||
"model.layers.{bid}.attention.query_key_value", # bailingmoe2
|
||||
"h.{bid}.attn.c_attn", # gpt2
|
||||
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
||||
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
|
||||
|
|
@ -260,6 +261,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||
"model.layers.{bid}.self_attn.dense", # persimmon
|
||||
"model.layers.{bid}.attention.dense", # bailingmoe2
|
||||
"h.{bid}.attn.c_proj", # gpt2
|
||||
"transformer.h.{bid}.mixer.out_proj", # phi2
|
||||
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||
|
|
@ -373,6 +375,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.FFN_EXP_PROBS_B: (
|
||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
|
||||
"model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe
|
||||
"model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2
|
||||
"model.layers.{bid}.feed_forward.expert_bias", # lfm2moe
|
||||
),
|
||||
|
||||
|
|
@ -549,6 +552,7 @@ class TensorNameMap:
|
|||
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
|
||||
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.query_layernorm", # hunyuan
|
||||
"model.layers.{bid}.attention.query_layernorm", # bailingmoe2
|
||||
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
|
||||
"layers.{bid}.self_attn.q_norm", # embeddinggemma
|
||||
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
||||
|
|
@ -563,6 +567,7 @@ class TensorNameMap:
|
|||
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
|
||||
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.key_layernorm", # hunyuan
|
||||
"model.layers.{bid}.attention.key_layernorm", # bailingmoe2
|
||||
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
|
||||
"layers.{bid}.self_attn.k_norm", # embeddinggemma
|
||||
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
||||
|
|
@ -584,6 +589,7 @@ class TensorNameMap:
|
|||
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
||||
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
|
||||
"encoder.layer.{bid}.layer_norm_2", # jina-v2-code
|
||||
"model.layers.{bid}.final_layernorm", # bailingmoe2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
|
||||
|
|
|
|||
|
|
@ -14,12 +14,12 @@ except ImportError:
|
|||
SentencePieceProcessor = None
|
||||
|
||||
try:
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from mistral_common.tokens.tokenizers.utils import (
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports]
|
||||
from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports]
|
||||
_filter_valid_tokenizer_files,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports]
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
except ImportError:
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue