diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fe86863893..15e1133095 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1305,6 +1305,81 @@ jobs: cd examples/llama.android ./gradlew build --no-daemon + android-ndk-build: + runs-on: ubuntu-latest + + env: + OPENCL_VERSION: 2025.07.22 + + strategy: + matrix: + include: + - build: 'arm64-cpu' + defines: '-D ANDROID_ABI=arm64-v8a -D ANDROID_PLATFORM=android-31 -D CMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -D GGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm -G Ninja -D LLAMA_CURL=OFF -D GGML_OPENMP=OFF' + - build: 'arm64-snapdragon' + defines: '--preset arm64-android-snapdragon-release' + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Install OpenCL Headers and Libs + id: install_opencl + if: ${{ matrix.build == 'arm64-snapdragon' }} + run: | + mkdir opencl + curl -L -o opencl/clhpp.tar.gz https://github.com/KhronosGroup/OpenCL-CLHPP/archive/refs/tags/v${OPENCL_VERSION}.tar.gz + curl -L -o opencl/headers.tar.gz https://github.com/KhronosGroup/OpenCL-Headers/archive/refs/tags/v${OPENCL_VERSION}.tar.gz + curl -L -o opencl/icd-loader.tar.gz https://github.com/KhronosGroup/OpenCL-ICD-Loader/archive/refs/tags/v${OPENCL_VERSION}.tar.gz + tar -xaf opencl/headers.tar.gz -C opencl + tar -xaf opencl/clhpp.tar.gz -C opencl + tar -xaf opencl/icd-loader.tar.gz -C opencl + sudo cp -r opencl/OpenCL-Headers-${OPENCL_VERSION}/CL ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include + sudo cp -r opencl/OpenCL-CLHPP-${OPENCL_VERSION}/include/CL/* ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include/CL + cd opencl/OpenCL-ICD-Loader-${OPENCL_VERSION} + cmake -B build -G Ninja -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake -DOPENCL_ICD_LOADER_HEADERS_DIR=${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=31 -DANDROID_STL=c++_shared + cmake --build build + sudo cp build/libOpenCL.so ${ANDROID_NDK_ROOT}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android + rm -rf opencl + + - name: Install Hexagon SDK + id: install_hexsdk + if: ${{ matrix.build == 'arm64-snapdragon' }} + env: + HEXSDK_VER: 6.4.0.2 + HEXTLS_VER: 19.0.04 + run: | + curl -L -o hex-sdk.tar.gz https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v$HEXSDK_VER/hexagon-sdk-v$HEXSDK_VER-amd64-lnx.tar.xz + mkdir hex-sdk + tar -xaf hex-sdk.tar.gz -C hex-sdk + ls -l hex-sdk + sudo mv hex-sdk /opt/hexagon + echo "HEXAGON_SDK_ROOT=/opt/hexagon/$HEXSDK_VER" >> "$GITHUB_ENV" + echo "HEXAGON_TOOLS_ROOT=/opt/hexagon/$HEXSDK_VER/tools/HEXAGON_Tools/$HEXTLS_VER" >> "$GITHUB_ENV" + echo "DEFAULT_HLOS_ARCH=64" >> "$GITHUB_ENV" + echo "DEFAULT_TOOLS_VARIANT=toolv19" >> "$GITHUB_ENV" + echo "DEFAULT_NO_QURT_INC=0" >> "$GITHUB_ENV" + echo "DEFAULT_DSP_ARCH=v73" >> "$GITHUB_ENV" + + - name: Update CMake presets + id: update_presets + if: ${{ matrix.build == 'arm64-snapdragon' }} + run: | + cp docs/backend/hexagon/CMakeUserPresets.json . + + - name: Build + id: ndk_build + run: | + cmake ${{ matrix.defines }} -B build + cmake --build build + cmake --install build --prefix pkg-adb/llama.cpp + + - name: Test + id: cmake_test + run: | + echo "FIXME: test on devices" + openEuler-latest-cmake-cann: if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }} defaults: diff --git a/CODEOWNERS b/CODEOWNERS index f833fb7cf4..bacc86cbbd 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -65,6 +65,7 @@ /ggml/src/ggml-impl.h @ggerganov @slaren /ggml/src/ggml-metal/ @ggerganov /ggml/src/ggml-opencl/ @lhez @max-krasnyansky +/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez /ggml/src/ggml-opt.cpp @JohannesGaessler /ggml/src/ggml-quants.* @ggerganov /ggml/src/ggml-rpc/ @rgerganov diff --git a/README.md b/README.md index e373611051..f4206e8d45 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) - [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral) - [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct) +- [x] [Jamba](https://huggingface.co/ai21labs) - [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon) - [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2) - [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne) @@ -280,6 +281,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo | [IBM zDNN](docs/backend/zDNN.md) | IBM Z & LinuxONE | | [WebGPU [In Progress]](docs/build.md#webgpu) | All | | [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All | +| [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon | ## Obtaining and quantizing models diff --git a/common/arg.cpp b/common/arg.cpp index 33ed7ae857..a465eb3623 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3248,7 +3248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", - "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix", + "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", [](common_params & params, const std::string & value) { params.embd_out = value; } @@ -3435,7 +3435,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.use_jinja = true; } - ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA")); add_opt(common_arg( {"--reasoning-format"}, "FORMAT", "controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n" diff --git a/common/chat.cpp b/common/chat.cpp index 8587140e1f..63583fb224 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -9,8 +9,11 @@ #include #include +#include #include +#include #include +#include #include #include #include @@ -640,6 +643,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; case COMMON_CHAT_FORMAT_APERTUS: return "Apertus"; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools"; default: throw std::runtime_error("Unknown chat format"); } @@ -986,6 +990,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat return data; } + +// Case-insensitive find +static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) { + auto it = std::search( + haystack.begin() + pos, haystack.end(), + needle.begin(), needle.end(), + [](char a, char b) { return std::tolower(a) == std::tolower(b); } + ); + return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it); +} + +static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + const auto is_json_schema_provided = !inputs.json_schema.is_null(); + const auto is_grammar_provided = !inputs.grammar.empty(); + const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty(); + + // the logic requires potentially modifying the messages + auto tweaked_messages = inputs.messages; + + auto replace_json_schema_marker = [](json & messages) -> bool { + static std::string marker1 = "force json schema.\n"; + static std::string marker2 = "force json schema."; + + if (messages.empty() || messages.at(0).at("role") != "system") { + return false; + } + + std::string content = messages.at(0).at("content"); + + for (const auto & marker : {marker1, marker2}) { + const auto pos = ifind_string(content, marker); + if (pos != std::string::npos) { + content.replace(pos, marker.length(), ""); + // inject modified content back into the messages + messages.at(0).at("content") = content; + return true; + } + } + + return false; + }; + + // Lfm2 model does not natively work with json, but can generally understand the tools structure + // + // Example of the pytorch dialog structure: + // <|startoftext|><|im_start|>system + // List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|> + // <|im_start|>user + // What is the current status of candidate ID 12345?<|im_end|> + // <|im_start|>assistant + // <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|> + // <|im_start|>tool + // <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|> + // <|im_start|>assistant + // The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|> + // + // For the llama server compatibility with json tools semantic, + // the client can add "Follow json schema." line into the system message prompt to force the json output. + // + if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) { + // server/utils.hpp prohibits that branch for the custom grammar anyways + throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar"); + } else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) { + LOG_INF("%s: Using tools to build a grammar\n", __func__); + + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + auto schemas = json::array(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + schemas.push_back({ + {"type", "object"}, + {"properties", { + {"name", { + {"type", "string"}, + {"const", function.at("name")}, + }}, + {"arguments", function.at("parameters")}, + }}, + {"required", json::array({"name", "arguments", "id"})}, + }); + }); + auto schema = json { + {"type", "array"}, + {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, + {"minItems", 1}, + }; + if (!inputs.parallel_tool_calls) { + schema["maxItems"] = 1; + } + + builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\""); + }); + // model has no concept of tool selection mode choice, + // if the system prompt rendered correctly it will produce a tool call + // the grammar goes inside the tool call body + data.grammar_lazy = true; + data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}}; + data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; + data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS; + } else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) { + LOG_INF("%s: Using tools without json schema or grammar\n", __func__); + // output those tokens + data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; + } else if (is_json_schema_provided) { + LOG_INF("%s: Using provided json schema to build a grammar\n", __func__); + data.grammar = json_schema_to_grammar(inputs.json_schema); + } else if (is_grammar_provided) { + LOG_INF("%s: Using provided grammar\n", __func__); + data.grammar = inputs.grammar; + } else { + LOG_INF("%s: Using content relying on the template\n", __func__); + } + + data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); + LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str()); + + return data; +} + static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -2499,6 +2623,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) { builder.add_content(builder.consume_rest()); } + +static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> + static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); + static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); + + // Loop through all tool calls + while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { + builder.move_to(res->groups[0].end); + + // Parse JSON array format: [{"name": "...", "arguments": {...}}] + auto tool_calls_data = builder.consume_json(); + + // Consume end marker + builder.consume_spaces(); + if (!builder.try_consume_regex(tool_call_end_regex)) { + throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); + } + + // Process each tool call in the array + if (tool_calls_data.json.is_array()) { + for (const auto & tool_call : tool_calls_data.json) { + if (!tool_call.is_object()) { + throw common_chat_msg_partial_exception("Tool call must be an object"); + } + + if (!tool_call.contains("name")) { + throw common_chat_msg_partial_exception("Tool call missing 'name' field"); + } + + std::string function_name = tool_call.at("name"); + std::string arguments = "{}"; + + if (tool_call.contains("arguments")) { + if (tool_call.at("arguments").is_object()) { + arguments = tool_call.at("arguments").dump(); + } else if (tool_call.at("arguments").is_string()) { + arguments = tool_call.at("arguments"); + } + } + + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + } else { + throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); + } + + // Consume any trailing whitespace after this tool call + builder.consume_spaces(); + } + + // Consume any remaining content after all tool calls + auto remaining = builder.consume_rest(); + if (!string_strip(remaining).empty()) { + builder.add_content(remaining); + } +} + static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { // Parse thinking tags first - this handles the main reasoning content builder.try_parse_reasoning("", ""); @@ -2748,6 +2937,12 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_apertus(tmpl, params); } + // LFM2 (w/ tools) + if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos && + src.find("]<|tool_list_end|>") != std::string::npos) { + return common_chat_params_init_lfm2(tmpl, params); + } + // Use generic handler when mixing tools + JSON schema. // TODO: support that mix in handlers below. if ((params.tools.is_array() && params.json_schema.is_object())) { @@ -2926,6 +3121,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_APERTUS: common_chat_parse_apertus(builder); break; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: + common_chat_parse_lfm2(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.h b/common/chat.h index f7b36ec711..50efb0d4e5 100644 --- a/common/chat.h +++ b/common/chat.h @@ -116,6 +116,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_SEED_OSS, COMMON_CHAT_FORMAT_NEMOTRON_V2, COMMON_CHAT_FORMAT_APERTUS, + COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index dd9b51a9e5..478aa1be7b 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -601,7 +601,10 @@ private: } std::string _resolve_ref(const std::string & ref) { - std::string ref_name = ref.substr(ref.find_last_of('/') + 1); + auto it = ref.find('#'); + std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref; + static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)"); + std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-"); if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) { _refs_being_resolved.insert(ref); json resolved = _refs[ref]; @@ -774,11 +777,24 @@ public: std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { std::string sel = tokens[i]; - if (target.is_null() || !target.contains(sel)) { + if (target.is_object() && target.contains(sel)) { + target = target[sel]; + } else if (target.is_array()) { + size_t sel_index; + try { + sel_index = std::stoul(sel); + } catch (const std::invalid_argument & e) { + sel_index = target.size(); + } + if (sel_index >= target.size()) { + _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); + return; + } + target = target[sel_index]; + } else { _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); return; } - target = target[sel]; } _refs[ref] = target; } diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 5cfb4f29df..8a1d37c5ad 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -29,12 +29,29 @@ if 'NO_LOCAL_GGUF' not in os.environ: sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) import gguf from gguf.vocab import MistralTokenizerType, MistralVocab -from mistral_common.tokens.tokenizers.base import TokenizerVersion -from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN, DATASET_STD -from mistral_common.tokens.tokenizers.tekken import Tekkenizer -from mistral_common.tokens.tokenizers.sentencepiece import ( - SentencePieceTokenizer, -) + +try: + from mistral_common.tokens.tokenizers.base import TokenizerVersion # pyright: ignore[reportMissingImports] + from mistral_common.tokens.tokenizers.multimodal import DATASET_MEAN as _MISTRAL_COMMON_DATASET_MEAN, DATASET_STD as _MISTRAL_COMMON_DATASET_STD # pyright: ignore[reportMissingImports] + from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports] + from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports] + SentencePieceTokenizer, + ) + + _mistral_common_installed = True + _mistral_import_error_msg = "" +except ImportError: + _MISTRAL_COMMON_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) + _MISTRAL_COMMON_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) + + _mistral_common_installed = False + TokenizerVersion = None + Tekkenizer = None + SentencePieceTokenizer = None + _mistral_import_error_msg = ( + "Mistral format requires `mistral-common` to be installed. Please run " + "`pip install mistral-common[image,audio]` to install it." + ) logger = logging.getLogger("hf-to-gguf") @@ -73,10 +90,8 @@ class ModelBase: use_temp_file: bool lazy: bool dry_run: bool - part_names: list[str] - is_safetensors: bool hparams: dict[str, Any] - tensor_names: set[str] | None + model_tensors: dict[str, Callable[[], Tensor]] gguf_writer: gguf.GGUFWriter model_name: str | None metadata_override: Path | None @@ -107,6 +122,9 @@ class ModelBase: type(self) is MmprojModel: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") + if self.is_mistral_format and not _mistral_common_installed: + raise ImportError(_mistral_import_error_msg) + self.dir_model = dir_model self.ftype = ftype self.fname_out = fname_out @@ -117,25 +135,8 @@ class ModelBase: self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id self.sentence_transformers_dense_modules = sentence_transformers_dense_modules - if remote_hf_model_id is not None: - self.is_safetensors = True - - def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: - logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") - remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) - self.tensor_names = set(name for name in remote_tensors.keys()) - for name, remote_tensor in remote_tensors.items(): - yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor)) - - self.get_tensors = get_remote_tensors - else: - prefix = "model" if not self.is_mistral_format else "consolidated" - self.part_names = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors") - self.is_safetensors = len(self.part_names) > 0 - if not self.is_safetensors: - self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams - self.tensor_names = None + self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id) self.metadata_override = metadata_override self.model_name = model_name self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py @@ -151,6 +152,8 @@ class ModelBase: logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") self.ftype = gguf.LlamaFileType.MOSTLY_BF16 + self.dequant_model() + # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @@ -172,67 +175,215 @@ class ModelBase: return None raise KeyError(f"could not find any of: {keys}") - def get_tensors(self) -> Iterator[tuple[str, Tensor]]: - tensor_names_from_parts: set[str] = set() + def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]: + tensors: dict[str, Callable[[], Tensor]] = {} + + if remote_hf_model_id is not None: + is_safetensors = True + + logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") + remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) + for name, remote_tensor in remote_tensors.items(): + tensors[name] = lambda r=remote_tensor: LazyTorchTensor.from_remote_tensor(r) + + return tensors + + prefix = "model" if not self.is_mistral_format else "consolidated" + part_names: list[str] = ModelBase.get_model_part_names(self.dir_model, prefix, ".safetensors") + is_safetensors: bool = len(part_names) > 0 + if not is_safetensors: + part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") + + tensor_names_from_index: set[str] = set() if not self.is_mistral_format: - index_name = "model.safetensors" if self.is_safetensors else "pytorch_model.bin" + index_name = "model.safetensors" if is_safetensors else "pytorch_model.bin" index_name += ".index.json" index_file = self.dir_model / index_name if index_file.is_file(): - self.tensor_names = set() logger.info(f"gguf: loading model weight map from '{index_name}'") with open(index_file, "r", encoding="utf-8") as f: index: dict[str, Any] = json.load(f) weight_map = index.get("weight_map") if weight_map is None or not isinstance(weight_map, dict): raise ValueError(f"Can't load 'weight_map' from {index_name!r}") - self.tensor_names.update(weight_map.keys()) + tensor_names_from_index.update(weight_map.keys()) else: - self.tensor_names = tensor_names_from_parts weight_map = {} else: - self.tensor_names = tensor_names_from_parts weight_map = {} - for part_name in self.part_names: - logger.info(f"gguf: loading model part '{part_name}'") + for part_name in part_names: + logger.info(f"gguf: indexing model part '{part_name}'") ctx: ContextManager[Any] - if self.is_safetensors: + if is_safetensors: from safetensors import safe_open ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) else: ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) with ctx as model_part: - tensor_names_from_parts.update(model_part.keys()) + assert model_part is not None for name in model_part.keys(): - if self.is_safetensors: + if is_safetensors: if self.lazy: data = model_part.get_slice(name) - data = LazyTorchTensor.from_safetensors_slice(data) + data_gen = lambda data=data: LazyTorchTensor.from_safetensors_slice(data) # noqa: E731 else: data = model_part.get_tensor(name) + data_gen = lambda data=data: data # noqa: E731 else: data = model_part[name] if self.lazy: - data = LazyTorchTensor.from_eager(data) - yield name, data + data_gen = lambda data=data: LazyTorchTensor.from_eager(data) # noqa: E731 + else: + data_gen = lambda data=data: data # noqa: E731 + tensors[name] = data_gen # verify tensor name presence and identify potentially missing files - if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0: - missing = sorted(self.tensor_names.difference(tensor_names_from_parts)) - extra = sorted(tensor_names_from_parts.difference(self.tensor_names)) - missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map)) - if len(extra) == 0 and len(missing_files) > 0: - raise ValueError(f"Missing or incomplete model files: {missing_files}\n" - f"Missing tensors: {missing}") + if len(tensor_names_from_index) > 0: + tensor_names_from_parts = set(tensors.keys()) + if len(tensor_names_from_parts.symmetric_difference(tensor_names_from_index)) > 0: + missing = sorted(tensor_names_from_index.difference(tensor_names_from_parts)) + extra = sorted(tensor_names_from_parts.difference(tensor_names_from_index)) + missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map)) + if len(extra) == 0 and len(missing_files) > 0: + raise ValueError(f"Missing or incomplete model files: {missing_files}\n" + f"Missing tensors: {missing}") + else: + raise ValueError("Mismatch between weight map and model parts for tensor names:\n" + f"Missing tensors: {missing}\n" + f"Extra tensors: {extra}") + + return tensors + + def dequant_model(self): + tensors_to_remove: list[str] = [] + new_tensors: dict[str, Callable[[], Tensor]] = {} + + if (quant_config := self.hparams.get("quantization_config")) and isinstance(quant_config, dict): + quant_method = quant_config.get("quant_method") + + def dequant_bitnet(weight: Tensor, scale: Tensor) -> Tensor: + weight = weight.view(torch.uint8) + orig_shape = weight.shape + + shift = torch.tensor([0, 2, 4, 6], dtype=torch.uint8).reshape((4, *(1 for _ in range(len(orig_shape))))) + data = weight.unsqueeze(0).expand((4, *orig_shape)) >> shift + data = data & 3 + data = (data.float() - 1).reshape((orig_shape[0] * 4, *orig_shape[1:])) + + # The scale is inverted + return data / scale.float() + + def dequant_simple(weight: Tensor, scale: Tensor) -> Tensor: + scale = scale.float() + + if (weight_block_size := quant_config.get("weight_block_size")): + # TODO: make sure it's a list of integers + for i, size in enumerate(weight_block_size): + scale = scale.repeat_interleave(size, i) + # unpad the scale (e.g. when the tensor size isn't a multiple of the block size) + scale = scale[tuple(slice(0, size) for size in weight.shape)] + + return weight.float() * scale + + # ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476 + def dequant_gptq(g_idx: Tensor, qweight: Tensor, qzeros: Tensor, scales: Tensor) -> Tensor: + bits = quant_config["bits"] + assert bits in (2, 3, 4, 8) + assert qweight.dtype == qzeros.dtype + maxq = (2 ** bits) - 1 + weight = None + zeros = None + pack_dtype_bits = qweight.dtype.itemsize * 8 + + if bits in [2, 4, 8]: + pack_factor = pack_dtype_bits // bits + wf = torch.tensor(list(range(0, pack_dtype_bits, bits)), dtype=torch.int32).unsqueeze(0) + if self.lazy: + wf = LazyTorchTensor.from_eager(wf) + + zeros = torch.bitwise_right_shift( + qzeros.unsqueeze(2).expand(-1, -1, pack_factor), + wf.unsqueeze(0) + ).to(torch.int16 if bits == 8 else torch.int8) + zeros = torch.bitwise_and(zeros, maxq).reshape(scales.shape) + + weight = torch.bitwise_and( + torch.bitwise_right_shift( + qweight.unsqueeze(1).expand(-1, pack_factor, -1), + wf.unsqueeze(-1) + ).to(torch.int16 if bits == 8 else torch.int8), + maxq + ) + elif bits == 3: + raise NotImplementedError("3-bit gptq dequantization is not yet implemented") + + assert weight is not None + assert zeros is not None + + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + # gptq_v2 doesn't need to offset zeros + if quant_config.get("checkpoint_format", "gptq") == "gptq": + zeros += 1 + + return (scales[g_idx].float() * (weight - zeros[g_idx]).float()).T + + if quant_method == "bitnet": + for name in self.model_tensors.keys(): + if name.endswith(".weight_scale"): + weight_name = name.removesuffix("_scale") + w = self.model_tensors[weight_name] + s = self.model_tensors[name] + self.model_tensors[weight_name] = lambda w=w, s=s: dequant_bitnet(w(), s()) + tensors_to_remove.append(name) + elif quant_method == "fp8": + for name in self.model_tensors.keys(): + if name.endswith(".weight_scale_inv"): + weight_name = name.removesuffix("_scale_inv") + w = self.model_tensors[weight_name] + s = self.model_tensors[name] + self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s()) + tensors_to_remove.append(name) + elif quant_method == "gptq": + for name in self.model_tensors.keys(): + if name.endswith(".qweight"): + base_name = name.removesuffix(".qweight") + g_idx = self.model_tensors[base_name + ".g_idx"] + qweight = self.model_tensors[base_name + ".qweight"] + qzeros = self.model_tensors[base_name + ".qzeros"] + scales = self.model_tensors[base_name + ".scales"] + new_tensors[base_name + ".weight"] = ( + lambda g=g_idx, z=qzeros, w=qweight, s=scales: dequant_gptq( + g(), w(), z(), s() + ) + ) + tensors_to_remove += [ + base_name + n + for n in ( + ".g_idx", + ".qzeros", + ".qweight", + ".scales", + ) + ] else: - raise ValueError("Mismatch between weight map and model parts for tensor names:\n" - f"Missing tensors: {missing}\n" - f"Extra tensors: {extra}") + raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}") + + for name in tensors_to_remove: + if name in self.model_tensors: + del self.model_tensors[name] + + for name, value in new_tensors.items(): + self.model_tensors[name] = value + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + for name, gen in self.model_tensors.items(): + yield name, gen() def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: if key not in gguf.MODEL_TENSORS[self.model_arch]: @@ -591,6 +742,12 @@ class TextModel(ModelBase): if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: self.gguf_writer.add_expert_used_count(n_experts_used) logger.info(f"gguf: experts used count = {n_experts_used}") + if (n_expert_groups := self.hparams.get("n_group")) is not None: + self.gguf_writer.add_expert_group_count(n_expert_groups) + logger.info(f"gguf: expert groups count = {n_expert_groups}") + if (n_group_used := self.hparams.get("topk_group")) is not None: + self.gguf_writer.add_expert_group_used_count(n_group_used) + logger.info(f"gguf: expert groups used count = {n_group_used}") if (head_dim := self.hparams.get("head_dim")) is not None: self.gguf_writer.add_key_length(head_dim) @@ -1349,6 +1506,17 @@ class MmprojModel(ModelBase): def set_type(self): self.gguf_writer.add_type(gguf.GGUFType.MMPROJ) + def prepare_metadata(self, vocab_only: bool): + super().prepare_metadata(vocab_only=vocab_only) + + output_type: str = self.ftype.name.partition("_")[2] + + if self.fname_out.is_dir(): + fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=output_type, model_type=None) + self.fname_out = self.fname_out / f"mmproj-{fname_default}.gguf" + else: + self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type) + def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) @@ -1366,8 +1534,8 @@ class MmprojModel(ModelBase): self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"])) # preprocessor config - image_mean = DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] - image_std = DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"] + image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] + image_std = _MISTRAL_COMMON_DATASET_STD if self.is_mistral_format else self.preprocessor_config["image_std"] self.gguf_writer.add_vision_image_mean(image_mean) self.gguf_writer.add_vision_image_std(image_std) @@ -2036,6 +2204,9 @@ class LlamaModel(TextModel): self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) def _set_vocab_mistral(self): + if not _mistral_common_installed: + raise ImportError(_mistral_import_error_msg) + vocab = MistralVocab(self.dir_model) logger.info( f"Converting tokenizer {vocab.tokenizer_type} of size {vocab.vocab_size}." @@ -2292,18 +2463,21 @@ class ArceeModel(LlamaModel): ) class LlavaVisionModel(MmprojModel): img_break_tok_id = -1 + use_break_tok = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.hparams.get("model_type") == "pixtral": # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5) - self.img_break_tok_id = self.get_token_id("[IMG_BREAK]") + if self.use_break_tok: + self.img_break_tok_id = self.get_token_id("[IMG_BREAK]") elif self.is_mistral_format: # hparams is already vision config here so norm_eps is only defined in global_config. self.hparams["norm_eps"] = self.global_config.get("norm_eps", None) assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json" - self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) + if self.use_break_tok: + self.img_break_tok_id = self.find_vparam(["image_break_token_id"]) else: raise ValueError(f"Unsupported model type: {self.hparams['model_type']}") logger.info(f"Image break token id: {self.img_break_tok_id}") @@ -3794,6 +3968,10 @@ class Qwen3Model(Qwen2Model): return torch.stack([true_row, false_row], dim=0) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "model.vision_" in name: + # skip multimodal tensors + return [] + if self.is_rerank: is_tied_head = self.is_tied_embeddings and "embed_tokens" in name is_real_head = not self.is_tied_embeddings and "lm_head" in name @@ -4361,27 +4539,6 @@ class CodeShellModel(TextModel): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(1.0) - _has_tok_embd = False - - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - - output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) - tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD) - - new_name = self.map_tensor_name(name) - - # assuming token_embd.weight is seen before output.weight - if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT): - # even though the tensor file(s) does not contain the word embeddings they are still in the weight map - if self.tensor_names and "transformer.wte.weight" in self.tensor_names: - logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied") - self.tensor_names.remove("transformer.wte.weight") - elif new_name == tok_embd_name: - self._has_tok_embd = True - - return [(new_name, data_torch)] - @ModelBase.register("InternLM2ForCausalLM") class InternLM2Model(TextModel): @@ -8092,8 +8249,6 @@ class BailingMoeV2Model(TextModel): self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) - self.gguf_writer.add_expert_group_count(hparams["n_group"]) - self.gguf_writer.add_expert_group_used_count(hparams["topk_group"]) self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) if hparams["score_function"] == "sigmoid": @@ -8813,6 +8968,13 @@ class SmolLM3Model(LlamaModel): class GptOssModel(TextModel): model_arch = gguf.MODEL_ARCH.GPT_OSS + # TODO: remove once MXFP4 is supported more generally + def dequant_model(self): + quant_config = self.hparams.get("quantization_config") + if quant_config is not None and quant_config.get("quant_method") == "mxfp4": + return + return super().dequant_model() + def transform_nibble_layout(self, tensor): assert tensor.dtype == torch.uint8 assert tensor.shape[-1] == 16 @@ -9255,7 +9417,7 @@ class MistralModel(LlamaModel): @staticmethod def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool): - assert TokenizerVersion is not None, "mistral_common is not installed" + assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg assert isinstance(vocab.tokenizer, (Tekkenizer, SentencePieceTokenizer)), ( f"Expected Tekkenizer or SentencePieceTokenizer, got {type(vocab.tokenizer)}" ) @@ -9323,6 +9485,21 @@ class PixtralModel(LlavaVisionModel): return super().map_tensor_name(name, try_suffixes) +@ModelBase.register("LightOnOCRForConditionalGeneration") +class LightOnOCRVisionModel(LlavaVisionModel): + is_mistral_format = False + use_break_tok = False + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + name = name.replace("model.vision_encoder.", "vision_tower.") + name = name.replace("model.vision_projection.", "multi_modal_projector.") + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("KimiVLForConditionalGeneration") class KimiVLModel(MmprojModel): def __init__(self, *args, **kwargs): @@ -9632,11 +9809,9 @@ def main() -> None: logger.info(f"Loading model: {dir_model.name}") - if args.mmproj: - if "mmproj" not in fname_out.name: - fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-") - is_mistral_format = args.mistral_format + if is_mistral_format and not _mistral_common_installed: + raise ImportError(_mistral_import_error_msg) disable_mistral_community_chat_template = args.disable_mistral_community_chat_template with torch.inference_mode(): diff --git a/docs/backend/hexagon/CMakeUserPresets.json b/docs/backend/hexagon/CMakeUserPresets.json new file mode 100644 index 0000000000..e0b19db0f5 --- /dev/null +++ b/docs/backend/hexagon/CMakeUserPresets.json @@ -0,0 +1,49 @@ +{ + "version": 4, + "configurePresets": [ + { + "name": "arm64-android-snapdragon", + "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x86_64", "strategy": "external" }, + "cacheVariables": { + "ANDROID_ABI": "arm64-v8a", + "ANDROID_PLATFORM": "android-31", + "CMAKE_TOOLCHAIN_FILE": "$env{ANDROID_NDK_ROOT}/build/cmake/android.toolchain.cmake", + "CMAKE_C_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE", + "CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only -flto -D_GNU_SOURCE", + "CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG", + "CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG", + "CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g", + "CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g", + "HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}", + "PREBUILT_LIB_DIR": "android_aarch64", + "GGML_OPENMP": "OFF", + "GGML_LLAMAFILE": "OFF", + "GGML_OPENCL": "ON", + "GGML_HEXAGON": "ON", + "LLAMA_CURL": "OFF" + } + }, + + { + "name": "arm64-windows-snapdragon", + "inherits": [ "base", "arm64-windows-llvm" ], + "cacheVariables": { + "HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}", + "PREBUILT_LIB_DIR": "windows_aarch64", + "GGML_OPENMP": "OFF", + "GGML_LLAMAFILE": "OFF", + "GGML_OPENCL": "ON", + "GGML_HEXAGON": "ON", + "LLAMA_CURL": "OFF" + } + }, + + { "name": "arm64-android-snapdragon-debug" , "inherits": [ "base", "arm64-android-snapdragon", "debug" ] }, + { "name": "arm64-android-snapdragon-release", "inherits": [ "base", "arm64-android-snapdragon", "release" ] }, + + { "name": "arm64-windows-snapdragon-debug" , "inherits": [ "base", "arm64-windows-snapdragon", "debug" ] }, + { "name": "arm64-windows-snapdragon-release", "inherits": [ "base", "arm64-windows-snapdragon", "release" ] } + ] +} diff --git a/docs/backend/hexagon/README.md b/docs/backend/hexagon/README.md new file mode 100644 index 0000000000..85f136ef9e --- /dev/null +++ b/docs/backend/hexagon/README.md @@ -0,0 +1,239 @@ +# Snapdragon-based Android devices + +## How to Build + +The easiest way to build llama.cpp for a Snapdragon-based Android device is using the toolchain Docker image (see github.com/snapdragon-toolchain). +This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc. + +This method works on Linux, macOS, and Windows. macOS and Windows users should install Docker Desktop. + +``` +~/src/llama.cpp$ docker run -it -u $(id -u):$(id -g) --volume $(pwd):/workspace --platform linux/amd64 ghcr.io/snapdragon-toolchain/arm64-android:v0.3 +[d]/> cd /workspace +``` + +The rest of the Android build process assumes that you're running inside the toolchain container. +Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets: + +``` +[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json . + +[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon +Preset CMake variables: + ANDROID_ABI="arm64-v8a" + ... + CMAKE_TOOLCHAIN_FILE="/opt/android-ndk-r28b/build/cmake/android.toolchain.cmake" + GGML_HEXAGON="ON" + GGML_OPENCL="ON" + GGML_OPENMP="OFF" + HEXAGON_SDK_ROOT="/opt/hexagon/6.4.0.2" +... +-- Including OpenCL backend +-- Including Hexagon backend +... +-- Build files have been written to: /workspace/build-snapdragon + +[d]/workspace> cmake --build build-snapdragon +... +[144/356] Performing build step for 'htp-v73' +[1/16] Generating htp_iface_skel.c, htp_iface_stub.c, htp_iface.h +[2/16] Building C object CMakeFiles/ggml-htp-v73.dir/hvx-sigmoid.c.obj +[3/16] Building C object CMakeFiles/ggml-htp-v73.dir/htp-dma.c.obj +[4/16] Building C object CMakeFiles/ggml-htp-v73.dir/worker-pool.c.obj +... +-- Installing: /workspace/build-snapdragon/ggml/src/ggml-hexagon/libggml-htp-v73.so +-- Installing: /workspace/build-snapdragon/ggml/src/ggml-hexagon/libggml-htp-v75.so +... +``` + +To generate an installable "package" simply use cmake --install: + +``` +[d]/workspace> cmake --install build-snapdragon --prefix pkg-adb/llama.cpp +-- Install configuration: "Release" +-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-cpu.so +-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-opencl.so +-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-hexagon.so +-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v73.so +-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v75.so +-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v79.so +-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v81.so +-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml.so +... +-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-bench +-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-cli +... +``` + +## How to Install + +For this step, your device needs to be configured for on-device development. +Please see https://developer.android.com/studio/debug/dev-options for details. + +Once ADB is enabled, use `adb push` to install `pkg-snapdragon` on the device. +**Note that the toolchain Docker image doesn't have ADB and doesn't set up the ADB bridge. Please use native ADB on the host.** + +``` +~/src/llama.cpp$ adb push pkg-adb/llama.cpp /data/local/tmp/ +pkg-adb/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s) +pkg-adb/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s) +pkg-adb/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s) +102 files pushed, 0 skipped. 186.9 MB/s (963151597 bytes in 4.914s) +``` + +At this point, you should also install some models: + +``` +~/src/llama.cpp$ wget https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q4_0.gguf +... +2025-10-11 12:04:52 (10.7 MB/s) - ‘Llama-3.2-1B-Instruct-Q4_0.gguf’ saved [773025920/773025920] + +~/src/llama.cpp$ adb push Llama-3.2-1B-Instruct-Q4_0.gguf /data/local/tmp/gguf +Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 bytes in 19.250s) +``` + +## How to Run + +The easiest way to run llama.cpp cli tools is using provided wrapper scripts that properly set up all required environment variables. + +llama.cpp supports three backends on Snapdragon-based devices: CPU, Adreno GPU (GPUOpenCL), and Hexagon NPU (HTP0-4). +You can select which backend to run the model on using the `D=` variable, which maps to the `--device` option. + +Hexagon NPU behaves as a "GPU" device when it comes to `-ngl` and other offload-related options. + +Here are some examples of running various llama.cpp tools via ADB. + +Simple question for Llama-3.2-1B + +``` +~/src/llama.cpp$ M=Llama-3.2-1B-Instruct-Q4_0.gguf D=HTP0 ./scripts/snapdragon/adb/run-cli.sh -no-cnv -p "what is the most popular cookie in the world?" +... +ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1 +ggml-hex: Hexagon Arch version v79 +ggml-hex: allocating new session: HTP0 +ggml-hex: new session: HTP0 : session-id 0 domain-id 3 uri file:///libggml-htp-v79.so?htp_iface_skel_handle_invoke&_modver=1.0&_dom=cdsp&_session=0 handle 0xb4000072c7955e50 +... +load_tensors: offloading output layer to GPU +load_tensors: offloaded 17/17 layers to GPU +load_tensors: CPU model buffer size = 225.49 MiB +load_tensors: HTP0 model buffer size = 0.26 MiB +load_tensors: HTP0-REPACK model buffer size = 504.00 MiB +... +I hope this helps you understand the world's most popular cookies! [end of text] +... +llama_perf_sampler_print: sampling time = 30.08 ms / 487 runs ( 0.06 ms per token, 16191.77 tokens per second) +llama_perf_context_print: load time = 617.94 ms +llama_perf_context_print: prompt eval time = 80.76 ms / 11 tokens ( 7.34 ms per token, 136.21 tokens per second) +llama_perf_context_print: eval time = 9210.59 ms / 475 runs ( 19.39 ms per token, 51.57 tokens per second) +llama_perf_context_print: total time = 9454.92 ms / 486 tokens +llama_perf_context_print: graphs reused = 473 +llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted | +llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 | +llama_memory_breakdown_print: | - Host | 439 = 225 + 136 + 77 | +llama_memory_breakdown_print: | - HTP0-REPACK | 504 = 504 + 0 + 0 | +``` + +Summary request for OLMoE-1B-7B. This is a large model that requires two HTP sessions/devices + +``` +~/src/llama.cpp$ M=OLMoE-1B-7B-0125-Instruct-Q4_0.gguf NDEV=2 D=HTP0,HTP1 ./scripts/snapdragon/adb/run-cli.sh -f surfing.txt -no-cnv +... +ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1 +ggml-hex: Hexagon Arch version v81 +ggml-hex: allocating new session: HTP0 +ggml-hex: allocating new session: HTP1 +... +load_tensors: offloading output layer to GPU +load_tensors: offloaded 17/17 layers to GPU +load_tensors: CPU model buffer size = 143.86 MiB +load_tensors: HTP1 model buffer size = 0.23 MiB +load_tensors: HTP1-REPACK model buffer size = 1575.00 MiB +load_tensors: HTP0 model buffer size = 0.28 MiB +load_tensors: HTP0-REPACK model buffer size = 2025.00 MiB +... +llama_context: CPU output buffer size = 0.19 MiB +llama_kv_cache: HTP1 KV buffer size = 238.00 MiB +llama_kv_cache: HTP0 KV buffer size = 306.00 MiB +llama_kv_cache: size = 544.00 MiB ( 8192 cells, 16 layers, 1/1 seqs), K (q8_0): 272.00 MiB, V (q8_0): 272.00 MiB +llama_context: HTP0 compute buffer size = 15.00 MiB +llama_context: HTP1 compute buffer size = 15.00 MiB +llama_context: CPU compute buffer size = 24.56 MiB +... +llama_perf_context_print: prompt eval time = 1730.57 ms / 212 tokens ( 8.16 ms per token, 122.50 tokens per second) +llama_perf_context_print: eval time = 5624.75 ms / 257 runs ( 21.89 ms per token, 45.69 tokens per second) +llama_perf_context_print: total time = 7377.33 ms / 469 tokens +llama_perf_context_print: graphs reused = 255 +llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted | +llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 | +llama_memory_breakdown_print: | - HTP1 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 | +llama_memory_breakdown_print: | - Host | 742 = 144 + 544 + 54 | +llama_memory_breakdown_print: | - HTP1-REPACK | 1575 = 1575 + 0 + 0 | +llama_memory_breakdown_print: | - HTP0-REPACK | 2025 = 2025 + 0 + 0 | +``` + +Op test for MUL_MAT + +``` +~/src/llama.cpp$ HB=0 ./scripts/snapdragon/adb/run-tool.sh test-backend-ops -b HTP0 -o MUL_MAT +... +Backend 2/3: HTP0 +Device description: Hexagon +Device memory: 2048 MB (2048 MB free) +MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=1,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK +MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=2,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK +MUL_MAT(type_a=q4_0,type_b=f32,m=16,n=3,k=256,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=0,o=1): OK + +~/src/llama.cpp-hexagon$ M=Llama-3.2-1B-Instruct-Q4_0.gguf ./scripts/snapdragon/adb/run-bench.sh -p 128 -n 64 +... +ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev 1 +ggml-hex: Hexagon Arch version v79 +ggml-hex: allocating new session: HTP0 +ggml-hex: new session: HTP0 : session-id 0 domain-id 3 uri file:///libggml-htp-v79.so?htp_iface_skel_handle_invoke&_modver=1.0&_dom=cdsp&_session=0 handle 0xb400007d4b231090 +| model | size | params | backend | ngl | threads | n_batch | mmap | test | t/s | +| ---------------| ---------: | -----: | ---------- | --: | ------: | ------: | ---: | ----: | ------------: | +| llama 1B Q4_0 | 729.75 MiB | 1.24 B | HTP | 99 | 4 | 128 | 0 | pp128 | 169.42 ± 1.75 | +| llama 1B Q4_0 | 729.75 MiB | 1.24 B | HTP | 99 | 4 | 128 | 0 | tg64 | 51.54 ± 1.13 | + +build: 6a8cf8914 (6733) +``` + +## Environment variables + +- `GGML_HEXAGON_NDEV=1` + Controls the number of devices/sessions to allocate. The default is 1. + Most quantized models under 4B fit into a single session; an 8B model needs two, and a 20B model needs four. + +- `GGML_HEXAGON_NHVX=0` + Controls the number of HVX hardware threads to use. The default is all (actual number varies depending on the hardware version). + +- `GGML_HEXAGON_HOSTBUF=1` + Controls whether the Hexagon backend allocates host buffers. By default, all buffers except for REPACK are host buffers. + This option is required for testing Ops that require REPACK buffers (MUL_MAT and MUL_MAT_ID). + +- `GGML_HEXAGON_VERBOSE=1` + Enables verbose logging of Ops from the backend. Example output: + + ``` + ggml-hex: HTP0 graph-compute n_nodes 2 + ggml-hex: HTP0 matmul : blk.27.ffn_up.weight x ffn_norm-27 -> ffn_up-27 : 3072:8192 x 3072:1 -> 8192:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x1 + ggml-hex: HTP0 matmul : blk.27.ffn_gate.weight x ffn_norm-27 -> ffn_gate-27 : 3072:8192 x 3072:1 -> 8192:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x3 + ggml-hex: HTP0 graph-compute n_nodes 1 + ggml-hex: HTP0 matmul : blk.27.ffn_down.weight x ffn_gate_par-27 -> ffn_out-27 : 8192:3072 x 8192:1 -> 3072:1 : q4_0 x f32 -> f32 : HTP0 x HTP0 -> HTP0 : flags 0x0 + ggml-hex: HTP0 get-tensor result_output : data 0x7592487000 offset 0 size 513024 + ``` + +- `GGML_HEXAGON_PROFILE=1` + Generates a host-side profile for the ggml-hexagon Ops. + +- `GGML_HEXAGON_OPMASK=0x0` + Allows enabling specific stages of the processing pipeline: + + - `0x1` Enable Op Queue (i.e., queuing Ops into NPU) + - `0x2` Enable Dynamic Quantizer (if needed for the Op) + - `0x4` Enable Op Compute (MUL_MAT, etc.) + + Examples: + + `GGML_HEXAGON_OPMASK=0x1 llama-cli ...` - Ops are enqueued but NPU-side processing is stubbed out + `GGML_HEXAGON_OPMASK=0x3 llama-cli ...` - NPU performs dynamic quantization and skips the rest + `GGML_HEXAGON_OPMASK=0x7 llama-cli ...` - Full queuing and processing of Ops (default) diff --git a/docs/backend/hexagon/developer.md b/docs/backend/hexagon/developer.md new file mode 100644 index 0000000000..200a7aabc0 --- /dev/null +++ b/docs/backend/hexagon/developer.md @@ -0,0 +1,109 @@ +# Hexagon backend developer details + +## Backend libraries + +The Hexagon backend consist of two parts: + + - `libggml-hexagon` + This is the regular CPU-side GGML backend library, either shared or statically linked + + - `libggml-htp-vNN` + This is the NPU-side (HTP stands for Hexagon Tensor Processor) shared library that contains the Op dispatcher and kernels. + The correct library is selected automatically at runtime based on the HW version. + +Here is an example of the build artifacts + +``` +~/src/llama.cpp$ ls -l pkg-adb/llama.cpp/lib/libggml* +pkg-adb/llama.cpp/lib/libggml-base.so +pkg-adb/llama.cpp/lib/libggml-cpu.so +pkg-adb/llama.cpp/lib/libggml-hexagon.so <<< CPU library +pkg-adb/llama.cpp/lib/libggml-htp-v73.so <<< HTP op/kernels for Hexagon v73 +pkg-adb/llama.cpp/lib/libggml-htp-v75.so +pkg-adb/llama.cpp/lib/libggml-htp-v79.so +pkg-adb/llama.cpp/lib/libggml-htp-v81.so +``` + +## Memory buffers + +Hexagon NPU backend takes advantage of the Snapdragon's unified memory model where all buffers are fully accessible by the CPU and GPU. +The NPU does have a dedicated tightly-coupled memory called VTCM but that memory is used only for intermediate data (e.g. dynamically +quantized tensors) or temporary data (chunks of the weight tensors fetched via DMA). + +Please note that currently the Hexagon backend does not implement SET/GET_ROWS Ops because there is no advantage in offloading those +to the NPU at this point. + +The backend does allocates non-host buffers for the tensors with datatypes that require repacking: Q4_0, Q8_0, MXFP4. +From the MMU perspective these buffers are still regular buffers (normal access by the CPU) they are marked as non-host simply to force +the repacking. + +## Large model handling + +Hexagon NPU session (aka Process Domain (PD) in the Hexagon docs) is limited to a memory mapping of around 3.5GB. +In llama.cpp/GGML the Hexagon session is mapped to a single GGML backend device (HTP0, HTP1, etc). + +In order to map models larger than 3.5GB we need to allocate multiple devices and split the model. +For this we're taking advantage of the llama.cpp/GGML multi-GPU layer-splitting support. +Each Hexagon device behaves like a GPU from the offload and model splitting perspective. + +Here is an example of running GPT-OSS-20B model on a newer Snapdragon device with 16GB of DDR. + +``` +M=gpt-oss-20b-Q4_0.gguf NDEV=4 D=HTP0,HTP1,HTP2,HTP3 P=surfing.txt scripts/snapdragon/adb/run-cli.sh -no-cnv -f surfing.txt -n 32 +... +LD_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib +ADSP_LIBRARY_PATH=/data/local/tmp/llama.cpp/lib +GGML_HEXAGON_NDEV=4 ./bin/llama-cli --no-mmap -m /data/local/tmp/llama.cpp/../gguf/gpt-oss-20b-Q4_0.gguf + -t 4 --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on -ngl 99 --device HTP0,HTP1,HTP2,HTP3 -no-cnv -f surfing.txt +... +llama_model_loader: - type f32: 289 tensors +llama_model_loader: - type q4_0: 96 tensors +llama_model_loader: - type q8_0: 2 tensors +llama_model_loader: - type mxfp4: 72 tensors +... +load_tensors: offloaded 25/25 layers to GPU +load_tensors: CPU model buffer size = 1182.09 MiB +load_tensors: HTP1 model buffer size = 6.64 MiB +load_tensors: HTP1-REPACK model buffer size = 2505.94 MiB +load_tensors: HTP3 model buffer size = 5.55 MiB +load_tensors: HTP3-REPACK model buffer size = 2088.28 MiB +load_tensors: HTP0 model buffer size = 7.75 MiB +load_tensors: HTP0-REPACK model buffer size = 2923.59 MiB +load_tensors: HTP2 model buffer size = 6.64 MiB +load_tensors: HTP2-REPACK model buffer size = 2505.94 MiB +... +llama_context: n_ctx_per_seq (8192) < n_ctx_train (131072) -- the full capacity of the model will not be utilized +llama_context: CPU output buffer size = 0.77 MiB +llama_kv_cache_iswa: creating non-SWA KV cache, size = 8192 cells +llama_kv_cache: HTP1 KV buffer size = 25.50 MiB +llama_kv_cache: HTP3 KV buffer size = 25.50 MiB +llama_kv_cache: HTP0 KV buffer size = 25.50 MiB +llama_kv_cache: HTP2 KV buffer size = 25.50 MiB +llama_kv_cache: size = 102.00 MiB ( 8192 cells, 12 layers, 1/1 seqs), K (q8_0): 51.00 MiB, V (q8_0): 51.00 MiB +llama_kv_cache_iswa: creating SWA KV cache, size = 256 cells +llama_kv_cache: HTP1 KV buffer size = 0.80 MiB +llama_kv_cache: HTP3 KV buffer size = 0.53 MiB +llama_kv_cache: HTP0 KV buffer size = 1.06 MiB +llama_kv_cache: HTP2 KV buffer size = 0.80 MiB +llama_kv_cache: size = 3.19 MiB ( 256 cells, 12 layers, 1/1 seqs), K (q8_0): 1.59 MiB, V (q8_0): 1.59 MiB +llama_context: HTP0 compute buffer size = 16.06 MiB +llama_context: HTP1 compute buffer size = 16.06 MiB +llama_context: HTP2 compute buffer size = 16.06 MiB +llama_context: HTP3 compute buffer size = 16.06 MiB +llama_context: CPU compute buffer size = 98.19 MiB +... +llama_perf_context_print: prompt eval time = 3843.67 ms / 197 tokens ( 19.51 ms per token, 51.25 tokens per second) +llama_perf_context_print: eval time = 1686.13 ms / 31 runs ( 54.39 ms per token, 18.39 tokens per second) +llama_perf_context_print: total time = 6266.30 ms / 228 tokens +llama_perf_context_print: graphs reused = 30 +llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted | +llama_memory_breakdown_print: | - HTP0 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 | +llama_memory_breakdown_print: | - HTP1 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 | +llama_memory_breakdown_print: | - HTP2 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 | +llama_memory_breakdown_print: | - HTP3 (Hexagon) | 2048 = 2048 + ( 0 = 0 + 0 + 0) + 0 | +llama_memory_breakdown_print: | - Host | 1476 = 1208 + 105 + 162 | +llama_memory_breakdown_print: | - HTP1-REPACK | 2505 = 2505 + 0 + 0 | +llama_memory_breakdown_print: | - HTP3-REPACK | 2088 = 2088 + 0 + 0 | +llama_memory_breakdown_print: | - HTP0-REPACK | 2923 = 2923 + 0 + 0 | +llama_memory_breakdown_print: | - HTP2-REPACK | 2505 = 2505 + 0 + 0 | +``` diff --git a/docs/build.md b/docs/build.md index dcbcce7549..b410c710e3 100644 --- a/docs/build.md +++ b/docs/build.md @@ -261,10 +261,12 @@ You can download it from your Linux distro's package manager or from here: [ROCm - Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU): ```bash HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ - cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ + cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ && cmake --build build --config Release -- -j 16 ``` + Note: `GPU_TARGETS` is optional, omitting it will build the code for all GPUs in the current system. + To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system. The rocWMMA library is included by default when installing the ROCm SDK using the `rocm` meta package provided by AMD. Alternatively, if you are not using the meta package, you can install the library using the `rocwmma-dev` or `rocwmma-devel` package, depending on your system's package manager. @@ -282,17 +284,17 @@ You can download it from your Linux distro's package manager or from here: [ROCm ```bash HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -p)" \ HIP_DEVICE_LIB_PATH= \ - cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ + cmake -S . -B build -DGGML_HIP=ON -DGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ && cmake --build build -- -j 16 ``` - Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU): ```bash set PATH=%HIP_PATH%\bin;%PATH% - cmake -S . -B build -G Ninja -DAMDGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release + cmake -S . -B build -G Ninja -DGPU_TARGETS=gfx1100 -DGGML_HIP=ON -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release cmake --build build ``` - Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors) + If necessary, adapt `GPU_TARGETS` to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors) Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`. diff --git a/docs/ops.md b/docs/ops.md index dfd1cfab6a..3738a48072 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -79,7 +79,7 @@ Legend: | REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ | | REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | | RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | -| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | +| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | | RMS_NORM_MUL_ADD | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | ROLL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | | ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | diff --git a/docs/ops/SYCL.csv b/docs/ops/SYCL.csv index fe6876357f..101e80f64c 100644 --- a/docs/ops/SYCL.csv +++ b/docs/ops/SYCL.csv @@ -5637,25 +5637,25 @@ "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000000,inplace=0","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000000,inplace=0","support","1","yes","SYCL" -"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","0","no","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000000","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=0","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000001,inplace=0","support","1","yes","SYCL" -"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","0","no","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000001","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000100,inplace=0","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.000100,inplace=0","support","1","yes","SYCL" -"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","0","no","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.000100","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.100000,inplace=0","support","1","yes","SYCL" "SYCL0","NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=1,eps=0.100000,inplace=0","support","1","yes","SYCL" -"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","0","no","SYCL" +"SYCL0","RMS_NORM_BACK","type=f32,ne=[64,5,4,3],eps=0.100000","support","1","yes","SYCL" "SYCL0","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","SYCL" "SYCL0","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","SYCL" "SYCL0","RMS_NORM_MUL_ADD","type=f32,ne=[64,5,4,3],eps=0.000000,broadcast=0,multi_add=0","support","1","yes","SYCL" diff --git a/examples/embedding/README.md b/examples/embedding/README.md index 3dd279d9fc..1684f36480 100644 --- a/examples/embedding/README.md +++ b/examples/embedding/README.md @@ -38,6 +38,7 @@ The above command will output space-separated float values. | | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$ | 'json' | openai style | | 'json+' | add cosine similarity matrix | +| 'raw' | plain text output | ### --embd-separator $"string"$ | $"string"$ | | diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 388908bc4d..9e3ab5905b 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -70,6 +70,29 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } } +// plain, pipe-friendly output: one embedding per line +static void print_raw_embeddings(const float * emb, + int n_embd_count, + int n_embd, + const llama_model * model, + enum llama_pooling_type pooling_type, + int embd_normalize) { + const uint32_t n_cls_out = llama_model_n_cls_out(model); + const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK); + const int cols = is_rank ? std::min(n_embd, (int) n_cls_out) : n_embd; + + for (int j = 0; j < n_embd_count; ++j) { + for (int i = 0; i < cols; ++i) { + if (embd_normalize == 0) { + LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); + } else { + LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); + } + } + LOG("\n"); + } +} + int main(int argc, char ** argv) { common_params params; @@ -372,6 +395,8 @@ int main(int argc, char ** argv) { } if (notArray) LOG("\n}\n"); + } else if (params.embd_out == "raw") { + print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize); } LOG("\n"); diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 2d57549046..26989157fe 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -371,8 +371,17 @@ class SchemaConverter: raise ValueError(f'Unsupported ref {ref}') for sel in ref.split('#')[-1].split('/')[1:]: - assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' - target = target[sel] + assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}' + if isinstance(target, list): + try: + sel_index = int(sel) + except ValueError: + raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}') + assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel_index] + else: + assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] self._refs[ref] = target else: @@ -547,7 +556,8 @@ class SchemaConverter: def _resolve_ref(self, ref): - ref_name = ref.split('/')[-1] + ref_fragment = ref.split('#')[-1] + ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment) if ref_name not in self._rules and ref not in self._refs_being_resolved: self._refs_being_resolved.add(ref) resolved = self._refs[ref] diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index 9444c713d0..7fb55e9af1 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -138,7 +138,7 @@ if model_path is None: "Model path must be specified either via --model-path argument or MODEL_PATH environment variable" ) -config = AutoConfig.from_pretrained(model_path) +config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) print("Model type: ", config.model_type) print("Vocab size: ", config.vocab_size) @@ -148,8 +148,8 @@ print("BOS token id: ", config.bos_token_id) print("EOS token id: ", config.eos_token_id) print("Loading model and tokenizer using AutoTokenizer:", model_path) -tokenizer = AutoTokenizer.from_pretrained(model_path) -config = AutoConfig.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if unreleased_model_name: model_name_lower = unreleased_model_name.lower() @@ -171,7 +171,7 @@ if unreleased_model_name: exit(1) else: model = AutoModelForCausalLM.from_pretrained( - model_path, device_map="auto", offload_folder="offload" + model_path, device_map="auto", offload_folder="offload", trust_remote_code=True ) for name, module in model.named_modules(): diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 73032be68e..181f179ed1 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -251,6 +251,8 @@ option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adr set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING "gmml: OpenCL API version to target") +option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF) + # toolchain for vulkan-shaders-gen set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen") diff --git a/ggml/include/ggml-hexagon.h b/ggml/include/ggml-hexagon.h new file mode 100644 index 0000000000..6e07900410 --- /dev/null +++ b/ggml/include/ggml-hexagon.h @@ -0,0 +1,19 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_hexagon_init(void); + +GGML_BACKEND_API bool ggml_backend_is_hexagon(ggml_backend_t backend); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_hexagon_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 3356ef550d..ba281b8e6d 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -402,6 +402,7 @@ ggml_add_backend(Vulkan) ggml_add_backend(WebGPU) ggml_add_backend(zDNN) ggml_add_backend(OpenCL) +ggml_add_backend(Hexagon) foreach (target ggml-base ggml) target_include_directories(${target} PUBLIC $ $) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index c830c09655..91aff205f1 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -226,16 +226,23 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al } if (best_fit_block == -1) { - // no suitable block found, try the last block (this will grow a chunks size) + // no suitable block found, try the last block (this may grow a chunks size) + int64_t best_reuse = INT64_MIN; for (int c = 0; c < alloc->n_chunks; ++c) { struct tallocr_chunk * chunk = alloc->chunks[c]; if (chunk->n_free_blocks > 0) { struct free_block * block = &chunk->free_blocks[chunk->n_free_blocks - 1]; max_avail = MAX(max_avail, block->size); - if (block->size >= size) { + int64_t reuse_factor = chunk->max_size - block->offset - size; + // reuse_factor < 0 : amount of extra memory that needs to be allocated + // reuse_factor = 0 : allocated free space exactly matches tensor size + // reuse_factor > 0 : superfluous memory that will remain unused + bool better_reuse = best_reuse < 0 && reuse_factor > best_reuse; + bool better_fit = reuse_factor >= 0 && reuse_factor < best_reuse; + if (block->size >= size && (better_reuse || better_fit)) { best_fit_chunk = c; best_fit_block = chunk->n_free_blocks - 1; - break; + best_reuse = reuse_factor; } } } @@ -268,7 +275,7 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al #ifdef GGML_ALLOCATOR_DEBUG add_allocated_tensor(alloc, addr, tensor); size_t cur_max = addr.offset + size; - if (cur_max > alloc->max_size[addr.chunk]) { + if (cur_max > chunk->max_size) { // sort allocated_tensors by chunk/offset for (int i = 0; i < 1024; i++) { for (int j = i + 1; j < 1024; j++) { diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 136afec748..e96b5c403d 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -57,6 +57,10 @@ #include "ggml-opencl.h" #endif +#ifdef GGML_USE_HEXAGON +#include "ggml-hexagon.h" +#endif + #ifdef GGML_USE_BLAS #include "ggml-blas.h" #endif @@ -199,6 +203,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_OPENCL register_backend(ggml_backend_opencl_reg()); #endif +#ifdef GGML_USE_HEXAGON + register_backend(ggml_backend_hexagon_reg()); +#endif #ifdef GGML_USE_CANN register_backend(ggml_backend_cann_reg()); #endif @@ -598,6 +605,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) { ggml_backend_load_best("sycl", silent, dir_path); ggml_backend_load_best("vulkan", silent, dir_path); ggml_backend_load_best("opencl", silent, dir_path); + ggml_backend_load_best("hexagon", silent, dir_path); ggml_backend_load_best("musa", silent, dir_path); ggml_backend_load_best("cpu", silent, dir_path); // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index f030ea0136..5df6dc96a3 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, ACL_MEM_MALLOC_HUGE_FIRST)); acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float), - theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); + theta_scale_ne, theta_scale_nb, 1); float start = 0; float step = 1; @@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx, yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); void * yarn_ramp_buffer = yarn_ramp_allocator.get(); acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, - theta_scale_nb, GGML_MAX_DIMS); + theta_scale_nb, 1); float zero_value = 0, one_value = 1; float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); aclScalar * low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 8bd5449f1f..51345742ee 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -67,19 +67,30 @@ GGML_ABORT("CANN error"); } +// Thread-local variable to record the current device of this thread. +thread_local int g_current_cann_device = -1; + /** - * @brief Sets the device to be used by CANN. + * @brief Set the CANN device to be used. * - * @param device The device ID to set. + * @param device The target device ID to set. */ void ggml_cann_set_device(const int32_t device) { - int current_device = -1; - aclrtGetDevice(¤t_device); + // int current_device = -1; + // Note: In some CANN versions, if no device has been set yet, + // aclrtGetDevice(¤t_device) may return 0 by default. + // aclrtGetDevice(¤t_device); - if (device == current_device) { + // If the current device is already the target one, no need to switch. + if (device == g_current_cann_device) { return; } + + // Switch to the new device. ACL_CHECK(aclrtSetDevice(device)); + + // Update the global device record. + g_current_cann_device = device; } /** diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b52f0f8472..3156bd6010 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7519,8 +7519,8 @@ static void ggml_compute_forward_upscale_f32( float pixel_offset = 0.5f; if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { pixel_offset = 0.0f; - sf0 = (float)(ne0 - 1) / (src0->ne[0] - 1); - sf1 = (float)(ne1 - 1) / (src0->ne[1] - 1); + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; } for (int64_t i3 = 0; i3 < ne3; i3++) { diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index fbf8873c31..65c7dfb6b9 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -77,85 +77,16 @@ inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i])); } } -inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { - int i = 0; -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vv = GGML_F32_VEC_SET1(v); - - for (; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; ++j) { - GGML_F32_VEC ax = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - GGML_F32_VEC az = GGML_F32_VEC_ADD(ax, vv); - GGML_F32_VEC_STORE(z + i + j*GGML_F32_EPR, az); - } - } -#endif - for (; i < n; ++i) { - z[i] = x[i] + v; - } -} -inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { - int i = 0; -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - for (; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; ++j) { - GGML_F32_VEC ay = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - GGML_F32_VEC ax = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay = GGML_F32_VEC_ADD(ay, ax); - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay); - } - } -#endif - for (; i < n; ++i) { - y[i] += x[i]; - } -} -inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { - int i = 0; -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vv = GGML_F32_VEC_SET1(v); - - for (; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; ++j) { - GGML_F32_VEC ay = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay = GGML_F32_VEC_ADD(ay, vv); - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay); - } - } -#endif - for (; i < n; ++i) { - y[i] += v; - } -} +inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } +inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { for (int i = 0; i < n; ++i) { z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) - GGML_CPU_FP16_TO_FP32(y[i])); } } -inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { - int i = 0; -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - - for (; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; ++j) { - GGML_F32_VEC_STORE(x + i + j*GGML_F32_EPR, vx); - } - } -#endif - for (; i < n; ++i) { - x[i] = v; - } -} +inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { @@ -164,24 +95,7 @@ inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp } } -inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { - int i = 0; -#if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - for (; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; ++j) { - GGML_F32_VEC ax = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - GGML_F32_VEC ay = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - GGML_F32_VEC az = GGML_F32_VEC_MUL(ax, ay); - GGML_F32_VEC_STORE(z + i + j*GGML_F32_EPR, az); - } - } -#endif - for (; i < n; ++i) { - z[i] = x[i]*y[i]; - } -} +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { for (int i = 0; i < n; ++i) { z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) * GGML_CPU_FP16_TO_FP32(y[i])); diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 607ded8558..6e7b90d427 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -1,5 +1,81 @@ #include "argsort.cuh" +#ifdef GGML_CUDA_USE_CUB +# include +using namespace cub; +#endif // GGML_CUDA_USE_CUB + +static __global__ void init_indices(int * indices, const int ncols, const int nrows) { + const int col = blockIdx.x * blockDim.x + threadIdx.x; + const int row = blockIdx.y; + + if (col < ncols && row < nrows) { + indices[row * ncols + col] = col; + } +} + +static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx <= nrows) { + offsets[idx] = idx * ncols; + } +} + +#ifdef GGML_CUDA_USE_CUB +static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { + ggml_cuda_pool_alloc temp_indices_alloc(pool, ncols * nrows); + ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols * nrows); + ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); + + int * temp_indices = temp_indices_alloc.get(); + float * temp_keys = temp_keys_alloc.get(); + int * d_offsets = offsets_alloc.get(); + + static const int block_size = 256; + const dim3 grid_size((ncols + block_size - 1) / block_size, nrows); + init_indices<<>>(temp_indices, ncols, nrows); + + const dim3 offset_grid((nrows + block_size - 1) / block_size); + init_offsets<<>>(d_offsets, ncols, nrows); + + cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream); + + size_t temp_storage_bytes = 0; + + if (order == GGML_SORT_ORDER_ASC) { + DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits + stream); + } else { + DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, + dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, + sizeof(float) * 8, stream); + } + + ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); + void * d_temp_storage = temp_storage_alloc.get(); + + if (order == GGML_SORT_ORDER_ASC) { + DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, + ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8, + stream); + } else { + DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, + 0, sizeof(float) * 8, stream); + } +} +#endif // GGML_CUDA_USE_CUB + +// Bitonic sort implementation template static inline __device__ void ggml_cuda_swap(T & a, T & b) { T tmp = a; @@ -65,7 +141,12 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { +static void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); @@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); + k_argsort_f32_i32 + <<>>(x, dst, ncols, ncols_pad); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad); + k_argsort_f32_i32 + <<>>(x, dst, ncols, ncols_pad); } else { GGML_ABORT("fatal error"); } @@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); +#ifdef GGML_CUDA_USE_CUB + const int ncols_pad = next_power_of_2(ncols); + const size_t shared_mem = ncols_pad * sizeof(int); + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + + if (shared_mem > max_shared_mem || ncols > 1024) { + ggml_cuda_pool & pool = ctx.pool(); + argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream); + } else { + argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); + } +#else + argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); +#endif } diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 6024010274..0e6d777b1e 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]); const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]); - if (block_nums.z > 65535) { + if (block_nums.z > 65535 || block_nums.y > 65535) { int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 41ff89c4d6..6a472be7fb 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -625,8 +625,11 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { // and a shift: // // n/d = (mulhi(n, mp) + n) >> L; -static const uint3 init_fastdiv_values(uint32_t d) { - GGML_ASSERT(d != 0); +static const uint3 init_fastdiv_values(uint64_t d_64) { + GGML_ASSERT(d_64 != 0); + GGML_ASSERT(d_64 <= std::numeric_limits::max()); + + uint32_t d = (uint32_t)d_64; // compute L = ceil(log2(d)); uint32_t L = 0; @@ -1005,3 +1008,16 @@ struct ggml_backend_cuda_context { return pool(device); } }; + +struct ggml_cuda_mm_fusion_args_host { + const ggml_tensor * x_bias = nullptr; + const ggml_tensor * gate = nullptr; + const ggml_tensor * gate_bias = nullptr; + ggml_glu_op glu_op; +}; +struct ggml_cuda_mm_fusion_args_device { + const void * x_bias = nullptr; + const void * gate = nullptr; + const void * gate_bias = nullptr; + ggml_glu_op glu_op; +}; diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index ef9e129950..8a5e08ef66 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -1,3 +1,4 @@ +#pragma once #include "common.cuh" #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 12d5bf7763..c5821acbde 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -112,6 +112,30 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, cpy_blck(cx + x_offset, cdst + dst_offset); } +template +static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + const src_t * x = (const src_t *) cx; + dst_t * dst = (dst_t *) cdst; + + dst[i] = ggml_cuda_cast(x[i]); +} + +template +static void ggml_cpy_flt_contiguous_cuda( + const char * cx, char * cdst, const int64_t ne, +cudaStream_t stream) { + + const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_flt_contiguous<<>> + (cx, cdst, ne); +} + template static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, @@ -285,7 +309,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; - if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); + + if (src0->type == src1->type && contiguous_srcs) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); #if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) { @@ -296,11 +322,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if (contiguous_srcs) { + ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if (contiguous_srcs) { + ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { @@ -327,21 +361,45 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if (contiguous_srcs) { + ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if (contiguous_srcs) { + ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if (contiguous_srcs) { + ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if (contiguous_srcs) { + ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if (contiguous_srcs) { + ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + if (contiguous_srcs) { + ggml_cpy_flt_contiguous_cuda (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6e7c5aedbc..fcff5d7cdc 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -50,6 +50,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml.h" @@ -1957,8 +1958,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct size_t src1_stride_size = sizeof(cuda_t); - dim3 block_dims(ne13, ne12); - k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( + const int threads_x = 16; + const int threads_y = 16; + dim3 block_dims(threads_x, threads_y); + + dim3 grid_dims( + (ne13 + threads_x - 1) / threads_x, + (ne12 + threads_y - 1) / threads_y + ); + k_compute_batched_ptrs<<>>( src0_ptr, src1_ptr, dst_t, ptrs_src.get(), ptrs_dst.get(), ne12, ne13, @@ -2007,6 +2015,147 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } +static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up, + const ggml_tensor * ffn_gate, + const ggml_tensor * glu, + const ggml_tensor * ffn_up_bias = nullptr, + const ggml_tensor * ffn_gate_bias = nullptr) { + const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr; + + if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) { + return false; + } + + const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU; + const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU; + + GGML_ASSERT(ffn_up && ffn_gate && glu); + + if (!is_mul_mat && !is_mul_mat_id) { + return false; + } + + const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (has_bias) { + if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) { + return false; + } + + if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) { + return false; + } + + if (expected_bias_op == GGML_OP_ADD) { + const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up; + const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate; + if (!up_has_mul || !gate_has_mul) { + return false; + } + } else { // GGML_OP_ADD_ID + if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) { + return false; + } + if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) { + return false; + } + } + } else { + if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) { + return false; + } + } + + if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) || + !ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) { + return false; + } + + if (ffn_up->src[1] != ffn_gate->src[1]) { + return false; + } + + if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) { + return false; + } + + static constexpr std::array valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI }; + + if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) { + return false; + } + + if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) { + return false; + } + + const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) || + ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft); + + //TODO: add support for fusion for split buffers + if (split) { + return false; + } + + return true; +} + +static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) { + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + const ggml_tensor * dst = tensor; + + const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID; + + bool use_mul_mat_vec_f = + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) && + src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]); + + //we only support fusion for ncols_dst = 1 + if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) { + return false; + } + + if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) { + return false; + } + + + return use_mul_mat_vec_f; +} + +static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) { + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + const ggml_tensor * dst = tensor; + + const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && + ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && + src0->view_src; + + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + + // fusion is not universally faster on Pascal + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + if (cc <= GGML_CUDA_CC_PASCAL) { + return false; + } + //we only support fusion for ncols_dst = 1 + if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) { + return false; + } + + if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) { + return false; + } + + return use_mul_mat_vec_q; +} + static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); @@ -2268,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SET_ROWS: ggml_cuda_op_set_rows(ctx, dst); break; + case GGML_OP_SET: + ggml_cuda_op_set(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -2745,7 +2897,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } } - if (node->op == GGML_OP_SCALE && + if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { return false; } @@ -2826,9 +2978,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); if (ops.size() == topk_moe_ops_with_norm.size() && - ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) { + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx+8]; + ggml_tensor * weights = cgraph->nodes[node_idx + 9]; if (ggml_cuda_should_use_topk_moe(softmax, weights)) { return true; @@ -2836,16 +2988,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } if (ops.size() == topk_moe_ops.size() && - ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) { + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx+4]; + ggml_tensor * weights = cgraph->nodes[node_idx + 4]; if (ggml_cuda_should_use_topk_moe(softmax, weights)) { return true; } } if (ops.size() == topk_moe_ops_delayed_softmax.size() && - ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_delayed_softmax, { node_idx + 2, node_idx + 5 })) { + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; ggml_tensor * weights = cgraph->nodes[node_idx + 5]; @@ -2854,6 +3006,38 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } + std::initializer_list mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU }; + std::initializer_list mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU }; + + std::initializer_list mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU }; + std::initializer_list mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU }; + + if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) || + ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) { + + const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; + const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1]; + const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2]; + const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3]; + const ggml_tensor * glu = cgraph->nodes[node_idx + 4]; + + if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) { + return true; + } + } + + if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) || + ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) { + + const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; + const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1]; + const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; + + if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) { + return true; + } + } + if (!ggml_can_fuse(cgraph, node_idx, ops)) { return false; } @@ -2934,9 +3118,20 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // With the use of CUDA graphs, the execution will be performed by the graph launch. if (!use_cuda_graph || cuda_graph_update_required) { + [[maybe_unused]] int prev_i = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + +#ifdef GGML_CUDA_DEBUG + const int nodes_fused = i - prev_i - 1; + prev_i = i; + if (nodes_fused > 0) { + GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused); + } +#endif + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } @@ -2945,17 +3140,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (!disable_fusion) { if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i+8]; - ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_tensor * weights = cgraph->nodes[i + 9]; + ggml_tensor * selected_experts = cgraph->nodes[i + 3]; + ggml_tensor * clamp = cgraph->nodes[i + 7]; ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true, - /*delayed softmax*/ false); - i += 8; + /*delayed softmax*/ false, clamp); + i += 9; continue; } if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { - ggml_tensor * weights = cgraph->nodes[i+4]; - ggml_tensor * selected_experts = cgraph->nodes[i+3]; + ggml_tensor * weights = cgraph->nodes[i + 4]; + ggml_tensor * selected_experts = cgraph->nodes[i + 3]; ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false, /*delayed softmax*/ false); i += 4; @@ -3004,6 +3200,184 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } + bool fused_mul_mat_vec = false; + int fused_node_count = 0; + + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 4]; + ggml_tensor * gate_bias_n = glu->src[0]; + ggml_tensor * up_bias_n = glu->src[1]; + + //we don't assume the order for {gate, up}. Instead infer it from the bias tensor + ggml_tensor * gate_n = nullptr; + ggml_tensor * up_n = nullptr; + + if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { + gate_n = cgraph->nodes[i]; + up_n = cgraph->nodes[i + 2]; + } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { + gate_n = cgraph->nodes[i + 2]; + up_n = cgraph->nodes[i]; + } else { + continue; + } + + auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { + if (op_bias == GGML_OP_ADD) { + if (bias_node->src[0] == mul_node) { + return bias_node->src[1]; + } + if (bias_node->src[1] == mul_node) { + return bias_node->src[0]; + } + return (ggml_tensor *) nullptr; + } + GGML_ASSERT(op_bias == GGML_OP_ADD_ID); + GGML_ASSERT(bias_node->src[0] == mul_node); + return bias_node->src[1]; + }; + + ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); + ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); + + if (!up_bias_tensor || !gate_bias_tensor) { + continue; + } + + const ggml_tensor * src0 = up_n->src[0]; + const ggml_tensor * src1 = up_n->src[1]; + const ggml_tensor * ids = up_n->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 2]; + ggml_tensor * gate = glu->src[0]; + ggml_tensor * up = glu->src[1]; + + bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) + || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); + + if (!ok) continue; + + const ggml_tensor * src0 = up->src[0]; + const ggml_tensor * src1 = up->src[1]; + const ggml_tensor * ids = up->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + } + } + + if (fused_mul_mat_vec) { + i += fused_node_count - 1; + continue; + } + + fused_mul_mat_vec = false; + fused_node_count = 0; + + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { + continue; + } + + ggml_tensor * mm_node = cgraph->nodes[i]; + ggml_tensor * bias_node = cgraph->nodes[i + 1]; + + ggml_tensor * bias_tensor = nullptr; + if (bias_op == GGML_OP_ADD) { + if (bias_node->src[0] == mm_node) { + bias_tensor = bias_node->src[1]; + } else if (bias_node->src[1] == mm_node) { + bias_tensor = bias_node->src[0]; + } else { + continue; + } + } else { + if (bias_node->src[0] != mm_node) { + continue; + } + bias_tensor = bias_node->src[1]; + } + + const ggml_tensor * src0 = mm_node->src[0]; + const ggml_tensor * src1 = mm_node->src[1]; + const ggml_tensor * ids = mm_node->src[2]; + + if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { + continue; + } + + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.x_bias = bias_tensor; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + } + + if (fused_mul_mat_vec) { + i += fused_node_count - 1; + continue; + } if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); @@ -3483,6 +3857,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g op->src[0]->type == GGML_TYPE_F32 && (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); } break; + case GGML_OP_SET: + { + const ggml_type t = op->type; + return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) && + t == op->src[0]->type && + t == op->src[1]->type; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; @@ -3642,8 +4023,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ARGSORT: - // TODO: Support arbitrary column width +#ifndef GGML_CUDA_USE_CUB return op->src[0]->ne[0] <= 1024; +#else + return true; +#endif case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 57ab839393..4e31783436 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -1,11 +1,12 @@ #include "ggml.h" #include "common.cuh" -#include "convert.cuh" +#include "unary.cuh" #include "mmvf.cuh" +#include "convert.cuh" -template +template static __global__ void mul_mat_vec_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { @@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f( y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; + bool use_gate = false; + bool use_bias = false; + bool use_gate_bias = false; + ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU; + const T * gate_x = nullptr; + const float * x_bias = nullptr; + const float * gate_bias = nullptr; + + if constexpr (has_fusion) { + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr; + glu_op = fusion.glu_op; + + if (use_gate) { + gate_x = static_cast(fusion.gate); + } + if (use_bias) { + x_bias = static_cast(fusion.x_bias); + } + if (use_gate_bias) { + gate_bias = static_cast(fusion.gate_bias); + use_gate_bias = use_gate; + } else { + use_gate_bias = false; + } + } + + if (use_gate) { + gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + } + if constexpr (has_fusion) { + const int channel_bias = ids ? channel_x : channel_dst; + if (use_bias) { + x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; + } + if (use_gate_bias) { + gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; + } + } + const float2 * y2 = (const float2 *) y; extern __shared__ char data_mmv[]; float * buf_iw = (float *) data_mmv; + float * buf_iw_gate = nullptr; + if constexpr (has_fusion) { + buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float)); + } if (block_size > warp_size) { if (tid < warp_size) { buf_iw[tid] = 0.0f; + if constexpr (has_fusion) { + if (use_gate) { + buf_iw_gate[tid] = 0.0f; + } + } } __syncthreads(); } float sumf[ncols_dst] = {0.0f}; + float sumf_gate[ncols_dst]; + if constexpr (has_fusion) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf_gate[j] = 0.0f; + } + } if constexpr (std::is_same_v) { const float2 * x2 = (const float2 *) x; + const float2 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const float2 *) gate_x; + } + } for (int col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = x2[col2]; + float2 tmpx_gate = make_float2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } } } } else if constexpr (std::is_same_v) { const half2 * x2 = (const half2 *) x; + const half2 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const half2 *) gate_x; + } + } if (std::is_same_v) { for (int col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = __half22float2(x2[col2]); - + float2 tmpx_gate = make_float2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = __half22float2(gate_x2[col2]); + } + } #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } } } } else { #ifdef FP16_AVAILABLE half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; + half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}}; for (int col2 = tid; col2 < ncols2; col2 += block_size) { const half2 tmpx = x2[col2]; - + half2 tmpx_gate = make_half2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y); + } + } } } @@ -83,6 +190,15 @@ static __global__ void mul_mat_vec_f( for (int j = 0; j < ncols_dst; ++j) { sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]); } + + if constexpr (has_fusion) { + if (use_gate) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]); + } + } + } #else NO_DEVICE_CODE; #endif // FP16_AVAILABLE @@ -91,8 +207,20 @@ static __global__ void mul_mat_vec_f( //TODO: add support for ggml_cuda_mad for hip_bfloat162 #if defined(GGML_USE_HIP) const int * x2 = (const int *) x; + const int * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const int *) gate_x; + } + } for (int col2 = tid; col2 < ncols2; col2 += block_size) { const int tmpx = x2[col2]; + int tmpx_gate = 0; + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; @@ -100,17 +228,45 @@ static __global__ void mul_mat_vec_f( const float tmpx1 = ggml_cuda_cast(reinterpret_cast(&tmpx)[1]); ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + const float tmpx0_gate = ggml_cuda_cast(reinterpret_cast(&tmpx_gate)[0]); + const float tmpx1_gate = ggml_cuda_cast(reinterpret_cast(&tmpx_gate)[1]); + ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y); + } + } } } #else const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; + const nv_bfloat162 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const nv_bfloat162 *) gate_x; + } + } for (int col2 = tid; col2 < ncols2; col2 += block_size) { const nv_bfloat162 tmpx = x2[col2]; + nv_bfloat162 tmpx_gate; + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } } } #endif @@ -122,13 +278,31 @@ static __global__ void mul_mat_vec_f( for (int j = 0; j < ncols_dst; ++j) { sumf[j] = warp_reduce_sum(sumf[j]); + if constexpr (has_fusion) { + if (use_gate) { + sumf_gate[j] = warp_reduce_sum(sumf_gate[j]); + } + } + if (block_size > warp_size) { buf_iw[tid/warp_size] = sumf[j]; + if constexpr (has_fusion) { + if (use_gate) { + buf_iw_gate[tid/warp_size] = sumf_gate[j]; + } + } __syncthreads(); if (tid < warp_size) { sumf[j] = buf_iw[tid]; sumf[j] = warp_reduce_sum(sumf[j]); + if constexpr (has_fusion) { + if (use_gate) { + sumf_gate[j] = buf_iw_gate[tid]; + sumf_gate[j] = warp_reduce_sum(sumf_gate[j]); + } + } } + if (j < ncols_dst) { __syncthreads(); } @@ -139,12 +313,74 @@ static __global__ void mul_mat_vec_f( return; } - dst[tid*stride_col_dst + row] = sumf[tid]; + float value = sumf[tid]; + + if constexpr (has_fusion) { + if (use_bias) { + value += x_bias[tid*stride_col_dst + row]; + } + + if (use_gate) { + float gate_value = sumf_gate[tid]; + if (use_gate_bias) { + gate_value += gate_bias[tid*stride_col_dst + row]; + } + switch (glu_op) { + case GGML_GLU_OP_SWIGLU: + value *= ggml_cuda_op_silu_single(gate_value); + break; + case GGML_GLU_OP_GEGLU: + value *= ggml_cuda_op_gelu_single(gate_value); + break; + case GGML_GLU_OP_SWIGLU_OAI: { + value = ggml_cuda_op_swiglu_oai_single(gate_value, value); + break; + } + default: + break; + } + } + } + + dst[tid*stride_col_dst + row] = value; + + if constexpr (!has_fusion) { + GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate); + } +} + +template +static void mul_mat_vec_f_switch_fusion( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) { + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + if constexpr (ncols_dst == 1) { + if (has_fusion) { + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + return; + } + } + + GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); + + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } template -static void launch_mul_mat_vec_f_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, +void launch_mul_mat_vec_f_cuda( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, @@ -176,57 +412,59 @@ static void launch_mul_mat_vec_f_cuda( } } - const int nbytes_shared = warp_size*sizeof(float); + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + + const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0); const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 64: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 96: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 128: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 160: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 192: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 224: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; case 256: { - mul_mat_vec_f<<>> - (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); } break; default: { GGML_ABORT("fatal error"); @@ -236,7 +474,7 @@ static void launch_mul_mat_vec_f_cuda( template static void mul_mat_vec_f_cuda_switch_ncols_dst( - const T * x, const float * y, const int32_t * ids, float * dst, + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, @@ -246,49 +484,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst( switch (ncols_dst) { case 1: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 2: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 3: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 4: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 5: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 6: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 7: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 8: launch_mul_mat_vec_f_cuda - (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; @@ -300,29 +538,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst( template static void mul_mat_vec_f_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, enum ggml_prec prec, cudaStream_t stream) { + if constexpr(std::is_same_v) { if (prec == GGML_PREC_DEFAULT) { mul_mat_vec_f_cuda_switch_ncols_dst - (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); return; } } mul_mat_vec_f_cuda_switch_ncols_dst - (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } -void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -348,6 +588,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; float * dst_d = (float *) dst->data; + ggml_cuda_mm_fusion_args_device fusion_local{}; + + if (fusion) { + GGML_ASSERT( !ids || dst->ne[2] == 1); + GGML_ASSERT( ids || dst->ne[1] == 1); + if (fusion->x_bias) { + GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]); + fusion_local.x_bias = fusion->x_bias->data; + } + if (fusion->gate) { + GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0)); + fusion_local.gate = fusion->gate->data; + } + if (fusion->gate_bias) { + GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]); + fusion_local.gate_bias = fusion->gate_bias->data; + } + fusion_local.glu_op = fusion->glu_op; + } + const int64_t s01 = src0->nb[1] / ts_src0; const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s1 = dst->nb[1] / ts_dst; @@ -370,19 +634,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; @@ -409,7 +673,6 @@ void ggml_cuda_op_mul_mat_vec_f( const int cc = ggml_cuda_info().devices[id].cc; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - // ggml_cuda_op provides single, contiguous matrices const int64_t stride_row = ne00; const int64_t stride_col_y = ne10; @@ -426,22 +689,23 @@ void ggml_cuda_op_mul_mat_vec_f( const int64_t stride_sample_y = 0; const int64_t stride_sample_dst = 0; + ggml_cuda_mm_fusion_args_device empty{}; switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0_dd_i; - mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; - mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; - mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh index 1da460992e..a205aa8e4c 100644 --- a/ggml/src/ggml-cuda/mmvf.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -1,6 +1,7 @@ #include "common.cuh" -void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion = nullptr); void ggml_cuda_op_mul_mat_vec_f( ggml_backend_cuda_context & ctx, diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 3bf0c9ed25..be04a85cc5 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,5 +1,6 @@ #include "mmvq.cuh" #include "quantize.cuh" +#include "unary.cuh" #include "vecdotq.cuh" #include @@ -82,7 +83,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { return MMVQ_PARAMETERS_GENERIC; } -static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { +static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC) { switch (ncols_dst) { case 1: @@ -136,11 +137,11 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template // tell the compiler to use as many registers as it wants, see nwarps definition below +template __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, @@ -169,8 +170,38 @@ static __global__ void mul_mat_vec_q( const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; + bool use_gate = false; + bool use_bias = false; + bool use_gate_bias = false; + const void * vgate = nullptr; + const float * x_bias = nullptr; + const float * gate_bias = nullptr; + ggml_glu_op active_glu; + + if constexpr (has_fusion) { + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr && use_gate; + vgate = fusion.gate; + x_bias = (const float *) fusion.x_bias; + gate_bias = (const float *) fusion.gate_bias; + active_glu = fusion.glu_op; + } + + const uint32_t channel_bias = ids ? channel_x : channel_dst; + + if constexpr (has_fusion) { + if (use_bias) { + x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + } + if (use_gate_bias) { + gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + } + } + // partial sum for each thread float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; + float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; @@ -187,17 +218,35 @@ static __global__ void mul_mat_vec_q( for (int i = 0; i < rows_per_cuda_block; ++i) { tmp[j][i] += vec_dot_q_cuda( vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += vec_dot_q_cuda( + vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + } + } } } } __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + if constexpr (!has_fusion) { + (void) tmp_shared_gate; + } else if (!use_gate) { + (void) tmp_shared_gate; + } + if (threadIdx.y > 0) { #pragma unroll for (int j = 0; j < ncols_dst; ++j) { #pragma unroll for (int i = 0; i < rows_per_cuda_block; ++i) { tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i]; + } + } } } } @@ -216,14 +265,55 @@ static __global__ void mul_mat_vec_q( #pragma unroll for (int l = 0; l < nwarps-1; ++l) { tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x]; + } + } } tmp[j][i] = warp_reduce_sum(tmp[j][i]); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] = warp_reduce_sum(tmp_gate[j][i]); + } + } } if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { - dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x]; + float result = tmp[j][threadIdx.x]; + if constexpr (has_fusion) { + if (use_bias) { + result += x_bias[j*stride_col_dst + threadIdx.x]; + } + if (use_gate) { + float gate_value = tmp_gate[j][threadIdx.x]; + if (use_gate_bias) { + gate_value += gate_bias[j*stride_col_dst + threadIdx.x]; + } + switch (active_glu) { + case GGML_GLU_OP_SWIGLU: + result *= ggml_cuda_op_silu_single(gate_value); + break; + case GGML_GLU_OP_GEGLU: + result *= ggml_cuda_op_gelu_single(gate_value); + break; + case GGML_GLU_OP_SWIGLU_OAI: { + result = ggml_cuda_op_swiglu_oai_single(gate_value, result); + break; + } + default: + result = result * gate_value; + break; + } + } + } + dst[j*stride_col_dst + threadIdx.x] = result; } } + + if constexpr (!has_fusion) { + GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate); + } } static std::pair calc_launch_params( @@ -235,9 +325,37 @@ static std::pair calc_launch_params( return {block_nums, block_dims}; } +template +static void mul_mat_vec_q_switch_fusion( + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + if constexpr (c_ncols_dst == 1) { + if (has_fusion) { + mul_mat_vec_q<<>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + return; + } + } + + GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); + + mul_mat_vec_q<<>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); +} + template static void mul_mat_vec_q_switch_ncols_dst( - const void * vx, const void * vy, const int32_t * ids, float * dst, + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int stride_col_y, const int stride_col_dst, const int nchannels_x, const int nchannels_y, const int nchannels_dst, @@ -256,80 +374,83 @@ static void mul_mat_vec_q_switch_ncols_dst( const int warp_size = ggml_cuda_info().devices[device].warp_size; const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + GGML_ASSERT(!ids || ncols_dst == 1); switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 2: { constexpr int c_ncols_dst = 2; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 3: { constexpr int c_ncols_dst = 3; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 4: { constexpr int c_ncols_dst = 4; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 5: { constexpr int c_ncols_dst = 5; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 6: { constexpr int c_ncols_dst = 6; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 7: { constexpr int c_ncols_dst = 7; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; case 8: { constexpr int c_ncols_dst = 8; std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q<<>> - (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, stream); } break; default: GGML_ABORT("fatal error"); break; } -} + GGML_UNUSED(has_fusion); +} static void mul_mat_vec_q_switch_type( - const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst, + const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int stride_col_y, const int stride_col_dst, const int nchannels_x, const int nchannels_y, const int nchannels_dst, @@ -339,143 +460,123 @@ static void mul_mat_vec_q_switch_type( switch (type_x) { case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q4_1: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q5_0: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q5_1: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q8_0: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_MXFP4: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q3_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q4_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q5_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_Q6_K: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ2_XXS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ2_XS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ2_S: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ3_XXS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ1_S: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ1_M: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ4_NL: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ4_XS: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case GGML_TYPE_IQ3_S: mul_mat_vec_q_switch_ncols_dst - (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, - stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; default: GGML_ABORT("fatal error"); @@ -484,7 +585,8 @@ static void mul_mat_vec_q_switch_type( } void ggml_cuda_mul_mat_vec_q( - ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. @@ -508,6 +610,31 @@ void ggml_cuda_mul_mat_vec_q( const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; float * dst_d = (float *) dst->data; + ggml_cuda_mm_fusion_args_device fusion_local{}; + + if (fusion) { + GGML_ASSERT( !ids || dst->ne[2] == 1); + GGML_ASSERT( ids || dst->ne[1] == 1); + + if (fusion->x_bias) { + GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]); + fusion_local.x_bias = fusion->x_bias->data; + } + if (fusion->gate) { + GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0)); + fusion_local.gate = fusion->gate->data; + } + if (fusion->gate_bias) { + GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]); + fusion_local.gate_bias = fusion->gate_bias->data; + } + fusion_local.glu_op = fusion->glu_op; + } + // If src0 is a temporary compute buffer, clear any potential padding. if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { const size_t size_data = ggml_nbytes(src0); @@ -549,10 +676,10 @@ void ggml_cuda_mul_mat_vec_q( const int64_t stride_channel_y = ids ? s11 : s12; mul_mat_vec_q_switch_type( - src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00, + src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, stream); + ne03, ne3, s03, s13, s3, stream); } void ggml_cuda_op_mul_mat_vec_q( @@ -578,8 +705,9 @@ void ggml_cuda_op_mul_mat_vec_q( const int stride_row_x = ne00 / ggml_blck_size(src0->type); const int stride_col_y = src1_padded_row_size / QK8_1; + ggml_cuda_mm_fusion_args_device fusion_local{}; mul_mat_vec_q_switch_type( - src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, + src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 39dc7d33eb..4bb10cfaec 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -3,7 +3,7 @@ #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); void ggml_cuda_op_mul_mat_vec_q( ggml_backend_cuda_context & ctx, diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 1525a15952..631de7e8fa 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -4,30 +4,53 @@ typedef void (*set_rows_kernel_t)(const char * src, char * dst); // Generic quantized set_rows kernel template -template -static __global__ void k_set_rows_quant( - const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst, - const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, - const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t s10, const int64_t s11, const int64_t s12, - const int64_t s1, const int64_t s2, const int64_t s3) { - +template +static __global__ void k_set_rows_quant(const float * __restrict__ src0, + const idx_t * __restrict__ src1, + block_type * __restrict__ dst, + const int64_t ne_total, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t s10, + const int64_t s11, + const int64_t s12, + const int64_t s1, + const int64_t s2, + const int64_t s3, + const uint3 ne00, + const uint3 ne01, + const uint3 ne02, + const uint3 ne11_fd, + const uint3 ne12_fd) { const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; - const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk; if (i >= ne_total) { return; } const int64_t i_base = i * qk; - const int64_t i03 = i_base / (ne00 * ne01 * ne02); - const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); - const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; - const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; + uint32_t tmp = (uint32_t) i_base; + uint2 div_mod; - const int64_t i12 = i03 % ne12; - const int64_t i11 = i02 % ne11; + div_mod = fast_div_modulo(tmp, ne00); + const int64_t i00 = div_mod.y; + tmp = div_mod.x; + + div_mod = fast_div_modulo(tmp, ne01); + const int64_t i01 = div_mod.y; + tmp = div_mod.x; + + div_mod = fast_div_modulo(tmp, ne02); + const int64_t i02 = div_mod.y; + const int64_t i03 = div_mod.x; + + const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd); + const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); @@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant( quantize_func(src_block, dst_block); GGML_UNUSED(ne10); + GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); } @@ -71,40 +96,65 @@ static void set_rows_cuda_quant( const int64_t s2 = nb2; const int64_t s3 = nb3; - if (ne_total > 0) { + if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) { + const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00); + const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01); + const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); + const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11); + const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12); + k_set_rows_quant<<>>( - src0_d, src1_d, dst_d, - ne00, ne01, ne02, ne03, - ne10, ne11, ne12, ne13, - s01, s02, s03, - s10, s11, s12, - s1, s2, s3); + src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, + ne01_fd, ne02_fd, ne11_fd, ne12_fd); } } -template -static __global__ void k_set_rows( - const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst, - const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, - const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, - const int64_t s01, const int64_t s02, const int64_t s03, - const int64_t s10, const int64_t s11, const int64_t s12, - const int64_t s1, const int64_t s2, const int64_t s3) { - +template +static __global__ void k_set_rows(const src_t * __restrict__ src0, + const idx_t * __restrict__ src1, + dst_t * __restrict__ dst, + const int64_t ne_total, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t s10, + const int64_t s11, + const int64_t s12, + const int64_t s1, + const int64_t s2, + const int64_t s3, + const uint3 ne00, + const uint3 ne01, + const uint3 ne02, + const uint3 ne11_fd, + const uint3 ne12_fd) { const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; - const int64_t ne_total = ne00 * ne01 * ne02 * ne03; if (i >= ne_total) { return; } - const int64_t i03 = i / (ne00 * ne01 * ne02); - const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); - const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; - const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; + uint32_t tmp = (uint32_t) i; + uint2 div_mod; - const int64_t i12 = i03 % ne12; - const int64_t i11 = i02 % ne11; + div_mod = fast_div_modulo(tmp, ne00); + const int64_t i00 = div_mod.y; + tmp = div_mod.x; + + div_mod = fast_div_modulo(tmp, ne01); + const int64_t i01 = div_mod.y; + tmp = div_mod.x; + + div_mod = fast_div_modulo(tmp, ne02); + const int64_t i02 = div_mod.y; + const int64_t i03 = div_mod.x; + + const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd); + const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); @@ -115,6 +165,8 @@ static __global__ void k_set_rows( dst_row_ptr[i00] = ggml_cuda_cast(src0_row[i00]); GGML_UNUSED(ne10); + GGML_UNUSED(ne11); + GGML_UNUSED(ne12); GGML_UNUSED(ne13); } @@ -144,14 +196,16 @@ static void set_rows_cuda( const int64_t s2 = nb2/sizeof(dst_t); const int64_t s3 = nb3/sizeof(dst_t); - if (ne_total > 0) { - k_set_rows<<>>( - src0_d, src1_d, dst_d, - ne00, ne01, ne02, ne03, - ne10, ne11, ne12, ne13, - s01, s02, s03, - s10, s11, s12, - s1, s2, s3); + if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) { + const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00); + const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01); + const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); + const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11); + const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12); + + k_set_rows<<>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, + s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, + ne11_fd, ne12_fd); } } diff --git a/ggml/src/ggml-cuda/set.cu b/ggml/src/ggml-cuda/set.cu new file mode 100644 index 0000000000..04bfe07ba0 --- /dev/null +++ b/ggml/src/ggml-cuda/set.cu @@ -0,0 +1,39 @@ +#include "set.cuh" +#include "cpy.cuh" + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)); + GGML_ASSERT(src1->type == src0->type); + GGML_ASSERT(dst ->type == src0->type); + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const size_t nb1 = ((int32_t *) dst->op_params)[0]; + const size_t nb2 = ((int32_t *) dst->op_params)[1]; + const size_t nb3 = ((int32_t *) dst->op_params)[2]; + const size_t offset = ((int32_t *) dst->op_params)[3]; + const bool inplace= (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + ggml_cuda_cpy(ctx, src0, dst); + } + + ggml_tensor dst_view = *dst; + dst_view.data = (void *)((char *)dst->data + offset); + dst_view.ne[0] = src1->ne[0]; + dst_view.ne[1] = src1->ne[1]; + dst_view.ne[2] = src1->ne[2]; + dst_view.ne[3] = src1->ne[3]; + + dst_view.nb[0] = ggml_element_size(dst); + dst_view.nb[1] = nb1; + dst_view.nb[2] = nb2; + dst_view.nb[3] = nb3; + + ggml_cuda_cpy(ctx, src1, &dst_view); +} diff --git a/ggml/src/ggml-cuda/set.cuh b/ggml/src/ggml-cuda/set.cuh new file mode 100644 index 0000000000..dd09529f3e --- /dev/null +++ b/ggml/src/ggml-cuda/set.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include "common.cuh" + +#define CUDA_SET_BLOCK_SIZE 256 + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index e28c810ac5..572379fcbf 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -2,6 +2,7 @@ #include "ggml.h" #include "topk-moe.cuh" +#include #include // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. @@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float * weights, int32_t * ids, const int n_rows, - const int n_expert_used) { + const int n_expert_used, + const float clamp_val) { const int row = blockIdx.x * blockDim.y + threadIdx.y; if (row >= n_rows) { return; @@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * if constexpr (with_norm) { wt_sum = warp_reduce_sum(wt_sum); + wt_sum = max(wt_sum, clamp_val); const float inv_sum = 1.0f / wt_sum; for (int i = 0; i < experts_per_thread; i++) { @@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * weights[idx] = output_weights[i]; } } + + if (!with_norm) { + GGML_UNUSED(clamp_val); + } } template @@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, int32_t * ids, const int n_rows, const int n_expert, - const int n_expert_used) { + const int n_expert_used, + const float clamp_val) { static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization"); - const int rows_per_block = 4; dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1); @@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, switch (n_expert) { case 1: topk_moe_cuda<1, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 2: topk_moe_cuda<2, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 4: topk_moe_cuda<4, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 8: topk_moe_cuda<8, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 16: topk_moe_cuda<16, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 32: topk_moe_cuda<32, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 64: topk_moe_cuda<64, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 128: topk_moe_cuda<128, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 256: topk_moe_cuda<256, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; case 512: topk_moe_cuda<512, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used); + <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); break; default: GGML_ASSERT(false && "fatal error"); @@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, ggml_tensor * weights, ggml_tensor * ids, const bool with_norm, - const bool delayed_softmax) { + const bool delayed_softmax, + ggml_tensor * clamp) { GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const int n_expert_used = weights->ne[1]; + float clamp_val = -INFINITY; if (with_norm) { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + if (clamp) { + clamp_val = ggml_get_op_params_f32(clamp, 0); + } + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val); } else { + GGML_ASSERT(clamp == nullptr); if (delayed_softmax) { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, + clamp_val); } else { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, + clamp_val); } } } -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) { float scale = 1.0f; float max_bias = 0.0f; @@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso return false; } + if (clamp) { + if (clamp->op != GGML_OP_CLAMP) { + return false; + } + float max_val = ggml_get_op_params_f32(clamp, 1); + + if (max_val != INFINITY) { + return false; + } + } + + return true; } std::initializer_list ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) { static std::initializer_list norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; + GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, + GGML_OP_RESHAPE }; static std::initializer_list no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS }; diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index cc2fbfe9e6..2eff408b03 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -8,8 +8,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, ggml_tensor * weights, ggml_tensor * ids, const bool with_norm, - const bool delayed_softmax = false); + const bool delayed_softmax = false, + ggml_tensor * weight_clamp = nullptr); -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr); std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 3c564566a5..5f0d3a6726 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -18,10 +18,7 @@ static __device__ __forceinline__ float op_step(float x) { } static __device__ __forceinline__ float op_gelu(float x) { - const float GELU_COEF_A = 0.044715f; - const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - - return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + return ggml_cuda_op_gelu_single(x); } static __device__ __forceinline__ float op_gelu_erf(float x) { @@ -37,7 +34,7 @@ static __device__ __forceinline__ float op_gelu_quick(float x) { } static __device__ __forceinline__ float op_silu(float x) { - return x / (1.0f + expf(-x)); + return ggml_cuda_op_silu_single(x); } static __device__ __forceinline__ float op_tanh(float x) { @@ -317,13 +314,8 @@ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, cons float xi = x[j0]; float gi = g[j1]; - xi = fminf(xi, limit); - gi = fmaxf(fminf(gi, limit), -limit); - float out_glu = xi / (1.0f + expf(-xi * alpha)); - out_glu = out_glu * (1.0f + gi); - - dst[i] = out_glu; + dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit); } template diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 8e7644fcd9..6c738cefec 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -1,3 +1,4 @@ +#pragma once #include "common.cuh" #define CUDA_NEG_BLOCK_SIZE 256 @@ -75,3 +76,23 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { + return x / (1.0f + expf(-x)); +} + +__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + + return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x))); +} + +__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) { + x = fminf(x, limit); + g = fmaxf(fminf(g, limit), -limit); + + float out_glu = x / (1.0f + expf(-x * alpha)); + out_glu = out_glu * (1.0f + g); + return out_glu; +} diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index ef48aa5f97..35b7e61d80 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -126,8 +126,8 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } else if (mode == GGML_SCALE_MODE_BILINEAR) { float pixel_offset = 0.5f; if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { - sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); - sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0; + sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1; pixel_offset = 0.0f; } upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt new file mode 100644 index 0000000000..166825c2c5 --- /dev/null +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -0,0 +1,68 @@ +include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) +include(ExternalProject) + +option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) + +add_library(htp_iface OBJECT + ${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c) + +set_target_properties(htp_iface PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(htp_iface PUBLIC + ${HEXAGON_SDK_ROOT}/incs + ${HEXAGON_SDK_ROOT}/incs/stddef + ${HEXAGON_SDK_ROOT}/utils/examples + ${CMAKE_CURRENT_SOURCE_DIR}/htp + ${CMAKE_CURRENT_BINARY_DIR}) + +build_idl(htp/htp_iface.idl htp_iface) + +if (CMAKE_SYSTEM_NAME MATCHES Android) + target_link_options(htp_iface PUBLIC -llog -ldl) +elseif (CMAKE_SYSTEM_NAME MATCHES Windows) + target_precompile_headers(htp_iface PUBLIC ) +else() + target_link_options(htp_iface PUBLIC -ldl) +endif() + +link_custom_library(htp_iface cdsprpc) +link_custom_library(htp_iface rpcmem) + +set(TARGET_NAME ggml-hexagon) +ggml_add_backend_library(${TARGET_NAME} + ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h) + +target_link_libraries(${TARGET_NAME} PRIVATE htp_iface) +target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR}) + +# Build HTP bits +set(HTP_CMAKE_ARGS + -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR} + -DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT} + -DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT} + -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}) + +ExternalProject_Add(htp-v73 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON + CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73") + +ExternalProject_Add(htp-v75 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON + CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75") + +ExternalProject_Add(htp-v79 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON + CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79") + +ExternalProject_Add(htp-v81 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON + CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81") + +# Install Hexagon skels required at runtime +install(FILES + ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so + ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so + ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so + ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so + TYPE LIB) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp new file mode 100644 index 0000000000..2d376a6025 --- /dev/null +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -0,0 +1,3648 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#ifdef _WIN32 +# include +# ifndef _WINDOWS +# define _WINDOWS +# endif +#else +# include +# include +#endif + +#pragma clang diagnostic ignored "-Wnested-anon-types" +#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" + +#include "htp-utils.h" + +#include +#include +#include + +#define GGML_COMMON_IMPL_CPP +#include "ggml-backend-impl.h" +#include "ggml-common.h" +#include "ggml-hexagon.h" +#include "ggml-impl.h" +#include "ggml-quants.h" +#include "htp-msg.h" +#include "htp_iface.h" + +static size_t opt_ndev = 1; +static size_t opt_nhvx = 0; // use all +static int opt_arch = 0; // autodetect +static int opt_etm = 0; +static int opt_verbose = 0; +static int opt_profile = 0; +static int opt_hostbuf = 1; +static int opt_experimental = 0; + +// Enable all stages by default +static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE; +static int opt_opsync = 0; // synchronous ops + +#define HEX_VERBOSE(...) \ + if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__) + +#define HEX_PROFILE(...) \ + if (opt_profile) GGML_LOG_INFO(__VA_ARGS__) + +static inline uint64_t hex_is_aligned(void * addr, uint32_t align) { + return ((size_t) addr & (align - 1)) == 0; +} + +static inline size_t hex_round_up(size_t n, size_t m) { + return m * ((n + m - 1) / m); +} + +static const char * status_to_str(uint32_t status) { + switch (status) { + case HTP_STATUS_OK: + return "OK"; + case HTP_STATUS_NO_SUPPORT: + return "NO-SUPPORT"; + case HTP_STATUS_INVAL_PARAMS: + return "INVAL-PARAMS"; + case HTP_STATUS_VTCM_TOO_SMALL: + return "VTCM-TOO-SMALL"; + case HTP_STATUS_INTERNAL_ERR: + return "INTERNAL-ERROR"; + default: + return "UNKNOWN"; + } +} + +// ** debug helpers + +static inline int hex_format_tensor_dims(char * str, const struct ggml_tensor * t) { + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); + } else { + return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); + } +} + +static inline void hex_format_op_dims(char * str, const struct ggml_tensor * t) { + char * p = str; + + // append src0 and src1 (if any) + if (t->src[0]) { + p += hex_format_tensor_dims(p, t->src[0]); + + for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { + p += sprintf(p, " x "); + p += hex_format_tensor_dims(p, t->src[i]); + } + + p += sprintf(p, " -> "); + } + + // format self dims separately for better visual alignment + char self[64]; + hex_format_tensor_dims(self, t); + + p += sprintf(p, "%s", self); +} + +static inline int hex_format_tensor_strides(char * str, const struct ggml_tensor * t) { + const char * c = ggml_is_contiguous(t) ? "" : "!"; + + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); + } else { + return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], + (size_t) t->nb[3], c); + } +} + +static inline void hex_format_op_strides(char * str, const struct ggml_tensor * t) { + char * p = str; + + // append src0 and src1 (if any) + if (t->src[0]) { + p += hex_format_tensor_strides(p, t->src[0]); + + for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { + p += sprintf(p, " x "); + p += hex_format_tensor_strides(p, t->src[i]); + } + + p += sprintf(p, " -> "); + } + + // format self dims separately for better visual alignment + char self[64]; + hex_format_tensor_strides(self, t); + + p += sprintf(p, "%s", self); +} + +static inline void hex_format_op_types(char * str, const struct ggml_tensor * t) { + char * p = str; + + // append src0 and src1 (if any) + if (t->src[0]) { + p += sprintf(p, "%s", ggml_type_name(t->src[0]->type)); + + for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", ggml_type_name(t->src[i]->type)); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", ggml_type_name(t->type)); +} + +static inline const char * hex_tensor_buff_name(const struct ggml_tensor * t) { + if (t->buffer) { + return ggml_backend_buffer_name(t->buffer); + } + return "NONE"; +} + +static inline void hex_format_op_buffs(char * str, const struct ggml_tensor * t) { + char * p = str; + + // append src0 and src1 (if any) + if (t->src[0]) { + p += sprintf(p, "%s", hex_tensor_buff_name(t->src[0])); + + for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", hex_tensor_buff_name(t->src[i])); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", hex_tensor_buff_name(t)); +} + +static inline void hex_format_op_names(char * str, const struct ggml_tensor * t) { + char * p = str; + + // append src0 and src1 (if any) + if (t->src[0]) { + p += sprintf(p, "%s", t->src[0]->name); + + for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", t->src[i]->name); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", t->name); +} + +// ** backend sessions + +struct ggml_hexagon_session { + ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); + ~ggml_hexagon_session() noexcept(true); + + void allocate(int dev_id) noexcept(false); + void release() noexcept(true); + + void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false); + void flush(); + + ggml_backend_buffer_type buffer_type; + ggml_backend_buffer_type repack_buffer_type; + + std::string name; + remote_handle64 handle; + dspqueue_t queue; + uint32_t session_id; + uint32_t domain_id; + uint64_t queue_id; + int dev_id; + bool valid_session; + bool valid_handle; + bool valid_queue; + bool valid_iface; + std::atomic op_pending; + uint32_t prof_usecs; + uint32_t prof_cycles; + uint32_t prof_pkts; +}; + +void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) { + // Bump pending flag (cleared in the session::flush once we get the responce) + this->op_pending++; // atomic inc + + int err = dspqueue_write(this->queue, + 0, // flags - the framework will autoset this + n_bufs, // number of buffers + bufs, // buffer references + sizeof(req), + (const uint8_t *) &req, // Message + 1000000 // Timeout + ); + + if (err != 0) { + GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err); + } + + if (sync) { + flush(); + } +} + +// Flush HTP response queue i.e wait for all outstanding requests to complete +void ggml_hexagon_session::flush() { + dspqueue_t q = this->queue; + + // Repeatedly read packets from the queue until it's empty. We don't + // necessarily get a separate callback for each packet, and new packets + // may arrive while we're processing the previous one. + + while (this->op_pending) { + struct htp_general_rsp rsp; + uint32_t rsp_size; + uint32_t flags; + + struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; + uint32_t n_bufs; + + // Read response packet from queue + int err = dspqueue_read(q, &flags, + HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references + &n_bufs, // Number of buffer references + bufs, // Buffer references + sizeof(rsp), // Max message length + &rsp_size, // Message length + (uint8_t *) &rsp, + 1000000); // Timeout + + if (err == AEE_EEXPIRED) { + // TODO: might need to bail out if the HTP is stuck on something + continue; + } + + if (err != 0) { + GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err); + } + + // Basic sanity checks + if (rsp_size != sizeof(rsp)) { + GGML_ABORT("ggml-hex: dspcall : bad response (size)\n"); + } + + if (rsp.status != HTP_STATUS_OK) { + GGML_LOG_ERROR("ggml-hex: dspcall : dsp-rsp: %s\n", status_to_str(rsp.status)); + // TODO: handle errors + } + + // TODO: update profiling implementation, currently only works for opt_opsync mode + this->prof_usecs = rsp.prof_usecs; + this->prof_cycles = rsp.prof_cycles; + this->prof_pkts = rsp.prof_pkts; + + this->op_pending--; // atomic dec + } +} + +// ** backend buffers + +struct ggml_backend_hexagon_buffer_type_context { + ggml_backend_hexagon_buffer_type_context(const std::string & name, ggml_hexagon_session * sess) { + this->sess = sess; + this->name = name; + } + + ggml_hexagon_session * sess; + std::string name; +}; + +struct ggml_backend_hexagon_buffer_context { + bool mmap_to(ggml_hexagon_session * s) { + HEX_VERBOSE("ggml-hex: %s mmaping buffer: base %p domain-id %d session-id %d size %zu fd %d repack %d\n", + s->name.c_str(), (void *) this->base, s->domain_id, s->session_id, this->size, this->fd, + (int) this->repack); + + int err = fastrpc_mmap(s->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", + s->domain_id, this->size, this->fd, (unsigned) err); + return false; + } + + return true; + } + + bool mmap() { + if (this->mapped) { + return true; + } + if (!mmap_to(this->sess)) { + return false; + } + this->mapped = true; + return true; + } + + void munmap() { + if (!this->mapped) { + return; + } + + fastrpc_munmap(this->sess->domain_id, this->fd, this->base, this->size); + this->mapped = false; + } + + ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) { + size += 4 * 1024; // extra page for padding + + this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); + if (!this->base) { + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size); + throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)"); + } + + this->fd = rpcmem_to_fd(this->base); + if (this->fd < 0) { + GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->name.c_str(), (void *) this->base); + rpcmem_free(this->base); + this->base = NULL; + throw std::runtime_error("ggml-hex: rpcmem_to_fd failed (see log for details)"); + } + + HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d repack %d\n", sess->name.c_str(), + (void *) this->base, size, this->fd, (int) repack); + + this->sess = sess; + this->size = size; + this->mapped = false; + this->repack = repack; + } + + ~ggml_backend_hexagon_buffer_context() { + munmap(); + if (this->base) { + rpcmem_free(this->base); + this->base = NULL; + } + } + + ggml_hexagon_session * sess; // primary session + uint8_t * base; + size_t size; + int fd; + bool mapped; // mmap is done + bool repack; // repacked buffer +}; + +static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_buffer_t buffer) { + return static_cast(buffer->buft->context)->sess; +} + +static void ggml_backend_hexagon_buffer_free_buffer(ggml_backend_buffer_t buffer) { + auto ctx = static_cast(buffer->context); + delete ctx; +} + +static void * ggml_backend_hexagon_buffer_get_base(ggml_backend_buffer_t buffer) { + auto ctx = static_cast(buffer->context); + return ctx->base; +} + +static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + auto ctx = static_cast(buffer->context); + auto sess = ctx->sess; + + HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d repack %d\n", sess->name.c_str(), + tensor->name, (void *) ctx->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage, + (int) ctx->repack); + + if (tensor->view_src != NULL && tensor->view_offs == 0) { + ; // nothing to do for the view + } else { + if (!ctx->mapped) { + ctx->mmap(); + } + } + return GGML_STATUS_SUCCESS; +} + +// ======== Q4x4x2 ==================== +struct x2_q4 { + int v[2]; +}; + +static x2_q4 unpack_q4(uint8_t v) { + x2_q4 x = { (int) (v & 0x0f) - 8, (int) (v >> 4) - 8 }; + return x; +} + +static void dump_block_q4_0(const block_q4_0 * b, int i) { + HEX_VERBOSE("ggml-hex: repack q4_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_q4(b->qs[0]).v[0], + unpack_q4(b->qs[1]).v[0], unpack_q4(b->qs[2]).v[0], unpack_q4(b->qs[3]).v[0], unpack_q4(b->qs[12]).v[1], + unpack_q4(b->qs[13]).v[1], unpack_q4(b->qs[14]).v[1], unpack_q4(b->qs[15]).v[1], + GGML_FP16_TO_FP32(b->d)); +} + +static void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k) { + static const int qk = QK_Q4_0x4x2; + const int dblk_size = 8 * 2; // 8x __fp16 + const int qblk_size = qk / 2; // int4 + const int qrow_size = k / 2; // int4 (not padded) + + const uint8_t * v_q = v + 0; // quants first + const uint8_t * v_d = v + qrow_size; // then scales + + const uint8_t * q = v_q + i * qblk_size; + const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size); + + HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, + unpack_q4(q[0]).v[0], unpack_q4(q[1]).v[0], unpack_q4(q[2]).v[0], unpack_q4(q[3]).v[0], + unpack_q4(q[60]).v[0], unpack_q4(q[61]).v[0], unpack_q4(q[62]).v[0], unpack_q4(q[63]).v[0], + unpack_q4(q[124]).v[0], unpack_q4(q[125]).v[0], unpack_q4(q[126]).v[0], unpack_q4(q[127]).v[0], + GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3])); + + HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", + i + 1, unpack_q4(q[0]).v[1], unpack_q4(q[1]).v[1], unpack_q4(q[2]).v[1], unpack_q4(q[3]).v[1], + unpack_q4(q[60]).v[1], unpack_q4(q[61]).v[1], unpack_q4(q[62]).v[1], unpack_q4(q[63]).v[1], + unpack_q4(q[124]).v[1], unpack_q4(q[125]).v[1], unpack_q4(q[126]).v[1], unpack_q4(q[127]).v[1], + GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7])); +} + +static void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) { + static const int qk = QK4_0; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const int x0 = (x->qs[i] & 0x0F); + const int x1 = (x->qs[i] >> 4); + qs[bi * qk + i + 0] = x0; + qs[bi * qk + i + qk / 2] = x1; + } +} + +static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi) { + static const int qk = QK4_0; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t x0 = qs[bi * qk + i + 0]; + const uint8_t x1 = qs[bi * qk + i + qk / 2]; + x->qs[i] = x0 | (x1 << 4); + } +} + +static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + const int dblk_size = 8 * 2; // 8x __fp16 + const int qblk_size = qk / 2; // int4 + const int qrow_size = k / 2; // int4 (not padded to blocks) + + uint8_t * y_q = y + 0; // quants first + uint8_t * y_d = y + qrow_size; // then scales + + if (opt_verbose > 2) { + for (int i = 0; i < nb; i++) { + dump_block_q4_0(&x[i * 8 + 0], 0); + dump_block_q4_0(&x[i * 8 + 1], 1); + dump_block_q4_0(&x[i * 8 + 2], 2); + dump_block_q4_0(&x[i * 8 + 3], 3); + dump_block_q4_0(&x[i * 8 + 4], 4); + dump_block_q4_0(&x[i * 8 + 5], 5); + dump_block_q4_0(&x[i * 8 + 6], 6); + dump_block_q4_0(&x[i * 8 + 7], 7); + } + } + + // Repack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + unpack_q4_0_quants(qs, &x[i * 8 + 0], 0); + unpack_q4_0_quants(qs, &x[i * 8 + 1], 1); + unpack_q4_0_quants(qs, &x[i * 8 + 2], 2); + unpack_q4_0_quants(qs, &x[i * 8 + 3], 3); + unpack_q4_0_quants(qs, &x[i * 8 + 4], 4); + unpack_q4_0_quants(qs, &x[i * 8 + 5], 5); + unpack_q4_0_quants(qs, &x[i * 8 + 6], 6); + unpack_q4_0_quants(qs, &x[i * 8 + 7], 7); + + uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + q[j] = (qs[j + 128] << 4) | qs[j]; + } + } + + // Repack the scales + // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) + // the last block is truncated and overriden by the scales. + for (int i = 0; i < nb; i++) { + // Repack the scales + ggml_half * d = (ggml_half *) (y_d + i * dblk_size); + d[0] = x[i * 8 + 0].d; + d[1] = x[i * 8 + 1].d; + d[2] = x[i * 8 + 2].d; + d[3] = x[i * 8 + 3].d; + d[4] = x[i * 8 + 4].d; + d[5] = x[i * 8 + 5].d; + d[6] = x[i * 8 + 6].d; + d[7] = x[i * 8 + 7].d; + } + + if (opt_verbose > 1) { + for (int i = 0; i < nb; i++) { + dump_packed_block_q4x4x2(y, i, k); + } + } +} + +static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + const int dblk_size = 8 * 2; // 8x __fp16 + const int qblk_size = qk / 2; // int4 + const int qrow_size = k / 2; // int4 (not padded to blocks) + + const uint8_t * y_q = y + 0; // quants first + const uint8_t * y_d = y + qrow_size; // then scales + + if (opt_verbose > 1) { + for (int i = 0; i < nb; i++) { + dump_packed_block_q4x4x2(y, i, k); + } + } + + // Unpack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + + const uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + qs[j] = q[j] & 0xf; + qs[j + 128] = q[j] >> 4; + } + + pack_q4_0_quants(&x[i * 8 + 0], qs, 0); + pack_q4_0_quants(&x[i * 8 + 1], qs, 1); + pack_q4_0_quants(&x[i * 8 + 2], qs, 2); + pack_q4_0_quants(&x[i * 8 + 3], qs, 3); + pack_q4_0_quants(&x[i * 8 + 4], qs, 4); + pack_q4_0_quants(&x[i * 8 + 5], qs, 5); + pack_q4_0_quants(&x[i * 8 + 6], qs, 6); + pack_q4_0_quants(&x[i * 8 + 7], qs, 7); + } + + // Repack the scales + // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) + // the last block is truncated and overriden by the scales. + for (int i = 0; i < nb; i++) { + // Unpack the scales + const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); + x[i * 8 + 0].d = d[0]; + x[i * 8 + 1].d = d[1]; + x[i * 8 + 2].d = d[2]; + x[i * 8 + 3].d = d[3]; + x[i * 8 + 4].d = d[4]; + x[i * 8 + 5].d = d[5]; + x[i * 8 + 6].d = d[6]; + x[i * 8 + 7].d = d[7]; + } + + if (opt_verbose > 2) { + for (int i = 0; i < nb; i++) { + dump_block_q4_0(&x[i * 8 + 0], 0); + dump_block_q4_0(&x[i * 8 + 1], 1); + dump_block_q4_0(&x[i * 8 + 2], 2); + dump_block_q4_0(&x[i * 8 + 3], 3); + dump_block_q4_0(&x[i * 8 + 4], 4); + dump_block_q4_0(&x[i * 8 + 5], 5); + dump_block_q4_0(&x[i * 8 + 6], 6); + dump_block_q4_0(&x[i * 8 + 7], 7); + } + } +} + +static void init_row_q4x4x2(block_q4_0 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + // Init the quants such that they unpack into zeros + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + memset(qs, 8, sizeof(qs)); + + for (int i = 0; i < nb; i++) { + pack_q4_0_quants(&x[i * 8 + 0], qs, 0); + pack_q4_0_quants(&x[i * 8 + 1], qs, 1); + pack_q4_0_quants(&x[i * 8 + 2], qs, 2); + pack_q4_0_quants(&x[i * 8 + 3], qs, 3); + pack_q4_0_quants(&x[i * 8 + 4], qs, 4); + pack_q4_0_quants(&x[i * 8 + 5], qs, 5); + pack_q4_0_quants(&x[i * 8 + 6], qs, 6); + pack_q4_0_quants(&x[i * 8 + 7], qs, 7); + } + + // Init the scales + // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) + // the last block is truncated and overriden by the scales. + for (int i = 0; i < nb; i++) { + // Unpack the scales + x[i * 8 + 0].d = 0; + x[i * 8 + 1].d = 0; + x[i * 8 + 2].d = 0; + x[i * 8 + 3].d = 0; + x[i * 8 + 4].d = 0; + x[i * 8 + 5].d = 0; + x[i * 8 + 6].d = 0; + x[i * 8 + 7].d = 0; + } +} + +// repack q4_0 data into q4x4x2 tensor +static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad + size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4_0-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros + + for (int64_t i = 0; i < nrows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, row_size); + repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +// repack q4x4x2 tensor into q4_0 data +static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad + size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros + + for (int64_t i = 0; i < nrows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_pd, src, row_size); + unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +// ======== Q8x4x2 ==================== +static void dump_block_q8_0(const block_q8_0 * b, int i) { + HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2], + b->qs[3], b->qs[28], b->qs[29], b->qs[30], b->qs[31], GGML_FP16_TO_FP32(b->d)); +} + +static void dump_packed_block_q8x4x2(const uint8_t * v, unsigned int i, size_t k) { + static const int qk = QK_Q8_0x4x2; + const int dblk_size = 8 * 2; // 8x __fp16 + const int qblk_size = qk; // int8 + const int qrow_size = k; // int8 (not padded) + + const uint8_t * v_q = v + 0; // quants first + const uint8_t * v_d = v + qrow_size; // then scales + + const uint8_t * q = v_q + i * qblk_size; + const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size); + + HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, + q[0], q[1], q[2], q[3], q[60], q[61], q[62], q[63], q[124], q[125], q[126], q[127], + GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3])); + + HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", + i + 1, q[128], q[129], q[130], q[131], q[192], q[193], q[194], q[195], q[252], q[253], q[254], q[255], + GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7])); +} + +static void unpack_q8_0_quants(uint8_t * qs, const block_q8_0 * x, unsigned int bi) { + static const int qk = QK8_0; + + for (unsigned int i = 0; i < qk; ++i) { + qs[bi * qk + i] = x->qs[i]; + } +} + +static void pack_q8_0_quants(block_q8_0 * x, const uint8_t * qs, unsigned int bi) { + static const int qk = QK8_0; + + for (unsigned int i = 0; i < qk; ++i) { + x->qs[i] = qs[bi * qk + i]; + } +} + +static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) { + static const int qk = QK_Q8_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + const int dblk_size = 8 * 2; // 8x __fp16 + const int qblk_size = qk; // int8 + const int qrow_size = k; // int8 (not padded to blocks) + + uint8_t * y_q = y + 0; // quants first + uint8_t * y_d = y + qrow_size; // then scales + + if (opt_verbose > 2) { + for (int i = 0; i < nb; i++) { + dump_block_q8_0(&x[i * 8 + 0], 0); + dump_block_q8_0(&x[i * 8 + 1], 1); + dump_block_q8_0(&x[i * 8 + 2], 2); + dump_block_q8_0(&x[i * 8 + 3], 3); + dump_block_q8_0(&x[i * 8 + 4], 4); + dump_block_q8_0(&x[i * 8 + 5], 5); + dump_block_q8_0(&x[i * 8 + 6], 6); + dump_block_q8_0(&x[i * 8 + 7], 7); + } + } + + // Repack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q8_0x4x2]; // unpacked quants + + unpack_q8_0_quants(qs, &x[i * 8 + 0], 0); + unpack_q8_0_quants(qs, &x[i * 8 + 1], 1); + unpack_q8_0_quants(qs, &x[i * 8 + 2], 2); + unpack_q8_0_quants(qs, &x[i * 8 + 3], 3); + unpack_q8_0_quants(qs, &x[i * 8 + 4], 4); + unpack_q8_0_quants(qs, &x[i * 8 + 5], 5); + unpack_q8_0_quants(qs, &x[i * 8 + 6], 6); + unpack_q8_0_quants(qs, &x[i * 8 + 7], 7); + + uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk; j++) { + q[j] = qs[j]; + } + } + + // Repack the scales + // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) + // the last block is truncated and overriden by the scales. + for (int i = 0; i < nb; i++) { + // Repack the scales + ggml_half * d = (ggml_half *) (y_d + i * dblk_size); + d[0] = x[i * 8 + 0].d; + d[1] = x[i * 8 + 1].d; + d[2] = x[i * 8 + 2].d; + d[3] = x[i * 8 + 3].d; + d[4] = x[i * 8 + 4].d; + d[5] = x[i * 8 + 5].d; + d[6] = x[i * 8 + 6].d; + d[7] = x[i * 8 + 7].d; + } + + if (opt_verbose > 1) { + for (int i = 0; i < nb; i++) { + dump_packed_block_q8x4x2(y, i, k); + } + } +} + +static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) { + static const int qk = QK_Q8_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + const int dblk_size = 8 * 2; // 8x __fp16 + const int qblk_size = qk; // int8 + const int qrow_size = k; // int8 (not padded to blocks) + + const uint8_t * y_q = y + 0; // quants first + const uint8_t * y_d = y + qrow_size; // then scales + + if (opt_verbose > 1) { + for (int i = 0; i < nb; i++) { + dump_packed_block_q8x4x2(y, i, k); + } + } + + // Unpack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + + const uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk; j++) { + qs[j] = q[j]; + } + + pack_q8_0_quants(&x[i * 8 + 0], qs, 0); + pack_q8_0_quants(&x[i * 8 + 1], qs, 1); + pack_q8_0_quants(&x[i * 8 + 2], qs, 2); + pack_q8_0_quants(&x[i * 8 + 3], qs, 3); + pack_q8_0_quants(&x[i * 8 + 4], qs, 4); + pack_q8_0_quants(&x[i * 8 + 5], qs, 5); + pack_q8_0_quants(&x[i * 8 + 6], qs, 6); + pack_q8_0_quants(&x[i * 8 + 7], qs, 7); + } + + // Repack the scales + // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) + // the last block is truncated and overriden by the scales. + for (int i = 0; i < nb; i++) { + // Unpack the scales + const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); + x[i * 8 + 0].d = d[0]; + x[i * 8 + 1].d = d[1]; + x[i * 8 + 2].d = d[2]; + x[i * 8 + 3].d = d[3]; + x[i * 8 + 4].d = d[4]; + x[i * 8 + 5].d = d[5]; + x[i * 8 + 6].d = d[6]; + x[i * 8 + 7].d = d[7]; + } + + if (opt_verbose > 2) { + for (int i = 0; i < nb; i++) { + dump_block_q8_0(&x[i * 8 + 0], 0); + dump_block_q8_0(&x[i * 8 + 1], 1); + dump_block_q8_0(&x[i * 8 + 2], 2); + dump_block_q8_0(&x[i * 8 + 3], 3); + dump_block_q8_0(&x[i * 8 + 4], 4); + dump_block_q8_0(&x[i * 8 + 5], 5); + dump_block_q8_0(&x[i * 8 + 6], 6); + dump_block_q8_0(&x[i * 8 + 7], 7); + } + } +} + +static void init_row_q8x4x2(block_q8_0 * x, int64_t k) { + static const int qk = QK_Q8_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + // Init the quants such that they unpack into zeros + uint8_t qs[QK_Q8_0x4x2]; // unpacked quants + memset(qs, 0, sizeof(qs)); + + for (int i = 0; i < nb; i++) { + pack_q8_0_quants(&x[i * 8 + 0], qs, 0); + pack_q8_0_quants(&x[i * 8 + 1], qs, 1); + pack_q8_0_quants(&x[i * 8 + 2], qs, 2); + pack_q8_0_quants(&x[i * 8 + 3], qs, 3); + pack_q8_0_quants(&x[i * 8 + 4], qs, 4); + pack_q8_0_quants(&x[i * 8 + 5], qs, 5); + pack_q8_0_quants(&x[i * 8 + 6], qs, 6); + pack_q8_0_quants(&x[i * 8 + 7], qs, 7); + } + + // Init the scales + // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2) + // the last block is truncated and overriden by the scales. + for (int i = 0; i < nb; i++) { + // Unpack the scales + x[i * 8 + 0].d = 0; + x[i * 8 + 1].d = 0; + x[i * 8 + 2].d = 0; + x[i * 8 + 3].d = 0; + x[i * 8 + 4].d = 0; + x[i * 8 + 5].d = 0; + x[i * 8 + 6].d = 0; + x[i * 8 + 7].d = 0; + } +} + +// repack q8_0 data into q8x4x2 tensor +static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad + size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q8_0-q8x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros + + for (int64_t i = 0; i < nrows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, row_size); + repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +// repack q8x4x2 tensor into q8_0 data +static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad + size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q8x4x2-q8_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros + + for (int64_t i = 0; i < nrows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_pd, src, row_size); + unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +// ======== MXFP4x4x2 ==================== +struct x2_mxfp4 { + int v[2]; +}; + +static x2_mxfp4 unpack_mxfp4(uint8_t v) { + x2_mxfp4 x; + x.v[0] = kvalues_mxfp4[(v & 0x0f)]; + x.v[1] = kvalues_mxfp4[(v >> 4)]; + return x; +} + +static void dump_block_mxfp4(const block_mxfp4 * b, int i) { + HEX_VERBOSE("ggml-hex: repack mxfp4 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_mxfp4(b->qs[0]).v[0], + unpack_mxfp4(b->qs[1]).v[0], unpack_mxfp4(b->qs[2]).v[0], unpack_mxfp4(b->qs[3]).v[0], + unpack_mxfp4(b->qs[12]).v[1], unpack_mxfp4(b->qs[13]).v[1], unpack_mxfp4(b->qs[14]).v[1], + unpack_mxfp4(b->qs[15]).v[1], GGML_E8M0_TO_FP32_HALF(b->e)); +} + +static void dump_packed_block_mxfp4x4x2(const uint8_t * v, unsigned int i, size_t k) { + static const int qk = QK_MXFP4x4x2; + const int eblk_size = 8 * 1; // 8x E8M0 + const int qblk_size = qk / 2; // int4 + const int qrow_size = k / 2; // int4 (not padded) + + const uint8_t * v_q = v + 0; // quants first + const uint8_t * v_e = v + qrow_size; // then scales + + const uint8_t * q = v_q + i * qblk_size; + const uint8_t * e = (const uint8_t *) (v_e + i * eblk_size); + + HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, + unpack_mxfp4(q[0]).v[0], unpack_mxfp4(q[1]).v[0], unpack_mxfp4(q[2]).v[0], unpack_mxfp4(q[3]).v[0], + unpack_mxfp4(q[60]).v[0], unpack_mxfp4(q[61]).v[0], unpack_mxfp4(q[62]).v[0], unpack_mxfp4(q[63]).v[0], + unpack_mxfp4(q[124]).v[0], unpack_mxfp4(q[125]).v[0], unpack_mxfp4(q[126]).v[0], + unpack_mxfp4(q[127]).v[0], GGML_E8M0_TO_FP32_HALF(e[0]), GGML_E8M0_TO_FP32_HALF(e[1]), + GGML_E8M0_TO_FP32_HALF(e[2]), GGML_E8M0_TO_FP32_HALF(e[3])); + + HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", + i + 1, unpack_mxfp4(q[0]).v[1], unpack_mxfp4(q[1]).v[1], unpack_mxfp4(q[2]).v[1], + unpack_mxfp4(q[3]).v[1], unpack_mxfp4(q[60]).v[1], unpack_mxfp4(q[61]).v[1], unpack_mxfp4(q[62]).v[1], + unpack_mxfp4(q[63]).v[1], unpack_mxfp4(q[124]).v[1], unpack_mxfp4(q[125]).v[1], + unpack_mxfp4(q[126]).v[1], unpack_mxfp4(q[127]).v[1], GGML_E8M0_TO_FP32_HALF(e[4]), + GGML_E8M0_TO_FP32_HALF(e[5]), GGML_E8M0_TO_FP32_HALF(e[6]), GGML_E8M0_TO_FP32_HALF(e[7])); +} + +static void unpack_mxfp4_quants(uint8_t * qs, const block_mxfp4 * x, unsigned int bi) { + static const int qk = QK_MXFP4; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t x0 = (x->qs[i] & 0x0F); + const uint8_t x1 = (x->qs[i] >> 4); + qs[bi * qk + i + 0] = x0; + qs[bi * qk + i + qk / 2] = x1; + } +} + +static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int bi) { + static const int qk = QK4_0; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t x0 = qs[bi * qk + i + 0]; + const uint8_t x1 = qs[bi * qk + i + qk / 2]; + x->qs[i] = x0 | (x1 << 4); + } +} + +static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) { + static const int qk = QK_MXFP4x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + const int eblk_size = 8 * 1; // 8x E8M0 + const int qblk_size = qk / 2; // int4 + const int qrow_size = k / 2; // int4 (not padded to blocks) + + uint8_t * y_q = y + 0; // quants first + uint8_t * y_e = y + qrow_size; // then scales + + if (opt_verbose > 2) { + for (int i = 0; i < nb; i++) { + dump_block_mxfp4(&x[i * 8 + 0], 0); + dump_block_mxfp4(&x[i * 8 + 1], 1); + dump_block_mxfp4(&x[i * 8 + 2], 2); + dump_block_mxfp4(&x[i * 8 + 3], 3); + dump_block_mxfp4(&x[i * 8 + 4], 4); + dump_block_mxfp4(&x[i * 8 + 5], 5); + dump_block_mxfp4(&x[i * 8 + 6], 6); + dump_block_mxfp4(&x[i * 8 + 7], 7); + } + } + + // Repack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_MXFP4x4x2]; // unpacked quants + + unpack_mxfp4_quants(qs, &x[i * 8 + 0], 0); + unpack_mxfp4_quants(qs, &x[i * 8 + 1], 1); + unpack_mxfp4_quants(qs, &x[i * 8 + 2], 2); + unpack_mxfp4_quants(qs, &x[i * 8 + 3], 3); + unpack_mxfp4_quants(qs, &x[i * 8 + 4], 4); + unpack_mxfp4_quants(qs, &x[i * 8 + 5], 5); + unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6); + unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7); + + uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + q[j] = (qs[j + 128] << 4) | qs[j]; + } + } + + // Repack the scales + // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) + // the last block is truncated and overriden by the scales. + for (int i = 0; i < nb; i++) { + // Repack the scales + uint8_t * e = (uint8_t *) (y_e + i * eblk_size); + e[0] = x[i * 8 + 0].e; + e[1] = x[i * 8 + 1].e; + e[2] = x[i * 8 + 2].e; + e[3] = x[i * 8 + 3].e; + e[4] = x[i * 8 + 4].e; + e[5] = x[i * 8 + 5].e; + e[6] = x[i * 8 + 6].e; + e[7] = x[i * 8 + 7].e; + } + + if (opt_verbose > 1) { + for (int i = 0; i < nb; i++) { + dump_packed_block_mxfp4x4x2(y, i, k); + } + } +} + +static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) { + static const int qk = QK_MXFP4x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + const int eblk_size = 8 * 1; // 8x E8M0 + const int qblk_size = qk / 2; // int4 + const int qrow_size = k / 2; // int4 (not padded to blocks) + + const uint8_t * y_q = y + 0; // quants first + const uint8_t * y_e = y + qrow_size; // then scales + + if (opt_verbose > 1) { + for (int i = 0; i < nb; i++) { + dump_packed_block_mxfp4x4x2(y, i, k); + } + } + + // Unpack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_MXFP4x4x2]; // unpacked quants + + const uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + qs[j] = q[j] & 0xf; + qs[j + 128] = q[j] >> 4; + } + + pack_mxfp4_quants(&x[i * 8 + 0], qs, 0); + pack_mxfp4_quants(&x[i * 8 + 1], qs, 1); + pack_mxfp4_quants(&x[i * 8 + 2], qs, 2); + pack_mxfp4_quants(&x[i * 8 + 3], qs, 3); + pack_mxfp4_quants(&x[i * 8 + 4], qs, 4); + pack_mxfp4_quants(&x[i * 8 + 5], qs, 5); + pack_mxfp4_quants(&x[i * 8 + 6], qs, 6); + pack_mxfp4_quants(&x[i * 8 + 7], qs, 7); + } + + // Repack the scales + // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2) + // the last block is truncated and overriden by the scales. + for (int i = 0; i < nb; i++) { + // Unpack the scales + const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size); + x[i * 8 + 0].e = e[0]; + x[i * 8 + 1].e = e[1]; + x[i * 8 + 2].e = e[2]; + x[i * 8 + 3].e = e[3]; + x[i * 8 + 4].e = e[4]; + x[i * 8 + 5].e = e[5]; + x[i * 8 + 6].e = e[6]; + x[i * 8 + 7].e = e[7]; + } + + if (opt_verbose > 2) { + for (int i = 0; i < nb; i++) { + dump_block_mxfp4(&x[i * 8 + 0], 0); + dump_block_mxfp4(&x[i * 8 + 1], 1); + dump_block_mxfp4(&x[i * 8 + 2], 2); + dump_block_mxfp4(&x[i * 8 + 3], 3); + dump_block_mxfp4(&x[i * 8 + 4], 4); + dump_block_mxfp4(&x[i * 8 + 5], 5); + dump_block_mxfp4(&x[i * 8 + 6], 6); + dump_block_mxfp4(&x[i * 8 + 7], 7); + } + } +} + +static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) { + static const int qk = QK_MXFP4x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + // Init the quants such that they unpack into zeros + uint8_t qs[QK_MXFP4x4x2]; // unpacked quants + memset(qs, 0, sizeof(qs)); + + for (int i = 0; i < nb; i++) { + pack_mxfp4_quants(&x[i * 8 + 0], qs, 0); + pack_mxfp4_quants(&x[i * 8 + 1], qs, 1); + pack_mxfp4_quants(&x[i * 8 + 2], qs, 2); + pack_mxfp4_quants(&x[i * 8 + 3], qs, 3); + pack_mxfp4_quants(&x[i * 8 + 4], qs, 4); + pack_mxfp4_quants(&x[i * 8 + 5], qs, 5); + pack_mxfp4_quants(&x[i * 8 + 6], qs, 6); + pack_mxfp4_quants(&x[i * 8 + 7], qs, 7); + } + + // Init the scales + // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) + // the last block is truncated and overriden by the scales. + for (int i = 0; i < nb; i++) { + // Unpack the scales + x[i * 8 + 0].e = 0; + x[i * 8 + 1].e = 0; + x[i * 8 + 2].e = 0; + x[i * 8 + 3].e = 0; + x[i * 8 + 4].e = 0; + x[i * 8 + 5].e = 0; + x[i * 8 + 6].e = 0; + x[i * 8 + 7].e = 0; + } +} + +// repack mxfp4 data into mxfp4x4x2 tensor +static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad + size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-mxfp4-mxfp4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, + size, t->ne[0], nrows, row_size); + + init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros + + for (int64_t i = 0; i < nrows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, row_size); + repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +// repack mxfp4x4x2 tensor into mxfp4 data +static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad + size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-mxfp4x4x2-mxfp4 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, + size, t->ne[0], nrows, row_size); + + memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros + + for (int64_t i = 0; i < nrows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_pd, src, row_size); + unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; + auto sess = ctx->sess; + + HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data, + offset, size); + + switch (tensor->type) { + case GGML_TYPE_Q4_0: + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + repack_q4_0_q4x4x2(tensor, data, size); + break; + + case GGML_TYPE_Q8_0: + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + repack_q8_0_q8x4x2(tensor, data, size); + break; + + case GGML_TYPE_MXFP4: + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + repack_mxfp4_mxfp4x4x2(tensor, data, size); + break; + + default: + memcpy((char *) tensor->data + offset, data, size); + break; + } +} + +static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; + auto sess = ctx->sess; + + HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data, + offset, size); + + switch (tensor->type) { + case GGML_TYPE_Q4_0: + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + repack_q4x4x2_q4_0(data, tensor, size); + break; + + case GGML_TYPE_Q8_0: + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + repack_q8x4x2_q8_0(data, tensor, size); + break; + + case GGML_TYPE_MXFP4: + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + repack_mxfp4x4x2_mxfp4(data, tensor, size); + break; + + default: + memcpy(data, (const char *) tensor->data + offset, size); + break; + } +} + +static bool ggml_backend_hexagon_buffer_cpy_tensor(ggml_backend_buffer_t buffer, + const struct ggml_tensor * src, + struct ggml_tensor * dst) { + GGML_UNUSED(buffer); + GGML_UNUSED(src); + GGML_UNUSED(dst); + // we might optimize this later, for now take the slow path (ie get/set_tensor) + return false; +} + +static void ggml_backend_hexagon_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; + auto sess = ctx->sess; + HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->name.c_str(), (void *) ctx->base, ctx->size); + memset(ctx->base, value, ctx->size); +} + +static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = { + /* .free_buffer = */ ggml_backend_hexagon_buffer_free_buffer, + /* .get_base = */ ggml_backend_hexagon_buffer_get_base, + /* .init_tensor = */ ggml_backend_hexagon_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_hexagon_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_hexagon_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_hexagon_buffer_cpy_tensor, + /* .clear = */ ggml_backend_hexagon_buffer_clear, + /* .reset = */ NULL, +}; + +// ** backend buffer type + +static const char * ggml_backend_hexagon_buffer_type_name(ggml_backend_buffer_type_t buffer_type) { + return static_cast(buffer_type->context)->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer( + ggml_backend_buffer_type_t buffer_type, size_t size) { + auto sess = static_cast(buffer_type->context)->sess; + try { + ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/); + return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); + } catch (std::exception const &exc) { + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); + return nullptr; + } +} + +static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffer( + ggml_backend_buffer_type_t buffer_type, size_t size) { + auto sess = static_cast(buffer_type->context)->sess; + try { + ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/); + return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); + } catch (std::exception const &exc) { + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); + return nullptr; + } +} + +static size_t ggml_backend_hexagon_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { + return 128; // HVX alignment + GGML_UNUSED(buffer_type); +} + +static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * t) { + return ggml_nbytes(t); +} + +static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { + return 1 * 1024 * 1024 * 1024; // 1GB per buffer + GGML_UNUSED(buffer_type); +} + +static bool ggml_backend_hexagon_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return opt_hostbuf; + GGML_UNUSED(buft); +} + +static bool ggml_backend_hexagon_repack_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return false; + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_type_i ggml_backend_hexagon_buffer_type_interface = { + /* .get_name = */ ggml_backend_hexagon_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_hexagon_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_hexagon_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_hexagon_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_hexagon_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_hexagon_buffer_type_is_host, +}; + +static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interface = { + /* .get_name = */ ggml_backend_hexagon_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_hexagon_repack_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_hexagon_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_hexagon_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_hexagon_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host, +}; + +void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { + this->valid_session = false; + this->valid_handle = false; + this->valid_queue = false; + this->valid_iface = false; + + this->domain_id = 3; // Default for CDSP, updated after the session is created + this->session_id = 0; // Default for CDSP, updated after the session is created + this->dev_id = dev_id; + this->name = std::string("HTP") + std::to_string(dev_id); + + this->op_pending = 0; + this->prof_usecs = 0; + this->prof_cycles = 0; + this->prof_pkts = 0; + + GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str()); + + domain * my_domain = get_domain(this->domain_id); + if (my_domain == NULL) { + GGML_LOG_ERROR("ggml-hex: unable to get domain struct for CDSP\n"); + throw std::runtime_error("ggml-hex: failed to get CDSP domain (see log for details)"); + } + + // Create new session + if (dev_id != 0) { + struct remote_rpc_reserve_new_session n; + n.domain_name_len = strlen(CDSP_DOMAIN_NAME); + n.domain_name = const_cast(CDSP_DOMAIN_NAME); + n.session_name = const_cast(this->name.c_str()); + n.session_name_len = this->name.size(); + + int err = remote_session_control(FASTRPC_RESERVE_NEW_SESSION, (void *) &n, sizeof(n)); + if (err != AEE_SUCCESS) { + GGML_LOG_ERROR("ggml-hex: failed to reserve new session %d : error 0x%x\n", dev_id, err); + throw std::runtime_error("ggml-hex: remote_session_control(new-sess) failed (see log for details)"); + } + + // Save the IDs + this->session_id = n.session_id; + this->domain_id = n.effective_domain_id; + this->valid_session = true; + } + + // Get session URI + char htp_uri[256]; + sprintf(htp_uri, "file:///libggml-htp-v%u.so?htp_iface_skel_handle_invoke&_modver=1.0", opt_arch); + + char session_uri[256]; + { + struct remote_rpc_get_uri u; + u.session_id = this->session_id; + u.domain_name = const_cast(CDSP_DOMAIN_NAME); + u.domain_name_len = strlen(CDSP_DOMAIN_NAME); + u.module_uri = const_cast(htp_uri); + u.module_uri_len = strlen(htp_uri); + u.uri = session_uri; + u.uri_len = sizeof(session_uri); + + int err = remote_session_control(FASTRPC_GET_URI, (void *) &u, sizeof(u)); + if (err != AEE_SUCCESS) { + GGML_LOG_ERROR("ggml-hex: failed to get URI for session %d : error 0x%x\n", dev_id, err); + throw std::runtime_error("ggml-hex: remote_session_control(get-uri) failed (see log for details)"); + } + } + + // Enable Unsigned PD + { + struct remote_rpc_control_unsigned_module u; + u.domain = this->domain_id; + u.enable = 1; + int err = remote_session_control(DSPRPC_CONTROL_UNSIGNED_MODULE, (void *) &u, sizeof(u)); + if (err != AEE_SUCCESS) { + GGML_LOG_ERROR("ggml-hex: failed to enable unsigned PD for session %d : error 0x%x\n", dev_id, err); + throw std::runtime_error("ggml-hex: remote_session_control(unsign) failed (see log for details)"); + } + } + + // Open session + int err = htp_iface_open(session_uri, &this->handle); + if (err != AEE_SUCCESS) { + GGML_LOG_ERROR("ggml-hex: failed to open session %d : error 0x%x\n", dev_id, err); + throw std::runtime_error("ggml-hex: failed to open session (see log for details)"); + } + + this->valid_handle = true; + + GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(), + this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); + + // Enable FastRPC QoS mode + { + struct remote_rpc_control_latency l; + l.enable = 1; + + int err = remote_handle64_control(this->handle, DSPRPC_CONTROL_LATENCY, (void *) &l, sizeof(l)); + if (err != 0) { + GGML_LOG_WARN("ggml-hex: failed to enable fastrpc QOS mode: 0x%08x\n", (unsigned) err); + } + } + + // Now let's setup the DSP queue + err = dspqueue_create(this->domain_id, + 0, // Flags + 128 * 1024, // Request queue size (in bytes) + 64 * 1024, // Response queue size (in bytes) + nullptr, // Read packet callback (we handle reads explicitly) + nullptr, // Error callback (we handle errors during reads) + (void *) this, // Callback context + &queue); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: %s dspqueue_create failed: 0x%08x\n", this->name.c_str(), (unsigned) err); + throw std::runtime_error("ggml-hex: failed to create dspqueue (see log for details)"); + } + + this->valid_queue = true; + + // Export queue for use on the DSP + err = dspqueue_export(queue, &this->queue_id); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: dspqueue_export failed: 0x%08x\n", (unsigned) err); + throw std::runtime_error("ggml-hex: dspqueue export failed (see log for details)"); + } + + if (opt_etm) { + err = htp_iface_enable_etm(this->handle); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err); + } + } + + // Start the DSP-side service. We need to pass the queue ID to the + // DSP in a FastRPC call; the DSP side will import the queue and start + // listening for packets in a callback. + err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); + throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); + } + this->valid_iface = true; +} + +void ggml_hexagon_session::release() noexcept(true) { + GGML_LOG_INFO("ggml-hex: releasing session: %s\n", this->name.c_str()); + + int err; + + // Stop the DSP-side service and close the queue + if (this->valid_iface) { + err = htp_iface_stop(this->handle); + if (err != 0) { + GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err); + } + } + + if (opt_etm) { + err = htp_iface_disable_etm(this->handle); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err); + } + } + + if (this->valid_queue) { + err = dspqueue_close(queue); + if (err != 0) { + GGML_ABORT("ggml-hex: dspqueue_close failed: 0x%08x\n", (unsigned) err); + } + } + + if (this->valid_handle) { + htp_iface_close(this->handle); + } +} + +ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false) { + buffer_type.context = nullptr; + repack_buffer_type.context = nullptr; + + buffer_type.device = dev; + repack_buffer_type.device = dev; + + try { + allocate(dev_id); + + buffer_type.iface = ggml_backend_hexagon_buffer_type_interface; + buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name, this); + + repack_buffer_type.iface = ggml_backend_hexagon_repack_buffer_type_interface; + repack_buffer_type.context = new ggml_backend_hexagon_buffer_type_context(this->name + "-REPACK", this); + } catch (std::exception const &exc) { + release(); + throw; + } +} + +ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) { + release(); + + delete static_cast(buffer_type.context); + delete static_cast(repack_buffer_type.context); +} + +// ** backend interface + +static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) { + return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment; +} + +static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) { + return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; +} + +static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) { + if (x->ne[0] != y->ne[0]) { + return false; + } + if (x->ne[1] != y->ne[1]) { + return false; + } + if (x->ne[2] != y->ne[2]) { + return false; + } + if (x->ne[3] != y->ne[3]) { + return false; + } + + return true; +} + +static bool hex_supported_src0_type(ggml_type t) { + return t == GGML_TYPE_F32; +} + +static bool hex_supported_src1_type(ggml_type t) { + return t == GGML_TYPE_F32; +} + +static bool hex_supported_src2_type(ggml_type t) { + return t == GGML_TYPE_F32; +} + +static bool hex_supported_src1_type2(ggml_type t) { + return t == GGML_TYPE_F16; +} + +static bool hex_supported_src1_type3(ggml_type t) { + return t == GGML_TYPE_I32; +} + +static bool hex_supported_dst_type(ggml_type t) { + return t == GGML_TYPE_F32; +} + +static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) { + // TODO: support broadcast for ne[2 and 3] + if (x->ne[0] != y->ne[0]) { + return false; + } + if (x->ne[2] != y->ne[2]) { + return false; + } + if (x->ne[3] != y->ne[3]) { + return false; + } + return true; +} + +static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + // TODO: add support for non-cont tensors + if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + if (src0->ne[0] % 32) { + return false; + } + + if (src0->ne[1] > 16 * 1024) { + return false; // typically the lm-head which would be too large for VTCM + } + + // if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false; + if ((src1->ne[2] != 1 || src1->ne[3] != 1)) { + return false; + } + + // src0 (weights) must be repacked + if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) { + return false; + } + break; + + case GGML_TYPE_F16: + if (!opt_experimental) { + return false; + } + break; + + default: + return false; + } + + // src0 & src1 & dst must be mapped to the same session + if (src0->buffer && + (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { + return false; + } + if (src1->buffer && + (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { + return false; + } + if (dst->buffer && + (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + return false; + } + + return true; +} + +static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * src2 = op->src[2]; + const struct ggml_tensor * dst = op; + + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32 || src2->type != GGML_TYPE_I32) { + return false; + } + + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + if ((src0->ne[0] % 32)) { + return false; + } + + // src0 (weights) must be repacked + if (src0->buffer && !ggml_backend_buffer_is_hexagon_repack(src0->buffer)) { + return false; + } + break; + + case GGML_TYPE_F16: + if (!opt_experimental) { + return false; + } + break; + + default: + return false; + } + + // TODO: add support for non-cont tensors + if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + + // src0 (weights) must be repacked and mapped to the same session + // src1 & sr2 & dst must be mapped to the same session + if (src0->buffer && + (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { + return false; + } + if (src1->buffer && + (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { + return false; + } + if (src2->buffer && + (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) { + return false; + } + if (dst->buffer && + (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + return false; + } + + return true; +} + +static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * dst = op; + + if (!hex_supported_src0_type(src0->type)) { + return false; + } + if (!hex_supported_src1_type(src1->type)) { + return false; + } + if (!hex_supported_dst_type(dst->type)) { + return false; + } + if (!hex_supported_dims2(src0, dst)) { + return false; + } + if (!ggml_can_repeat(src1, src0)) { + return false; + } + + // TODO: add support for non-contigiuos tensors + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + + // src0, src1 & dst must be mapped to the same session + if (src0->buffer && + (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { + return false; + } + if (src1->buffer && + (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { + return false; + } + if (dst->buffer && + (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + return false; + } + + return true; +} + +static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * src2 = op->src[2]; + const struct ggml_tensor * dst = op; + + if (!hex_supported_src0_type(src0->type)) { + return false; + } + if (!hex_supported_src1_type(src1->type)) { + return false; + } + if (!hex_supported_dst_type(dst->type)) { + return false; + } + if (!hex_supported_dims2(src0, dst)) { + return false; + } + + // REVISIT: add support for non-contigiuos tensors + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + + // src0, src1 & dst must be mapped to the same session + if (src0->buffer && + (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { + return false; + } + if (src1->buffer && + (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { + return false; + } + if (src2->buffer && + (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) { + return false; + } + if (dst->buffer && + (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + return false; + } + + return true; +} + +static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (!hex_supported_src0_type(src0->type)) { + return false; + } + if (!hex_supported_dst_type(dst->type)) { + return false; + } + if (!hex_supported_dims2(src0, dst)) { + return false; + } + + // TODO: add support for non-contigiuos tensors + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + return false; + } + + // src0 & dst must be mapped to the same session + if (src0->buffer && + (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { + return false; + } + if (dst->buffer && + (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + return false; + } + + return true; +} + +static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess, + const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * dst = op; + + if (!hex_supported_src0_type(src0->type)) { + return false; + } + if (!hex_supported_dst_type(dst->type)) { + return false; + } + + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + return false; + } + + if (src1) { + if (!hex_supported_src1_type(src1->type)) { + return false; + } + if (!hex_supported_dims2(src0, src1)) { + return false; + } + if (!ggml_is_contiguous(src1)) { + return false; + } + } + + // src0, src1 & dst must be mapped to the same session + if (src0->buffer && + (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { + return false; + } + if (src1 && src1->buffer && + (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { + return false; + } + if (dst->buffer && + (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + return false; + } + + return true; +} + +static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * src2 = op->src[2]; + const struct ggml_tensor * dst = op; + + if (src2) { + return false; // FIXME: add support for sinks + } + + if (!hex_supported_src0_type(src0->type)) { + return false; + } + if (!hex_supported_dst_type(dst->type)) { + return false; + } + + if (src1) { + if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) { + return false; + } + if (src0->ne[0] != src1->ne[0]) { + return false; + } + if (src1->ne[1] < src0->ne[1]) { + return false; + } + if (src0->ne[2] % src1->ne[2] != 0) { + return false; + } + if (src0->ne[3] % src1->ne[3] != 0) { + return false; + } + } + + if (src1) { + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + } else { + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + return false; + } + } + + // src0, src1 & dst must be mapped to the same session + if (src0->buffer && + (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { + return false; + } + if (src1 && src1->buffer && + (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { + return false; + } + if (dst->buffer && + (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + return false; + } + + return true; +} + +static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const int32_t * op_params = &op->op_params[0]; + + int mode = op_params[2]; + + if ((mode & GGML_ROPE_TYPE_NEOX) || (mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) { + return false; + } + if (mode & 1) { + return false; + } + + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * src2 = op->src[2]; + const struct ggml_tensor * dst = op; + + if (!hex_supported_src0_type(src0->type)) { + return false; // FIXME: add support for GGML_TYPE_F16 for src0 + } + if (!hex_supported_dst_type(dst->type)) { + return false; + } + if (!hex_supported_src1_type3(src1->type)) { + return false; + } + if (src2) { + if (!hex_supported_src2_type(src2->type)) { + return false; + } + int n_dims = op_params[1]; + if (src2->ne[0] < (n_dims / 2)) { + return false; + } + } + + if (src2) { + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(src2) || + !ggml_is_contiguous(dst)) { + return false; + } + } else { + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + } + + // src0, src1, src2 & dst must be mapped to the same session + if (src0->buffer && + (!ggml_backend_buffer_is_hexagon(src0->buffer) || ggml_backend_hexagon_buffer_get_sess(src0->buffer) != sess)) { + return false; + } + if (src1->buffer && + (!ggml_backend_buffer_is_hexagon(src1->buffer) || ggml_backend_hexagon_buffer_get_sess(src1->buffer) != sess)) { + return false; + } + if (src2 && src2->buffer && + (!ggml_backend_buffer_is_hexagon(src2->buffer) || ggml_backend_hexagon_buffer_get_sess(src2->buffer) != sess)) { + return false; + } + if (dst->buffer && + (!ggml_backend_buffer_is_hexagon(dst->buffer) || ggml_backend_hexagon_buffer_get_sess(dst->buffer) != sess)) { + return false; + } + + return true; +} + +// Init hexagon tensor from GGML tensor and Hexagon buffer +static void init_htp_tensor(htp_tensor * h, const ggml_tensor * t) { + h->data = 0; // updated by the receiver + h->type = t->type; + h->ne[0] = t->ne[0]; + h->ne[1] = t->ne[1]; + h->ne[2] = t->ne[2]; + h->ne[3] = t->ne[3]; + h->nb[0] = t->nb[0]; + h->nb[1] = t->nb[1]; + h->nb[2] = t->nb[2]; + h->nb[3] = t->nb[3]; +} + +static void hex_dump_dspbuf(const struct ggml_tensor * t, const dspqueue_buffer * d) { + auto buf = static_cast(t->buffer->context); + auto sess = buf->sess; + + HEX_VERBOSE("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(), + t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset, + (unsigned int) d->size); +} + +static void ggml_hexagon_mul_mat(const struct ggml_tensor * op, uint32_t flags) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * dst = op; + + auto src0_buf = static_cast(src0->buffer->context); + auto src1_buf = static_cast(src1->buffer->context); + auto dst_buf = static_cast(dst->buffer->context); + + uint64_t t1, t2; + t1 = ggml_time_us(); + + // Construct HTP message + htp_general_req req; + req.op = HTP_OP_MUL_MAT; + req.flags = flags; + + init_htp_tensor(&req.src0, src0); + init_htp_tensor(&req.src1, src1); + init_htp_tensor(&req.dst, dst); + + // Use opmask to override flags + if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { + req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; + } + if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { + req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; + } + + dspqueue_buffer bufs[3]; + memset(bufs, 0, sizeof(bufs)); + + // First buffer Weights. + // The content is static, there is no need to do any cache management + bufs[0].fd = src0_buf->fd; + bufs[0].ptr = src0->data; + bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; + bufs[0].size = ggml_nbytes(src0); + bufs[0].flags = 0; + + // Second buffer Input Activations. This is a buffer that the CPU + // writes and the DSP reads, so we'll need to flush CPU caches and + // invalidate DSP ones. On platforms with I/O coherency support the + // framework will automatically skip cache operations where possible. + bufs[1].fd = src1_buf->fd; + bufs[1].ptr = src1->data; + bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; + bufs[1].size = ggml_nbytes(src1); + bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + + // Third buffer Output Activations. We'll handle DSP + // cache maintenance in the response message but need to flush + // CPU caches to ensure any previously written dirty lines are + // written out before writes from the DSP start. + bufs[2].fd = dst_buf->fd; + bufs[2].ptr = dst->data; + bufs[2].offset = (uint8_t *) dst->data - dst_buf->base; + bufs[2].size = ggml_nbytes(dst); + bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + + // Primary DSP session from the src0 (normally weight) tensor + auto sess = src0_buf->sess; + + if (opt_verbose) { + char dims[64 * GGML_MAX_SRC]; + char strides[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + hex_format_op_dims(dims, op); + hex_format_op_strides(strides, op); + hex_format_op_types(types, op); + hex_format_op_buffs(buffs, op); + hex_format_op_names(names, op); + + HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), + names, dims, types, strides, buffs, req.flags); + if (opt_verbose > 1) { + hex_dump_dspbuf(src0, &bufs[0]); + hex_dump_dspbuf(src1, &bufs[1]); + hex_dump_dspbuf(dst, &bufs[2]); + } + } + + if ((opt_opmask & HTP_OPMASK_QUEUE)) { + sess->enqueue(req, bufs, 3, opt_opsync); + } + + t2 = ggml_time_us(); + + HEX_PROFILE( + "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) " + "call-usec %llu\n", + sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, + (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); +} + +static void ggml_hexagon_mul_mat_id(const struct ggml_tensor * op, uint32_t flags) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * src2 = op->src[2]; + const struct ggml_tensor * dst = op; + + auto src0_buf = static_cast(src0->buffer->context); + auto src1_buf = static_cast(src1->buffer->context); + auto src2_buf = static_cast(src2->buffer->context); + auto dst_buf = static_cast(dst->buffer->context); + + uint64_t t1, t2; + t1 = ggml_time_us(); + + // Construct HTP message + htp_general_req req; + req.op = HTP_OP_MUL_MAT_ID; + req.flags = flags; + + init_htp_tensor(&req.src0, src0); + init_htp_tensor(&req.src1, src1); + init_htp_tensor(&req.src2, src2); + init_htp_tensor(&req.dst, dst); + + // Use opmask to override flags + if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { + req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; + } + if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { + req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; + } + + dspqueue_buffer bufs[4]; + memset(bufs, 0, sizeof(bufs)); + + // First buffer Weights. + // The content is static, there is no need to do any cache management + bufs[0].fd = src0_buf->fd; + bufs[0].ptr = src0->data; + bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; + bufs[0].size = ggml_nbytes(src0); + bufs[0].flags = 0; + + // Second buffer Input Activations. This is a buffer that the CPU + // writes and the DSP reads, so we'll need to flush CPU caches and + // invalidate DSP ones. On platforms with I/O coherency support the + // framework will automatically skip cache operations where possible. + bufs[1].fd = src1_buf->fd; + bufs[1].ptr = src1->data; + bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; + bufs[1].size = ggml_nbytes(src1); + bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + + // Third buffer expert IDs. This is a buffer that the CPU + // writes and the DSP reads, so we'll need to flush CPU caches and + // invalidate DSP ones. On platforms with I/O coherency support the + // framework will automatically skip cache operations where possible. + bufs[2].fd = src2_buf->fd; + bufs[2].ptr = src2->data; + bufs[2].offset = (uint8_t *) src2->data - src2_buf->base; + bufs[2].size = ggml_nbytes(src2); + bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + + // Forth buffer Output Activations. We'll handle DSP + // cache maintenance in the response message but need to flush + // CPU caches to ensure any previously written dirty lines are + // written out before writes from the DSP start. + bufs[3].fd = dst_buf->fd; + bufs[3].ptr = dst->data; + bufs[3].offset = (uint8_t *) dst->data - dst_buf->base; + bufs[3].size = ggml_nbytes(dst); + bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + + // Primary DSP session from the src0 (normally weight) tensor + auto sess = src0_buf->sess; + + if (opt_verbose) { + char dims[64 * GGML_MAX_SRC]; + char strides[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + hex_format_op_dims(dims, op); + hex_format_op_types(types, op); + hex_format_op_buffs(buffs, op); + hex_format_op_names(names, op); + + HEX_VERBOSE("ggml-hex: %s %s: %s : %s : %s : %s : %s: flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), + names, dims, types, strides, buffs, req.flags); + + if (opt_verbose > 1) { + hex_dump_dspbuf(src0, &bufs[0]); + hex_dump_dspbuf(src1, &bufs[1]); + hex_dump_dspbuf(src2, &bufs[2]); + hex_dump_dspbuf(dst, &bufs[3]); + } + } + + if ((opt_opmask & HTP_OPMASK_QUEUE)) { + sess->enqueue(req, bufs, 4, opt_opsync); + } + + t2 = ggml_time_us(); + + HEX_PROFILE( + "ggml-hex: %s matmul-id %s %u:%u:%u:%u x %s %u:%u:%u:%u (%s %u:%u:%u:%u) -> %s %u:%u:%u:%u : op-usec %u " + "op-cycles %u op-pkts %u (%f) call-usec %llu\n", + sess->name.c_str(), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], (uint32_t) src0->ne[2], + (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], src2->name, (uint32_t) src2->ne[0], (uint32_t) src2->ne[1], (uint32_t) src2->ne[2], + (uint32_t) src2->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, + (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); +} + +static void ggml_hexagon_binary(const struct ggml_tensor * op, uint32_t flags) { + const struct ggml_tensor * node = op; + const struct ggml_tensor * src0 = node->src[0]; + const struct ggml_tensor * src1 = node->src[1]; + const struct ggml_tensor * dst = node; + + auto src0_buf = static_cast(src0->buffer->context); + auto src1_buf = static_cast(src1->buffer->context); + auto dst_buf = static_cast(dst->buffer->context); + + uint64_t t1 = 0; + uint64_t t2 = 0; + + t1 = ggml_time_us(); + + // Construct HTP message + htp_general_req req; + req.flags = flags; + + // Use opmask to override flags + if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { + req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; + } + if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { + req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; + } + + switch (node->op) { + case GGML_OP_MUL: + req.op = HTP_OP_MUL; + break; + case GGML_OP_ADD: + req.op = HTP_OP_ADD; + break; + case GGML_OP_SUB: + req.op = HTP_OP_SUB; + break; + default: + GGML_ABORT("ggml-hex: binary : unsupported op:%d\n", node->op); + } + + init_htp_tensor(&req.src0, src0); + init_htp_tensor(&req.src1, src1); + init_htp_tensor(&req.dst, dst); + + dspqueue_buffer bufs[3]; + memset(bufs, 0, sizeof(bufs)); + + // First buffer = First Operand of Binary op + // This is a buffer that the CPU writes and the DSP reads, so we'll + // need to flush CPU caches and invalidate DSP ones. On platforms + // with I/O coherency support the framework will automatically skip + // cache operations where possible. + bufs[0].fd = src0_buf->fd; + bufs[0].ptr = src0->data; + bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; + bufs[0].size = ggml_nbytes(src0); + bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; + + // Second buffer = Second Operand of Binary op + // This is a buffer that the CPU writes and the DSP reads, so we'll + // need to flush CPU caches and invalidate DSP ones. On platforms + // with I/O coherency support the framework will automatically skip + // cache operations where possible. + bufs[1].fd = src1_buf->fd; + bufs[1].ptr = src1->data; + bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; + bufs[1].size = ggml_nbytes(src1); + bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + + // Third buffer = Output Activations. We'll handle DSP + // cache maintenance in the response message but need to flush + // CPU caches to ensure any previously written dirty lines are + // written out before writes from the DSP start. + bufs[2].fd = dst_buf->fd; + bufs[2].ptr = dst->data; + bufs[2].offset = (uint8_t *) dst->data - dst_buf->base; + bufs[2].size = ggml_nbytes(dst); + bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + + // Primary DSP session from the src0 tensor + ggml_hexagon_session * sess = src0_buf->sess; + + if (opt_verbose) { + char dims[64 * GGML_MAX_SRC]; + char strides[16 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + hex_format_op_dims(dims, op); + hex_format_op_strides(strides, op); + hex_format_op_types(types, op); + hex_format_op_buffs(buffs, op); + hex_format_op_names(names, op); + + HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), + ggml_op_name(node->op), names, dims, types, strides, buffs, req.flags); + if (opt_verbose > 1) { + hex_dump_dspbuf(src0, &bufs[0]); + hex_dump_dspbuf(src1, &bufs[1]); + hex_dump_dspbuf(dst, &bufs[2]); + } + } + + if ((opt_opmask & HTP_OPMASK_QUEUE)) { + sess->enqueue(req, bufs, 3, opt_opsync); + } + + t2 = ggml_time_us(); + + HEX_PROFILE( + "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) " + "call-usec %llu\n", + sess->name.c_str(), ggml_op_name(node->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, + (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); +} + +static void ggml_hexagon_add_id(const struct ggml_tensor * op, uint32_t flags) { + const struct ggml_tensor * node = op; + const struct ggml_tensor * src0 = node->src[0]; + const struct ggml_tensor * src1 = node->src[1]; + const struct ggml_tensor * src2 = node->src[2]; + const struct ggml_tensor * dst = node; + + auto src0_buf = static_cast(src0->buffer->context); + auto src1_buf = static_cast(src1->buffer->context); + auto src2_buf = static_cast(src2->buffer->context); + auto dst_buf = static_cast(dst->buffer->context); + + uint64_t t1 = 0; + uint64_t t2 = 0; + + t1 = ggml_time_us(); + + // Construct HTP message + htp_general_req req; + req.flags = flags; + + // Use opmask to override flags + if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { + req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; + } + if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { + req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; + } + + switch (node->op) { + case GGML_OP_ADD_ID: + req.op = HTP_OP_ADD_ID; + break; + default: + GGML_ABORT("ggml-hex: unsupported op:%d\n", node->op); + } + + init_htp_tensor(&req.src0, src0); + init_htp_tensor(&req.src1, src1); + init_htp_tensor(&req.src2, src2); + init_htp_tensor(&req.dst, dst); + + dspqueue_buffer bufs[4]; + memset(bufs, 0, sizeof(bufs)); + + // First buffer = input activations + bufs[0].fd = src0_buf->fd; + bufs[0].ptr = src0->data; + bufs[0].offset = (uint8_t *) src0->data - src0_buf->base; + bufs[0].size = ggml_nbytes(src0); + bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; + + // Second buffer = experts bias + bufs[1].fd = src1_buf->fd; + bufs[1].ptr = src1->data; + bufs[1].offset = (uint8_t *) src1->data - src1_buf->base; + bufs[1].size = ggml_nbytes(src1); + bufs[1].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + + // Third buffer = activated experts + bufs[2].fd = src2_buf->fd; + bufs[2].ptr = src2->data; + bufs[2].offset = (uint8_t *) src2->data - src2_buf->base; + bufs[2].size = ggml_nbytes(src2); + bufs[2].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + + // Forth buffer = output activations + bufs[3].fd = dst_buf->fd; + bufs[3].ptr = dst->data; + bufs[3].offset = (uint8_t *) dst->data - dst_buf->base; + bufs[3].size = ggml_nbytes(dst); + bufs[3].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + + // Primary DSP session from the src0 tensor + ggml_hexagon_session * sess = src0_buf->sess; + + if (opt_verbose) { + char dims[64 * GGML_MAX_SRC]; + char strides[16 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + hex_format_op_dims(dims, op); + hex_format_op_strides(strides, op); + hex_format_op_types(types, op); + hex_format_op_buffs(buffs, op); + hex_format_op_names(names, op); + + HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), + ggml_op_name(node->op), names, dims, types, strides, buffs, req.flags); + + if (opt_verbose > 1) { + hex_dump_dspbuf(src0, &bufs[0]); + hex_dump_dspbuf(src1, &bufs[1]); + hex_dump_dspbuf(src2, &bufs[2]); + hex_dump_dspbuf(dst, &bufs[3]); + } + } + + if ((opt_opmask & HTP_OPMASK_QUEUE)) { + sess->enqueue(req, bufs, 4, opt_opsync); + } + + t2 = ggml_time_us(); + + HEX_PROFILE( + "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) " + "call-usec %llu\n", + sess->name.c_str(), ggml_op_name(node->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, + (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); +} + +static void ggml_hexagon_unary(const struct ggml_tensor * op, uint32_t flags) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * dst = op; + + uint64_t t1 = 0; + uint64_t t2 = 0; + + t1 = ggml_time_us(); + + // Construct HTP message + htp_general_req req; + + memset(&req, 0, sizeof(htp_general_req)); + memcpy(&req.op_params, &op->op_params, sizeof(op->op_params)); + req.flags = flags; + + bool supported = false; + + switch (op->op) { + case GGML_OP_RMS_NORM: + req.op = HTP_OP_RMS_NORM; + supported = true; + break; + + case GGML_OP_UNARY: + if (ggml_get_unary_op(dst) == GGML_UNARY_OP_SILU) { + req.op = HTP_OP_UNARY_SILU; + supported = true; + } + break; + + case GGML_OP_GLU: + if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU) { + req.op = HTP_OP_GLU_SWIGLU; + supported = true; + } else if (ggml_get_glu_op(dst) == GGML_GLU_OP_SWIGLU_OAI) { + req.op = HTP_OP_GLU_SWIGLU_OAI; + supported = true; + } + break; + + case GGML_OP_SOFT_MAX: + req.op = HTP_OP_SOFTMAX; + supported = true; + + default: + break; + } + + if (!supported) { + GGML_ABORT("ggml-hex: unary : unsupported op:%d\n", op->op); + } + + init_htp_tensor(&req.dst, dst); + init_htp_tensor(&req.src0, src0); + if (src1) { + init_htp_tensor(&req.src1, src1); + } + + // Use opmask to override flags + if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { + req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; + } + if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { + req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; + } + + dspqueue_buffer bufs[3]; + int n_bufs = 0; + + memset(bufs, 0, sizeof(bufs)); + + // First buffer = Only Operand of Unary op + // This is a buffer that the CPU writes and the DSP reads, so we'll + // need to flush CPU caches and invalidate DSP ones. On platforms + // with I/O coherency support the framework will automatically skip + // cache operations where possible. + auto src0_buf = static_cast(src0->buffer->context); + bufs[n_bufs].fd = src0_buf->fd; + bufs[n_bufs].ptr = src0->data; + bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base; + bufs[n_bufs].size = ggml_nbytes(src0); + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; + ++n_bufs; + + if (src1) { + // Second buffer = Second Operand of Binary op + // This is a buffer that the CPU writes and the DSP reads, so we'll + // need to flush CPU caches and invalidate DSP ones. On platforms + // with I/O coherency support the framework will automatically skip + // cache operations where possible. + auto src1_buf = static_cast(src1->buffer->context); + bufs[n_bufs].fd = src1_buf->fd; + bufs[n_bufs].ptr = src1->data; + bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base; + bufs[n_bufs].size = ggml_nbytes(src1); + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + ++n_bufs; + } + + // Second or third buffer = Output Activations. We'll handle DSP + // Second buffer = Output Activations. We'll handle DSP + // cache maintenance in the response message but need to flush + // CPU caches to ensure any previously written dirty lines are + // written out before writes from the DSP start. + auto dst_buf = static_cast(dst->buffer->context); + bufs[n_bufs].fd = dst_buf->fd; + bufs[n_bufs].ptr = dst->data; + bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base; + bufs[n_bufs].size = ggml_nbytes(dst); + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + ++n_bufs; + + // Primary DSP session from the src0 tensor + ggml_hexagon_session * sess = src0_buf->sess; + + if (opt_verbose) { + char dims[64 * GGML_MAX_SRC]; + char strides[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + hex_format_op_dims(dims, op); + hex_format_op_strides(strides, op); + hex_format_op_types(types, op); + hex_format_op_buffs(buffs, op); + hex_format_op_names(names, op); + + HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), + names, dims, types, strides, buffs, req.flags); + if (opt_verbose > 1) { + hex_dump_dspbuf(src0, &bufs[0]); + if (src1) { + hex_dump_dspbuf(src1, &bufs[1]); + hex_dump_dspbuf(dst, &bufs[2]); + } else { + hex_dump_dspbuf(dst, &bufs[1]); + } + } + } + + if ((opt_opmask & HTP_OPMASK_QUEUE)) { + sess->enqueue(req, bufs, n_bufs, opt_opsync); + } + + t2 = ggml_time_us(); + + if (src1) { + HEX_PROFILE( + "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u " + "(%f) call-usec %llu\n", + sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, + (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); + } else { + HEX_PROFILE( + "ggml-hex: %s %s %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u (%f) call-usec " + "%llu\n", + sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, + (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); + } +} + +static void ggml_hexagon_rope(const struct ggml_tensor * op, uint32_t flags) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * src2 = op->src[2]; + const struct ggml_tensor * dst = op; + + uint64_t t1 = 0; + uint64_t t2 = 0; + + t1 = ggml_time_us(); + + // Construct HTP message + htp_general_req req; + + memset(&req, 0, sizeof(htp_general_req)); + memcpy(&req.op_params, &op->op_params, sizeof(op->op_params)); + req.flags = flags; + req.op = HTP_OP_ROPE; + + init_htp_tensor(&req.dst, dst); + init_htp_tensor(&req.src0, src0); + init_htp_tensor(&req.src1, src1); + if (src2) { + init_htp_tensor(&req.src2, src2); + } + + // Use opmask to override flags + if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { + req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; + } + if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { + req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; + } + + dspqueue_buffer bufs[4]; + int n_bufs = 0; + + memset(bufs, 0, sizeof(bufs)); + + // First buffer + // This is a buffer that the CPU writes and the DSP reads, so we'll + // need to flush CPU caches and invalidate DSP ones. On platforms + // with I/O coherency support the framework will automatically skip + // cache operations where possible. + auto src0_buf = static_cast(src0->buffer->context); + bufs[n_bufs].fd = src0_buf->fd; + bufs[n_bufs].ptr = src0->data; + bufs[n_bufs].offset = (uint8_t *) src0->data - src0_buf->base; + bufs[n_bufs].size = ggml_nbytes(src0); + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP; + ++n_bufs; + + // Second buffer + // This is a buffer that the CPU writes and the DSP reads, so we'll + // need to flush CPU caches and invalidate DSP ones. On platforms + // with I/O coherency support the framework will automatically skip + // cache operations where possible. + auto src1_buf = static_cast(src1->buffer->context); + bufs[n_bufs].fd = src1_buf->fd; + bufs[n_bufs].ptr = src1->data; + bufs[n_bufs].offset = (uint8_t *) src1->data - src1_buf->base; + bufs[n_bufs].size = ggml_nbytes(src1); + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + ++n_bufs; + + if (src2) { + // Third buffer + // This is a buffer that the CPU writes and the DSP reads, so we'll + // need to flush CPU caches and invalidate DSP ones. On platforms + // with I/O coherency support the framework will automatically skip + // cache operations where possible. + auto src2_buf = static_cast(src2->buffer->context); + bufs[n_bufs].fd = src2_buf->fd; + bufs[n_bufs].ptr = src2->data; + bufs[n_bufs].offset = (uint8_t *) src2->data - src2_buf->base; + bufs[n_bufs].size = ggml_nbytes(src2); + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush CPU + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate DSP + ++n_bufs; + } + + // Final buffer = Output Activations. We'll handle DSP + // Second buffer = Output Activations. We'll handle DSP + // cache maintenance in the response message but need to flush + // CPU caches to ensure any previously written dirty lines are + // written out before writes from the DSP start. + auto dst_buf = static_cast(dst->buffer->context); + bufs[n_bufs].fd = dst_buf->fd; + bufs[n_bufs].ptr = dst->data; + bufs[n_bufs].offset = (uint8_t *) dst->data - dst_buf->base; + bufs[n_bufs].size = ggml_nbytes(dst); + bufs[n_bufs].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER); + ++n_bufs; + + // Primary DSP session from the src0 tensor + ggml_hexagon_session * sess = src0_buf->sess; + + if (opt_verbose) { + char dims[64 * GGML_MAX_SRC]; + char strides[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + hex_format_op_dims(dims, op); + hex_format_op_strides(strides, op); + hex_format_op_types(types, op); + hex_format_op_buffs(buffs, op); + hex_format_op_names(names, op); + + HEX_VERBOSE("ggml-hex: %s %s : %s : %s : %s : %s : %s : flags 0x%x\n", sess->name.c_str(), ggml_op_name(op->op), + names, dims, types, strides, buffs, req.flags); + if (opt_verbose > 1) { + hex_dump_dspbuf(src0, &bufs[0]); + if (src1) { + hex_dump_dspbuf(src1, &bufs[1]); + hex_dump_dspbuf(dst, &bufs[2]); + } else { + hex_dump_dspbuf(dst, &bufs[1]); + } + } + } + + if ((opt_opmask & HTP_OPMASK_QUEUE)) { + sess->enqueue(req, bufs, n_bufs, opt_opsync); + } + + t2 = ggml_time_us(); + + if (src2) { + HEX_PROFILE( + "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles " + "%u op-pkts %u (%f) call-usec %llu\n", + sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], src2->name, (uint32_t) src2->ne[0], (uint32_t) src2->ne[1], + (uint32_t) src2->ne[2], (uint32_t) src2->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, + (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); + } else { + HEX_PROFILE( + "ggml-hex: %s %s %s %u:%u:%u:%u x %s %u:%u:%u:%u -> %s %u:%u:%u:%u : op-usec %u op-cycles %u op-pkts %u " + "(%f) call-usec %llu\n", + sess->name.c_str(), ggml_op_name(op->op), src0->name, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], (uint32_t) src0->ne[3], src1->name, (uint32_t) src1->ne[0], (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], dst->name, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, + (float) sess->prof_cycles / sess->prof_pkts, (unsigned long long) t2 - t1); + } +} + +static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { + auto sess = static_cast(backend->context); + return sess->name.c_str(); +} + +static void ggml_backend_hexagon_free(ggml_backend_t backend) { + // we just need to delete the backend here + // the sessions are allocated & freed as part of the registry + delete backend; +} + +static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) { + return (op0 && op0->src[1] == op1->src[1]); +} + +// scan the graph and figure out last compute op index +static inline int last_compute_op(ggml_cgraph * graph) { + int last; + for (int i = 0; i < graph->n_nodes; ++i) { + ggml_tensor * node = graph->nodes[i]; + + switch (node->op) { + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_MUL: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_RMS_NORM: + case GGML_OP_GLU: + case GGML_OP_ADD_ID: + last = i; + break; + + default: + break; + } + } + + return last; +} + +static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) { + auto sess = static_cast(backend->context); + + HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes); + + const int last = last_compute_op(graph); + + const struct ggml_tensor * prev_quant_op = nullptr; // prev executed op with quantizer + + for (int i = 0; i < graph->n_nodes; ++i) { + ggml_tensor * node = graph->nodes[i]; + + uint32_t flags = 0; + + // skip quantizer if src1 is reused + if (op_reuse_src1(node, prev_quant_op)) { + flags |= HTP_OPFLAGS_SKIP_QUANTIZE; + } + + // ask for early notification for the last Op + if (i == last) { + flags |= HTP_OPFLAGS_EARLY_WAKEUP; + } + + switch (node->op) { + case GGML_OP_MUL_MAT: + ggml_hexagon_mul_mat(node, flags); + prev_quant_op = node; + break; + case GGML_OP_MUL_MAT_ID: + ggml_hexagon_mul_mat_id(node, flags); + prev_quant_op = node; + break; + case GGML_OP_MUL: + case GGML_OP_ADD: + case GGML_OP_SUB: + ggml_hexagon_binary(node, flags); + break; + case GGML_OP_ADD_ID: + ggml_hexagon_add_id(node, flags); + break; + case GGML_OP_RMS_NORM: + ggml_hexagon_unary(node, flags); + break; + case GGML_OP_UNARY: + if (ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) { + ggml_hexagon_unary(node, flags); + } + break; + case GGML_OP_GLU: + if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) || + (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) { + ggml_hexagon_unary(node, flags); + } + break; + case GGML_OP_SOFT_MAX: + ggml_hexagon_unary(node, flags); + break; + + case GGML_OP_ROPE: + ggml_hexagon_rope(node, flags); + break; + + // non-compute ops + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + break; + + default: + GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); + } + } + + // Wait until all pending ops complete + sess->flush(); + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) { + auto sess = static_cast(backend->context); + + HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str()); + + // Wait until all pending ops complete + sess->flush(); +} + +struct node_info { + ggml_tensor * node; + + std::vector fused; + + ggml_op op() const { + return node->op; + } + + const ggml_tensor * dst() const { + return fused.empty() ? node : fused.back(); + } + + const ggml_tensor * src0() const { + return node->src[0]; + } + + const ggml_tensor * src1() const { + return node->src[1]; + } + + bool is_empty() const { + return ggml_op_is_empty(node->op); + } + + void add_fused(ggml_tensor * t) { + fused.push_back(t); + } + + bool stackable() const { + switch (this->op()) { + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + return ggml_is_quantized(this->src0()->type); + default: + return false; + } + } + + bool same_input(const node_info& n) const { + return n.src1() == this->src1(); + } +}; + +static std::vector ggml_hexagon_graph_optimize_reorder(const std::vector & nodes) { + const int n = nodes.size(); + + std::vector res; + res.reserve(n); + + std::vector used(n, false); + + // The main goal here is to stack the MUL_MAT ops with the same src1 input. + // This allows use to reuse dynamically quantized src1 in VTCM. + + // TODO: the current version might do incorrect reodering in cases where quantized src0 + // input is an output of another Op. + + for (int i0 = 0; i0 < n; i0++) { + if (used[i0]) { + continue; + } + + res.push_back(i0); + + const auto & node0 = nodes[i0]; + + if (!node0.stackable()) { + continue; + } + + // that many nodes forward to search for stackable nodes that can reuse VTCM + constexpr int N_FORWARD = 8; + + for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) { + if (used[i1]) { + continue; + } + + const auto & node1 = nodes[i1]; + + if (node1.stackable() && node1.same_input(node0)) { + res.push_back(i1); + used[i1] = true; + } + } + } + + return res; +} + +static void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgraph * gf) { + const int n = gf->n_nodes; + + constexpr int MAX_FUSE = 16; + + enum ggml_op ops[MAX_FUSE]; + + std::vector nodes; + nodes.reserve(gf->n_nodes); + + // fuse nodes: + // we don't want to make reorders that break fusing, so we first pack all fusable tensors + // and perform the reorder over the fused nodes. after the reorder is done, we unfuse + for (int i = 0; i < n; i++) { + node_info node = { + /*.node =*/ gf->nodes[i], + /*.fused =*/ {}, + }; + + // fuse only ops that start with these operations + // can be expanded when needed + if (node.op() == GGML_OP_ADD || + node.op() == GGML_OP_NORM || + node.op() == GGML_OP_RMS_NORM) { + ops[0] = node.op(); + + int f = i + 1; + while (f < n && f < i + MAX_FUSE) { + // conservatively allow fusing only these ops + // can be expanded when needed + if (gf->nodes[f]->op != GGML_OP_ADD && + gf->nodes[f]->op != GGML_OP_MUL && + gf->nodes[f]->op != GGML_OP_NORM && + gf->nodes[f]->op != GGML_OP_RMS_NORM) { + break; + } + ops[f - i] = gf->nodes[f]->op; + f++; + } + + f -= i; + for (; f > 1; f--) { + if (ggml_can_fuse(gf, i, ops, f)) { + break; + } + } + + // add the fused tensors into the node info so we can unfuse them later + for (int k = 1; k < f; k++) { + ++i; + + // the .dst() becomes the last fused tensor + node.add_fused(gf->nodes[i]); + } + } + + nodes.push_back(std::move(node)); + } + + const auto order = ggml_hexagon_graph_optimize_reorder(nodes); + + // unfuse + { + int j = 0; + for (const auto i : order) { + const auto & node = nodes[i]; + + gf->nodes[j++] = node.node; + + for (auto * fused : node.fused) { + gf->nodes[j++] = fused; + } + } + } +} + +static struct ggml_backend_i hexagon_backend_i = { + /* .get_name = */ ggml_backend_hexagon_name, + /* .free = */ ggml_backend_hexagon_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ ggml_backend_hexagon_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_hexagon_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ ggml_backend_hexagon_graph_optimize, +}; + +static ggml_guid_t ggml_backend_hexagon_guid() { + static ggml_guid guid = { 0x7b, 0x57, 0xdc, 0xaf, 0xde, 0x12, 0x1d, 0x49, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11 }; + return &guid; +} + +bool ggml_backend_is_hexagon(ggml_backend_t backend) { + return backend && backend->iface.get_name == ggml_backend_hexagon_name; +} + +// device interface + +static ggml_backend_t ggml_backend_hexagon_device_init(ggml_backend_dev_t dev, const char * params) { + auto sess = static_cast(dev->context); + + return new ggml_backend{ + /* .guid = */ ggml_backend_hexagon_guid(), + /* .interface = */ hexagon_backend_i, + /* .device = */ dev, + /* .context = */ sess, + }; + + GGML_UNUSED(params); +} + +static const char * ggml_backend_hexagon_device_get_name(ggml_backend_dev_t dev) { + auto sess = static_cast(dev->context); + return sess->name.c_str(); + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_hexagon_device_get_description(ggml_backend_dev_t dev) { + return "Hexagon"; + GGML_UNUSED(dev); +} + +static void ggml_backend_hexagon_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + // ~2GB per session for now + *free = 2ULL * 1024 * 1024 * 1024; + *total = *free; + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_hexagon_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_GPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_hexagon_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_hexagon_device_get_name(dev); + props->description = ggml_backend_hexagon_device_get_description(dev); + props->type = ggml_backend_hexagon_device_get_type(dev); + ggml_backend_hexagon_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ true, + /* .host_buffer = */ (bool) opt_hostbuf, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_buffer_type(ggml_backend_dev_t dev) { + auto sess = static_cast(dev->context); + return &sess->buffer_type; +} + +static ggml_backend_buffer_type_t ggml_backend_hexagon_device_get_repack_buffer_type(ggml_backend_dev_t dev) { + auto sess = static_cast(dev->context); + return &sess->repack_buffer_type; +} + +static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + auto sess = static_cast(dev->context); + + bool supp = false; + + switch (op->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + supp = true; + break; + + case GGML_OP_MUL_MAT: + supp = ggml_hexagon_supported_mul_mat(sess, op); + break; + + case GGML_OP_MUL_MAT_ID: + supp = ggml_hexagon_supported_mul_mat_id(sess, op); + break; + + case GGML_OP_MUL: + case GGML_OP_ADD: + case GGML_OP_SUB: + supp = ggml_hexagon_supported_binary(sess, op); + break; + + case GGML_OP_ADD_ID: + supp = ggml_hexagon_supported_add_id(sess, op); + break; + + case GGML_OP_RMS_NORM: + supp = ggml_hexagon_supported_unary(sess, op); + break; + + case GGML_OP_SOFT_MAX: + supp = ggml_hexagon_supported_softmax(sess, op); + break; + + case GGML_OP_UNARY: + if (ggml_get_unary_op(op) == GGML_UNARY_OP_SILU) { + supp = ggml_hexagon_supported_activations(sess, op); + } + break; + + case GGML_OP_GLU: + if ((ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU) /* || (ggml_get_glu_op(op) == GGML_GLU_OP_SWIGLU_OAI) */) { + supp = ggml_hexagon_supported_activations(sess, op); + } + break; + + case GGML_OP_ROPE: + supp = ggml_hexagon_supported_rope(sess, op); + break; + + default: + break; + } + + if (opt_verbose) { + char dims[64 * GGML_MAX_SRC]; + char strides[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + hex_format_op_dims(dims, op); + hex_format_op_strides(strides, op); + hex_format_op_types(types, op); + hex_format_op_buffs(buffs, op); + hex_format_op_names(names, op); + + HEX_VERBOSE("ggml-hex: %s device-supports-op %s : %s : %s : %s : %s : %s : (%d)\n", sess->name.c_str(), + ggml_op_name(op->op), names, dims, types, strides, buffs, (int) supp); + } + + return supp; + + GGML_UNUSED(dev); +} + +static bool ggml_backend_hexagon_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (buft->iface.get_alignment != ggml_backend_hexagon_buffer_type_get_alignment) { + return false; + } + + auto s0 = static_cast(dev->context); + auto s1 = static_cast(buft->context)->sess; + + // Need session/domain-id for buffers to be compatible + bool supp = (s0->session_id == s1->session_id); + + HEX_VERBOSE("ggml-hex: %s device-supports-buft %s (%d)\n", s0->name.c_str(), s1->name.c_str(), (int) supp); + + return supp; +} + +static ggml_backend_buffer_type_t * ggml_backend_hexagon_device_get_extra_buffers_type(ggml_backend_dev_t dev) { + auto s0 = static_cast(dev->context); + HEX_VERBOSE("ggml-hex: device-get-extra-buft : %s \n", s0->name.c_str()); + + static ggml_backend_buffer_type_t bufts[2]; + bufts[0] = ggml_backend_hexagon_device_get_repack_buffer_type(dev); + bufts[1] = NULL; + return bufts; +} + +static const struct ggml_backend_device_i ggml_backend_hexagon_device_i = { + /* .get_name = */ ggml_backend_hexagon_device_get_name, + /* .get_description = */ ggml_backend_hexagon_device_get_description, + /* .get_memory = */ ggml_backend_hexagon_device_get_memory, + /* .get_type = */ ggml_backend_hexagon_device_get_type, + /* .get_props = */ ggml_backend_hexagon_device_get_props, + /* .init_backend = */ ggml_backend_hexagon_device_init, + /* .get_buffer_type = */ ggml_backend_hexagon_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, // ggml_backend_hexagon_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, // ggml_backend_hexagon_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_hexagon_device_supports_op, + /* .supports_buft = */ ggml_backend_hexagon_device_supports_buft, + /* .offload_op = */ NULL, // ggml_backend_hexagon_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +//** backend registry + +#define GGML_HEXAGON_MAX_SESSIONS 16 + +struct ggml_hexagon_registry { + ggml_hexagon_registry(ggml_backend_reg_t reg); + ~ggml_hexagon_registry(); + + ggml_backend_device devices[GGML_HEXAGON_MAX_SESSIONS]; +}; + +ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { + GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev); + + if (!opt_arch) { + int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err); + opt_arch = 73; + } + } + + GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch); + + // Create devices / sessions + for (size_t i = 0; i < opt_ndev; i++) { + devices[i].iface = ggml_backend_hexagon_device_i; + devices[i].reg = reg; + try { + devices[i].context = new ggml_hexagon_session(i, &devices[i]); + } catch (std::exception const &exc) { + GGML_LOG_ERROR("ggml-hex: failed to create device/session %zu\n", i); + devices[i].context = nullptr; + } + } +} + +ggml_hexagon_registry::~ggml_hexagon_registry() { + GGML_LOG_INFO("ggml-hex: releasing registry\n"); + + // Release devices / sessions + for (size_t i = 0; i < opt_ndev; i++) { + auto sess = static_cast(devices[i].context); + delete sess; + } +} + +static const char * ggml_backend_hexagon_reg_get_name(ggml_backend_reg_t reg) { + return "HTP"; + GGML_UNUSED(reg); +} + +static size_t ggml_backend_hexagon_reg_get_device_count(ggml_backend_reg_t reg) { + return opt_ndev; + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_hexagon_reg_get_device(ggml_backend_reg_t reg, size_t index) { + auto hreg = static_cast(reg->context); + + if (index >= opt_ndev || !hreg->devices[index].context) { + return nullptr; + } + + return &hreg->devices[index]; +} + +static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) { + ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_hexagon_device_get_extra_buffers_type; + return (void *) fct; + } + + return NULL; +} + +static void ggml_hexagon_init(ggml_backend_reg * reg) { + // Basic sanity checks to make sure definitions match + static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, + "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0, + "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, + "please update hexagon_type to match ggml_type"); + + const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); + const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); + + opt_verbose = str_verbose ? atoi(str_verbose) : 0; + opt_profile = getenv("GGML_HEXAGON_PROFILE") != nullptr; + opt_etm = getenv("GGML_HEXAGON_ETM") != nullptr; + opt_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL") != nullptr; + + const char * str_opmask = getenv("GGML_HEXAGON_OPMASK"); + if (str_opmask != nullptr) { + opt_opmask = strtoul(str_opmask, NULL, 0); + } + opt_opsync = getenv("GGML_HEXAGON_OPSYNC") != nullptr; + + const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); + if (str_ndev) { + opt_ndev = strtoul(str_ndev, NULL, 0); + if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { + opt_ndev = GGML_HEXAGON_MAX_SESSIONS; + } + } + + const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); + if (str_nhvx) { + opt_nhvx = strtoul(str_nhvx, NULL, 0); + } + + const char * str_arch = getenv("GGML_HEXAGON_ARCH"); + if (str_arch) { + if (str_arch[0] == 'v') { + str_arch++; + } + opt_arch = strtoul(str_arch, NULL, 0); + } + + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1; + + reg->context = new ggml_hexagon_registry(reg); + + HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req), + sizeof(struct htp_general_rsp)); +} + +static const struct ggml_backend_reg_i ggml_backend_hexagon_reg_i = { + /* .get_name = */ ggml_backend_hexagon_reg_get_name, + /* .get_device_count = */ ggml_backend_hexagon_reg_get_device_count, + /* .get_device = */ ggml_backend_hexagon_reg_get_device, + /* .get_proc_address = */ ggml_backend_hexagon_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_hexagon_reg(void) { + static bool initialized = false; + + static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_hexagon_reg_i, + /* .context = */ NULL }; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + ggml_hexagon_init(®); + } + + initialized = true; + } + + return ® +} + +GGML_BACKEND_DL_IMPL(ggml_backend_hexagon_reg) diff --git a/ggml/src/ggml-hexagon/htp-utils.c b/ggml/src/ggml-hexagon/htp-utils.c new file mode 100644 index 0000000000..e8a035af8c --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-utils.c @@ -0,0 +1,448 @@ + +#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wsign-compare" + +#define GGML_COMMON_IMPL_C +#include "ggml-backend-impl.h" +#include "ggml-common.h" +#include "ggml-hexagon.h" +#include "ggml-impl.h" + +#include "htp-utils.h" + +#include +#include +#include +#include +#include +#include +#include + +domain * get_domain(int domain_id) { + int i = 0; + int size = sizeof(supported_domains) / sizeof(domain); + + for (i = 0; i < size; i++) { + if (supported_domains[i].id == domain_id) { + return &supported_domains[i]; + } + } + + return NULL; +} + +bool is_valid_domain_id(int domain_id, int compute_only) { + int i = 0; + int size = sizeof(supported_domains) / sizeof(domain); + + if (compute_only) { + return is_CDSP(domain_id); + } + + for (i = 0; i < size; i++) { + if (supported_domains[i].id == domain_id) { + return true; + } + } + + return false; +} + +int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) { + int nErr = AEE_SUCCESS; + int ss_info = 0; + if (domain_type != NULL) { + if (strcmp(domain_type, "LPASS") == 0) { + ss_info = FASTRPC_LPASS; + } else if (strcmp(domain_type, "HPASS") == 0) { + ss_info = FASTRPC_HPASS; + } else { + ss_info = FASTRPC_NSP; + } + } + system_req_payload req = { 0 }; + req.id = FASTRPC_GET_DOMAINS; + req.sys.domains = NULL; + fastrpc_domain * domain = NULL; + if (ss_info != 0) { + req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info); + } else { + req.sys.flags = 0; + } +#ifdef _WIN32 + nErr = AEE_EUNSUPPORTED; + goto bail; +#endif + if (remote_system_request) { + nErr = remote_system_request(&req); + if (nErr != AEE_SUCCESS) { + GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr); + goto bail; + } + // Allocate memory for domain-info array + req.sys.max_domains = req.sys.num_domains; + if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) { + nErr = AEE_ENOMEMORY; + GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains"); + goto bail; + } + + nErr = remote_system_request(&req); + if (nErr != AEE_SUCCESS) { + GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr); + goto bail; + } + + for (int i = 0; i < req.sys.num_domains; i++) { + // Verify that only requested type domains were returned + domain = &req.sys.domains[i]; + if (domain->type != ss_info && domain_type != NULL) { + nErr = -1; + GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n"); + goto bail; + } + } + *domains_info = req.sys.domains; + *num_domains = req.sys.num_domains; + } else { + nErr = AEE_EUNSUPPORTED; + goto bail; + } +bail: + if (nErr && !req.sys.domains) { + free(req.sys.domains); + } + return nErr; +} + +int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) { + int err = 0; + remote_rpc_effective_domain_id_t sess = { 0 }; + + sess.domain_name = domain_name; + sess.domain_name_len = strlen(domain_name); + sess.session_id = session_id; + + err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess)); + if (err) { + GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name, + session_id); + return err; + } + + *effec_domain_id = sess.effective_domain_id; + return err; +} + +int get_dsp_support(int * domain) { + int nErr = AEE_SUCCESS; + *domain = CDSP_DOMAIN_ID; // DSP domain default value is CDSP_DOMAIN_ID + + if (remote_handle_control) { + struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 }; + nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability)); + if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { + GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); + goto bail; + } + + if (dsp_capability_domain.capability == 0) { + dsp_capability_domain.domain = ADSP_DOMAIN_ID; // Check for ADSP support. + dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT; + dsp_capability_domain.capability = 0; + nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, + sizeof(struct remote_dsp_capability)); + if (dsp_capability_domain.capability) { + *domain = ADSP_DOMAIN_ID; // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID + } + } + + if (nErr != AEE_SUCCESS) { + GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr); + goto bail; + } + } else { + nErr = AEE_EUNSUPPORTEDAPI; + GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); + } + +bail: + return nErr; +} + +int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) { + int nErr = AEE_SUCCESS; + *capability = 0; + + if (attr == VTCM_PAGE || attr == VTCM_COUNT) { + } else { + nErr = AEE_EBADPARM; + GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n"); + goto bail; + } + if (remote_handle_control) { + if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) { + /* + * Query the DSP for VTCM information + * Since the ADSP does not have a dedicated VTCM, we expect the output to be 0 + */ + struct remote_dsp_capability dsp_capability_vtcm_dsp; + dsp_capability_vtcm_dsp.domain = (uint32_t) domain; + dsp_capability_vtcm_dsp.attribute_ID = attr; + dsp_capability_vtcm_dsp.capability = (uint32_t) 0; + nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp, + sizeof(struct remote_dsp_capability)); + if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { + GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); + GGML_LOG_ERROR("Running the usecase without checking the capability\n"); + nErr = AEE_SUCCESS; + goto bail; + } else if (nErr == AEE_SUCCESS) { + *capability = dsp_capability_vtcm_dsp.capability; + } else { + GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr); + goto bail; + } + } else { + nErr = AEE_EUNSUPPORTED; + GGML_LOG_ERROR("Unsupported domain %d\n", domain); + goto bail; + } + } else { + nErr = AEE_EUNSUPPORTEDAPI; + GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); + } + +bail: + return nErr; +} + +bool is_unsignedpd_supported(int domain_id) { + int nErr = AEE_SUCCESS; + if (remote_handle_control) { + struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 }; + nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability)); + if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { + GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n"); + return false; + } + if (nErr) { + GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr); + return false; + } + if (dsp_capability_domain.capability == 1) { + return true; + } + } else { + nErr = AEE_EUNSUPPORTEDAPI; + GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n"); + return false; + } + return false; +} + +bool get_unsignedpd_support(void) { + return is_unsignedpd_supported(CDSP_DOMAIN_ID); +} + +bool is_async_fastrpc_supported(int domain) { + int nErr = AEE_SUCCESS; + if (remote_handle_control) { + if (domain == CDSP_DOMAIN_ID) { + /* + * Query the DSP for ASYNC_FASTRPC_SUPPORT information + * Async fastrpc is supported only on CDSP + */ + struct remote_dsp_capability dsp_capability_async_support; + dsp_capability_async_support.domain = (uint32_t) domain; + dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT; + dsp_capability_async_support.capability = (uint32_t) 0; + nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support, + sizeof(struct remote_dsp_capability)); + if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { + GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); + GGML_LOG_ERROR("Running the usecase without checking the capability\n"); + nErr = AEE_SUCCESS; + goto bail; + } else if (dsp_capability_async_support.capability == 1) { + return true; + } + if (nErr != AEE_SUCCESS) { + GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr); + goto bail; + } + } else { + nErr = AEE_EUNSUPPORTED; + GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain); + goto bail; + } + } else { + nErr = AEE_EUNSUPPORTEDAPI; + GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); + } + +bail: + return false; +} + +bool is_status_notification_supported(int domain) { + int nErr = AEE_SUCCESS; + + if (remote_handle_control) { + /* + * Query the DSP for STATUS_NOTIFICATION_SUPPORT information + * DSP User PD status notification Support + */ + struct remote_dsp_capability dsp_capability_status_notification_support; + dsp_capability_status_notification_support.domain = (uint32_t) domain; + dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT; + dsp_capability_status_notification_support.capability = (uint32_t) 0; + nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support, + sizeof(struct remote_dsp_capability)); + if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { + GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); + GGML_LOG_ERROR("Running the usecase without checking the capability\n"); + nErr = AEE_SUCCESS; + goto bail; + } else if (dsp_capability_status_notification_support.capability == 1) { + return true; + } + if (nErr != AEE_SUCCESS) { + GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr); + goto bail; + } + } else { + nErr = AEE_EUNSUPPORTEDAPI; + GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); + } + +bail: + return false; +} + +int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) { + int nErr = AEE_SUCCESS; + *capability = 0; + + if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) { + nErr = AEE_EBADPARM; + GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n"); + goto bail; + } + if (remote_handle_control) { + if (domain == CDSP_DOMAIN_ID) { + /* + * Query the DSP for HMX SUPPORT information + * HMX is supported on CDSP only + */ + struct remote_dsp_capability dsp_capability_hmx_dsp; + dsp_capability_hmx_dsp.domain = (uint32_t) domain; + dsp_capability_hmx_dsp.attribute_ID = attr; + dsp_capability_hmx_dsp.capability = (uint32_t) 0; + nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp, + sizeof(struct remote_dsp_capability)); + if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { + GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); + GGML_LOG_ERROR("Running the usecase without checking the capability\n"); + nErr = AEE_SUCCESS; + goto bail; + } else if (nErr == AEE_SUCCESS) { + *capability = dsp_capability_hmx_dsp.capability; + } else { + GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr); + goto bail; + } + } else { + nErr = AEE_EUNSUPPORTED; + GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain); + goto bail; + } + } else { + nErr = AEE_EUNSUPPORTEDAPI; + GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); + } + +bail: + return nErr; +} + +int get_hex_arch_ver(int domain, int * arch) { + if (!remote_handle_control) { + GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n"); + return AEE_EUNSUPPORTEDAPI; + } + + struct remote_dsp_capability arch_ver; + arch_ver.domain = (uint32_t) domain; + arch_ver.attribute_ID = ARCH_VER; + arch_ver.capability = (uint32_t) 0; + + int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver)); + if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) { + GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n"); + return AEE_EUNSUPPORTEDAPI; + } + + if (err != AEE_SUCCESS) { + GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err); + return err; + } + + switch (arch_ver.capability & 0xff) { + case 0x73: + *arch = 73; + return 0; + case 0x75: + *arch = 75; + return 0; + case 0x79: + *arch = 79; + return 0; + case 0x81: + *arch = 81; + return 0; + } + return -1; +} + +int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) { + int nErr = AEE_SUCCESS; + *capability = 0; + + if (remote_handle_control) { + if (domain == CDSP_DOMAIN_ID) { + /* + * Query the DSP for HVX SUPPORT information + * HVX is supported on CDSP only + */ + struct remote_dsp_capability dsp_capability_hvx_dsp; + dsp_capability_hvx_dsp.domain = (uint32_t) domain; + dsp_capability_hvx_dsp.attribute_ID = attr; + dsp_capability_hvx_dsp.capability = (uint32_t) 0; + nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp, + sizeof(struct remote_dsp_capability)); + if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { + GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); + GGML_LOG_ERROR("Running the usecase without checking the capability\n"); + nErr = AEE_SUCCESS; + goto bail; + } else if (nErr == AEE_SUCCESS) { + *capability = dsp_capability_hvx_dsp.capability; + } else { + GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr); + goto bail; + } + } else { + nErr = AEE_EUNSUPPORTED; + GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain); + goto bail; + } + } else { + nErr = AEE_EUNSUPPORTEDAPI; + GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); + } + +bail: + return nErr; +} diff --git a/ggml/src/ggml-hexagon/htp-utils.h b/ggml/src/ggml-hexagon/htp-utils.h new file mode 100644 index 0000000000..66f9fd373e --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-utils.h @@ -0,0 +1,219 @@ +#ifndef HTP_UTILS_H +#define HTP_UTILS_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include +#include + +/* Offset to differentiate HLOS and Hexagon error codes. + Stores the value of AEE_EOFFSET for Hexagon. */ +#ifndef DSP_OFFSET +# define DSP_OFFSET 0x80000400 +#endif + +/* Errno for connection reset by peer. */ +#ifndef ECONNRESET +# ifdef __hexagon__ +# define ECONNRESET 104 +# endif +#endif + +/* Abstraction of different OS specific sleep APIs. + SLEEP accepts input in seconds. */ +#ifndef SLEEP +# ifdef __hexagon__ +# define SLEEP(x) \ + { /* Do nothing for simulator. */ \ + } +# else +# ifdef _WINDOWS +# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */ +# else +# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */ +# endif +# endif +#endif + +/* Include windows specific header files. */ +#ifdef _WINDOWS +# include +# include +# define _CRT_SECURE_NO_WARNINGS 1 +# define _WINSOCK_DEPRECATED_NO_WARNINGS 1 +/* Including this file for custom implementation of getopt function. */ +# include "getopt_custom.h" +#endif + +/* Includes and defines for all HLOS except windows */ +#if !defined(__hexagon__) && !defined(_WINDOWS) +# include "unistd.h" + +# include +#endif + +/* Includes and defines for Hexagon and all HLOS except Windows. */ +#if !defined(_WINDOWS) +/* Weak reference to remote symbol for compilation. */ +# pragma weak remote_session_control +# pragma weak remote_handle_control +# pragma weak remote_handle64_control +# pragma weak fastrpc_mmap +# pragma weak fastrpc_munmap +#endif + +#if !defined(_WINDOWS) +# pragma weak remote_system_request +#endif +/** + * Wrapper for FastRPC Capability API: query DSP support. + * + * @param[out] domain pointer to supported domain. + * @return 0 if query is successful. + * non-zero if error, return value points to the error. + */ +int get_dsp_support(int * domain); + +/** + * Wrapper for FastRPC Capability API: query VTCM information. + * + * @param[in] domain value of domain in the queried. + * @param[out] capability capability value of the attribute queried. + * @param[in] attr value of the attribute to the queried. + * @return 0 if query is successful. + * non-zero if error, return value points to the error. + */ +int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr); + +/** + * Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain. + * + * @return true if unsigned pd is supported. + * false if unsigned pd is not supported, capability query failed. + */ + +bool get_unsignedpd_support(void); + +/** + * Wrapper for FastRPC Capability API: query unsigned pd support. + * + * @param[in] domain value of domain in the queried. + * @return true if unsigned pd is supported. + * false if unsigned pd is not supported, capability query failed. + */ + +bool is_unsignedpd_supported(int domain_id); + +/** + * is_valid_domain_id API: query a domain id is valid. + * + * @param[in] domain value of domain in the queried. + * @param[in] compute_only value of domain is only compared with CDSP domains supported by the target when enabled. + * @return true if value of domain is valid. + * false if value of domain is not valid. + */ + +bool is_valid_domain_id(int domain_id, int compute_only); + +/** + * get_domain API: get domain struct from domain value. + * + * @param[in] domain value of a domain + * @return Returns domain struct of the domain if it is supported or else + * returns NULL. + * + */ + +domain * get_domain(int domain_id); + +/** + * get_domains_info API: get information for all the domains available on the device + * + * @param[in] domain_type pointer to domain type + * @param[in] num_domains pointer to number of domains + * @param[in] domains_info pointer to save discovered domains information. + * @return 0 if query is successful. + * non-zero if error, return value points to the error. + * + * It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application. + * + */ + +int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info); + +/** + * get_effective_domain_id API: get effective domain id for given session id + * + * @param[in] domain_name pointer to domain name + * @param[in] session_id + * @param[in] effec_domain_id pointer to save obtained effective domain id. + * @return 0 if query is successful. + * non-zero if error, return value points to the error. + * + */ + +int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id); + +/** + * is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not + * + * @param[in] domain_id value of a domain + * @return Returns true or false stating support of Async FastRPC + * + */ + +bool is_async_fastrpc_supported(int domain_id); + +/** + * is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information + * + * @param[in] domain_id value of a domain + * @return Returns true or false stating status notification support information + * + */ +bool is_status_notification_supported(int domain_id); + +/** + * get_hmx_support_info API: query the DSP for HMX SUPPORT information + * + * @param[in] domain_id value of a domain + * @param[out] capability capability value of the attribute queried. + * @param[in] attr value of the attribute to the queried. + * @return 0 if query is successful. + * non-zero if error, return value points to the error. + * + */ +int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr); + +/** + * get_hex_arch_ver API: query the Hexagon processor architecture version information + * + * @param[in] domain_id value of a domain + * @param[out] Arch version (73, 75, ...) + * @return 0 if query is successful. + * non-zero if error, return value points to the error. + * + */ +int get_hex_arch_ver(int domain, int * arch); + +/** + * get_hvx_support_info API: query the DSP for HVX SUPPORT information + * + * @param[in] domain_id value of a domain + * @param[out] capability capability value of the attribute queried. + * @param[in] attr value of the attribute to the queried. + * @return 0 if query is successful. + * non-zero if error, return value points to the error. + * + */ +int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr); + +#ifdef __cplusplus +} +#endif + +#endif //DSP_CAPABILITIES_UTILS_H diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt new file mode 100644 index 0000000000..22e3fea11d --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -0,0 +1,40 @@ +cmake_minimum_required(VERSION 3.22.2) +project(ggml-htp C CXX ASM) + +include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) + +include_directories( + ${HEXAGON_SDK_ROOT}/incs + ${HEXAGON_SDK_ROOT}/incs/stddef + ${CMAKE_CURRENT_SOURCE_DIR}/../.. + ${CMAKE_CURRENT_SOURCE_DIR}/.. + ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR}) + +set(HTP_LIB ggml-htp-${DSP_VERSION}) + +add_library(${HTP_LIB} SHARED + main.c + htp_iface_skel.c + worker-pool.c + htp-dma.c + hvx-sigmoid.c + hvx-inverse.c + hvx-exp.c + hvx-utils.c + matmul-ops.c + binary-ops.c + unary-ops.c + softmax-ops.c + act-ops.c + rope-ops.c +) + +target_compile_definitions(${HTP_LIB} PRIVATE + $,HTP_DEBUG=1,NDEBUG=1>) + +build_idl(htp_iface.idl ${HTP_LIB}) + +set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) + +install(TARGETS ${HTP_LIB}) diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c new file mode 100644 index 0000000000..16044975d9 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -0,0 +1,448 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define htp_act_preamble3 \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +#define htp_act_preamble2 \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, + const struct htp_tensor * src1, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * src1_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread) { + htp_act_preamble3; + + size_t src0_row_size = nb01; + size_t src1_row_size = nb11; + size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + int is_aligned = 1; + int opt_path = 0; + if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { + is_aligned = 0; + FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n"); + } + if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + bool src1_valid = src1->ne[0]; + if (!src1_valid) { + data_src1 = data_src0; + src1_row_size = src0_row_size; + } + + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); + uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); + + const int32_t swapped = op_params[1]; + + const int nc = (src1_valid) ? ne0 : ne0 / 2; + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { + const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); + const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size)); + float * restrict dst = (float *) (data_dst + (ir * dst_row_size)); + + if (ir + 1 < src0_end_row) { + htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size); + } + + if (!src1_valid) { + src0 += swapped ? nc : 0; + src1 += swapped ? 0 : nc; + } + + if (1 == opt_path) { + hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc); + hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1, + (uint8_t *) dst, nc); + } else { + hvx_exp_f32((const uint8_t *) src0, src0_spad_data, nc, true); + hvx_add_scalar_f32(src0_spad_data, 1.0, src1_spad_data, nc); + hvx_inverse_f32(src1_spad_data, src0_spad_data, nc); + + hvx_mul_f32((const uint8_t *) src0, src0_spad_data, dst_spad_data, nc); + hvx_mul_f32(dst_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "swiglu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, + const struct htp_tensor * src1, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * src1_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread) { + htp_act_preamble3; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const size_t src0_row_size = nb01; + const size_t src1_row_size = nb11; + const size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { + FARF(HIGH, "act-f32: unaligned addresses in activations op, possibly slower execution\n"); + } + + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + bool src1_valid = src1->ne[0]; + if (!src1_valid) { + data_src1 = data_src0; + } + + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); + uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); + + const int32_t swapped = op_params[1]; + const float alpha = ((const float *) (op_params))[2]; + const float limit = ((const float *) (op_params))[3]; + + const int nc = (src1_valid) ? ne0 : ne0 / 2; + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { + const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); + const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size)); + float * restrict dst = (float *) (data_dst + (ir * dst_row_size)); + + if (ir + 1 < src0_end_row) { + htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size); + } + + if (!src1) { + src0 += swapped ? nc : 0; + src1 += swapped ? 0 : nc; + } + + // x (src0_spad_data) = std::min(src0_p[k], limit); + hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc); + // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit); + hvx_clamp_scalar_f32((const uint8_t *) src1, limit, limit, src1_spad_data, nc); + // y (src1_spad_data) = y1 + 1.f + hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc); + // x1 (dst_spad_data) = alpha * (x) + hvx_mul_scalar_f32(src0_spad_data, alpha, dst_spad_data, nc); + // x2 (dst_spad_data) = expf(-x1) + hvx_exp_f32(dst_spad_data, dst_spad_data, nc, true); + // x3 (dst_spad_data) = x2 + 1.f + hvx_add_scalar_f32(dst_spad_data, 1.0, dst_spad_data, nc); + // x4 (dst_spad_data) = 1 / x3 + hvx_inverse_f32(dst_spad_data, dst_spad_data, nc); + // out_glu(dst_spad_data) = x * x4 + hvx_mul_f32(src0_spad_data, dst_spad_data, dst_spad_data, nc); + // out = out_glu * (y + 1.f); + hvx_mul_f32(dst_spad_data, src1_spad_data, (uint8_t *) dst, nc); + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0], + src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], + src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread) { + htp_act_preamble2; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + int is_aligned = 1; + int opt_path = 0; + if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { + is_aligned = 0; + FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n"); + } + if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { + const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); + float * restrict dst = (float *) (data_dst + (ir * dst_row_size)); + + if (ir + 1 < src0_end_row) { + htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size); + } + + if (1 == opt_path) { + hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, ne0); + hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0); + } else { + hvx_exp_f32((const uint8_t *) src0, src0_spad_data, ne0, true); + hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0); + hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0); + + hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "silu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02, + ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, + octx->src0_nrows_per_thread); +} + +static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, + &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread); +} + +static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, + &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread); +} + +static int execute_op_activations_fp32(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) { + FARF(ERROR, "Non-contiguous tensors are not supported at this time \n"); + return HTP_STATUS_NO_SUPPORT; + } + + worker_callback_t act_op_func; + const char * op_type = NULL; + + switch (octx->op) { + case HTP_OP_UNARY_SILU: + act_op_func = unary_silu_fp32; + op_type = "silu-f32"; + break; + + case HTP_OP_GLU_SWIGLU: + act_op_func = glu_swiglu_fp32; + op_type = "swiglu-f32"; + break; + + case HTP_OP_GLU_SWIGLU_OAI: + act_op_func = glu_swiglu_oai_fp32; + op_type = "swiglu-oai-f32"; + break; + + default: + FARF(ERROR, "Unsupported activations Op %u\n", octx->op); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t n_threads = octx->n_threads; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + + const size_t src0_row_size = src0->nb[1]; + const size_t src1_row_size = src1->ne[0] ? src1->nb[1] : src0->nb[1]; + const size_t dst_row_size = dst->nb[1]; + + // VTCM scratchpads for all tensors + // N rows per thread, padded to HVX vector size + octx->dst_spad.size = htp_round_up(dst_row_size, 128) * octx->n_threads; + octx->src0_spad.size = htp_round_up(src0_row_size, 128) * octx->n_threads; + octx->src1_spad.size = htp_round_up(src1_row_size, 128) * octx->n_threads; + + size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + + if (src1->ne[0]) { + FARF(HIGH, + "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], + src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, + octx->dst_spad.size); + } else { + FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + } + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "act-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, + spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t n_jobs = MIN(n_threads, src0_nrows); + + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs); + } + + return err; +} + +int op_activations(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src0.type) { + case HTP_TYPE_F32: + err = execute_op_activations_fp32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c new file mode 100644 index 0000000000..92c0109d28 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -0,0 +1,344 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0, + const uint8_t * src1, + uint8_t * data_dst, + const int num_elems); + +static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 }; +static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt }; + +#define htp_binary_preamble \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +static void binary_job_f32_per_thread(const struct htp_tensor * src0, + const struct htp_tensor * src1, + struct htp_tensor * dst, + uint8_t * spad_data, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + enum htp_op op) { + htp_binary_preamble; + + const size_t src0_row_size = nb01; + const size_t src1_row_size = nb11; + const size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + int is_aligned = 1; + int opt_path = 0; + if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) || + (0 == htp_is_aligned((void *) dst->data, VLEN))) { + FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n"); + is_aligned = 0; + } + if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op]; + + uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size); + + const uint32_t nr0 = ne00 / ne10; + + const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size); + uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size); + + const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + const uint8_t * restrict src1_ptr = NULL; + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { + src1_ptr = data_src1 + (ir % src1_nrows) * src1_row_size; + + if (ir + 1 < src0_end_row) { + htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size); + if (src1_row_size == src0_row_size) { + htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size); + } + } + + if (nr0 > 1) { + if ((1 == is_aligned) && (nr0 == ne00)) { + hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0); + } else { + for (uint32_t r = 0; r < nr0; r++) { + memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11); + } + } + func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00); + } else { + func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00); + } + + src0_ptr += src0_row_size; + dst_ptr += dst_row_size; + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void binary_add_id_job_f32_per_thread(const struct htp_tensor * src0, + const struct htp_tensor * src1, + const struct htp_tensor * src2, + struct htp_tensor * dst, + uint8_t * spad_data, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + hvx_elemwise_f32_func func_HVX) { + htp_binary_preamble; + + const size_t src0_row_size = nb01; + const size_t src1_row_size = nb11; + const size_t dst_row_size = nb1; + + const uint32_t ne02_ne01 = ne02 * ne01; + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) || + (0 == htp_is_aligned((void *) dst->data, VLEN))) { + FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n"); + } + + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { + // src0 indices + const uint32_t i03 = ir / ne02_ne01; + const uint32_t i02 = (ir - i03 * ne02_ne01) / ne01; + const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); + + // src1 indices + const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]); + assert(i11 >= 0 && i11 < ne11); + + float * restrict dst_ptr = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1); + const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01); + const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11); + + if (ir + 1 < src0_end_row) { + htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size); + if (src1_row_size == src0_row_size) { + htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size); + } + } + + const uint32_t nr0 = ne00 / ne10; + if (nr0 > 1) { + for (uint32_t r = 0; r < nr0; r++) { + memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10); + } + func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00); + } else { + func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], + src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + + switch (octx->op) { + case HTP_OP_MUL: + case HTP_OP_ADD: + case HTP_OP_SUB: + binary_job_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->src1_spad.data, n, i, + octx->src0_nrows_per_thread, octx->op); + break; + + case HTP_OP_ADD_ID: + binary_add_id_job_f32_per_thread(&octx->src0, &octx->src1, &octx->src2, &octx->dst, octx->src0_spad.data, n, + i, octx->src0_nrows_per_thread, hvx_add_f32); + break; + + default: + FARF(ERROR, "Unknown Binary Op %u", octx->op); + break; + } +} + +static int execute_op_binary_f32(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + worker_callback_t binary_op_func; + const char * op_type = NULL; + + switch (octx->op) { + case HTP_OP_MUL: + binary_op_func = binary_job_dispatcher_f32; + op_type = "mul-f32"; + break; + + case HTP_OP_ADD: + binary_op_func = binary_job_dispatcher_f32; + op_type = "add-f32"; + break; + + case HTP_OP_SUB: + binary_op_func = binary_job_dispatcher_f32; + op_type = "sub-f32"; + break; + + case HTP_OP_ADD_ID: + binary_op_func = binary_job_dispatcher_f32; + op_type = "add-id-f32"; + break; + + default: + FARF(ERROR, "Unsupported binary-Op %u\n", octx->op); + return HTP_STATUS_NO_SUPPORT; + } + + const int n_threads = octx->n_threads; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + + const size_t src0_row_size = src0->nb[1]; + const size_t src1_row_size = src1->nb[1]; + const size_t dst_row_size = dst->nb[1]; + + // VTCM scratchpads for all tensors + octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; + octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; + octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads; + + size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + + FARF(HIGH, + "%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], + src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, + octx->dst_spad.size); + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, + octx->ctx->vtcm_size, spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t n_jobs = MIN(n_threads, src0_nrows); + + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs); + } + + return err; +} + +int op_binary(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src0.type) { + case HTP_TYPE_F32: + err = execute_op_binary_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake new file mode 100644 index 0000000000..7fa236e328 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake @@ -0,0 +1,157 @@ +if (HEXAGON_TOOLCHAIN_INCLUDED) + return() +endif() +set(HEXAGON_TOOLCHAIN_INCLUDED true) + +#Cross Compiling for Hexagon +set(HEXAGON TRUE) +set(CMAKE_SYSTEM_NAME QURT) +set(CMAKE_SYSTEM_PROCESSOR Hexagon) +set(CMAKE_SYSTEM_VERSION "1") #${HEXAGON_PLATFORM_LEVEL}) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) +set(CUSTOM_RUNELF_PATH "") + +#To fix backward compatibility with EAI addon. +if (NOT HEXAGON_SDK_ROOT) + set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT}) +endif() + +if (NOT HEXAGON_TOOLS_ROOT) + if (DEFINED ENV{HEXAGON_TOOLS_ROOT}) + set(HEXAGON_TOOLS_ROOT $ENV{HEXAGON_TOOLS_ROOT}) + endif() + if(NOT HEXAGON_TOOLS_ROOT) + set(HEXAGON_TOOLS_ROOT $ENV{DEFAULT_HEXAGON_TOOLS_ROOT}) + endif() +endif() + +file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) +file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT) + +#Get the Binary extension of the Hexagon Toolchain +if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows) + set(HEXAGON_TOOLCHAIN_SUFFIX .exe) +endif() +message(DEBUG "CMAKE_HOST_SYSTEM_NAME:${CMAKE_HOST_SYSTEM_NAME}") + +include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_arch.cmake) + +set(HEXAGON_TOOLCHAIN ${HEXAGON_TOOLS_ROOT}) +set(HEXAGON_LIB_DIR "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib") +set(HEXAGON_ISS_DIR ${HEXAGON_TOOLCHAIN}/Tools/lib/iss) + +set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES + HEXAGON_SDK_ROOT + HEXAGON_TOOLS_ROOT +) + +#QURT Related includes and linker flags +set(V_ARCH ${HEXAGON_ARCH}) +set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}") +set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}") + +if( ${TREE} MATCHES PAKMAN ) + set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}") +endif() +message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}") +set(RTOS_DIR ${_QURT_INSTALL_DIR}) +set(QCC_DIR "${HEXAGON_QCC_DIR}/${V_ARCH}/G0") +set(TARGET_DIR "${HEXAGON_LIB_DIR}/${V_ARCH}/G0") + +include_directories( + ${_QURT_INSTALL_DIR}/include + ${_QURT_INSTALL_DIR}/include/qurt + ${_QURT_INSTALL_DIR}/include/posix + ) + +set(QURT_START_LINK_LIBS) +set(QURT_START_LINK_LIBS + "${TARGET_DIR}/init.o" + "${RTOS_DIR}/lib/crt1.o" + "${RTOS_DIR}/lib/debugmon.o" + "${RTOS_DIR}/lib/libqurt.a" + "${TARGET_DIR}/libc.a" + "${TARGET_DIR}/libqcc.a" + "${TARGET_DIR}/libhexagon.a" + "${RTOS_DIR}/lib/libqurtcfs.a" + "${RTOS_DIR}/lib/libtimer_island.a" + "${RTOS_DIR}/lib/libtimer_main.a" + "${RTOS_DIR}/lib/libposix.a" + ) +STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}") + +set(QURT_END_LINK_LIBS + ${TARGET_DIR}/fini.o + ) + +#Non QURT related includes and linker flags + +set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}") + +if (NOT NO_WRAP_MEM_API) + set(WRAP_MALLOC -Wl,--wrap=malloc) + set(WRAP_CALLOC -Wl,--wrap=calloc) + set(WRAP_FREE -Wl,--wrap=free) + set(WRAP_REALLOC -Wl,--wrap=realloc) + set(WRAP_MEMALIGN -Wl,--wrap=memalign) +endif() + +set(PIC_SHARED_LD_FLAGS + -mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} + -G0 + -fpic + -Wl,-Bsymbolic + -Wl,-L${TARGET_DIR_NOOS}/G0/pic + -Wl,-L${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/ + -Wl,--no-threads ${WRAP_MALLOC} ${WRAP_CALLOC} ${WRAP_FREE} ${WRAP_REALLOC} ${WRAP_MEMALIGN} + -shared + "-o " + "" + -Wl,--start-group + "" + "" + -Wl,--end-group + -lc + ) +STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}") + +set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}") + +#System include paths +include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs) +include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef) +include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs) + +#LLVM toolchain setup +#Compiler paths, options and architecture +set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX}) +set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX}) +set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX}) +set(CMAKE_ASM_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX}) +set(HEXAGON_LINKER ${CMAKE_C_COMPILER}) +set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon) + +set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,") +set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,") + +#Compiler Options +set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") + +set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") +set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3") + +set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") +set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") +set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3") + +set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}") +set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}") +set(CMAKE_ASM_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}" ) + +#Linker Options +set(CMAKE_C_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}") +set(CMAKE_CXX_CREATE_SHARED_LIBRARY "${HEXAGON_LINKER} ${HEXAGON_PIC_SHARED_LINK_OPTIONS}") diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h new file mode 100644 index 0000000000..5c3d217f1c --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -0,0 +1,40 @@ +#ifndef HTP_CTX_H +#define HTP_CTX_H + +#include "htp-dma.h" +#include "worker-pool.h" + +#include +#include +#include +#include + +#define HTP_MAX_NTHREADS 10 + +// FIXME: move these into matmul-ops +#define HTP_SPAD_SRC0_NROWS 16 +#define HTP_SPAD_SRC1_NROWS 16 +#define HTP_SPAD_DST_NROWS 2 + +// Main context for htp DSP backend +struct htp_context { + dspqueue_t queue; + dma_queue * dma[HTP_MAX_NTHREADS]; + worker_pool_context_t worker_pool; + uint32_t n_threads; + + int thread_id; + int thread_prio; + + uint8_t * vtcm_base; + size_t vtcm_size; + uint32_t vtcm_rctx; + + atomic_bool vtcm_valid; + atomic_bool vtcm_inuse; + atomic_bool vtcm_needs_release; + + uint32_t opmask; +}; + +#endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.c b/ggml/src/ggml-hexagon/htp/htp-dma.c new file mode 100644 index 0000000000..10c54b45ee --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/htp-dma.c @@ -0,0 +1,69 @@ +#include "htp-dma.h" + +#include +#include +#include + +#pragma clang diagnostic ignored "-Wunused-function" + +static inline uint32_t pow2_ceil(uint32_t x) { + if (x <= 1) { + return 1; + } + int p = 2; + x--; + while (x >>= 1) { + p <<= 1; + } + return p; +} + +dma_queue * dma_queue_create(size_t capacity) { + dma_queue * q = (dma_queue *) memalign(32, sizeof(dma_queue)); + if (q == NULL) { + FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__); + return NULL; + } + + capacity = pow2_ceil(capacity); + + memset(q, 0, sizeof(dma_queue)); + q->capacity = capacity; + q->idx_mask = capacity - 1; + + q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t)); + memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t)); + + q->dst = (void **) memalign(4, capacity * sizeof(void *)); + memset(q->dst, 0, capacity * sizeof(void *)); + + q->tail = &q->desc[capacity - 1]; + + if (!q->desc && !q->dst) { + FARF(ERROR, "%s: failed to allocate DMA queue items\n", __FUNCTION__); + return NULL; + } + + FARF(HIGH, "dma-queue: capacity %u\n", capacity); + + return q; +} + +void dma_queue_delete(dma_queue * q) { + if (!q) { + return; + } + free(q->desc); + free(q->dst); + free(q); +} + +void dma_queue_flush(dma_queue * q) { + while (1) { + uint32_t s = dmwait() & 0x3; + if (s == HEXAGON_UDMA_DM0_STATUS_IDLE) { + break; + } + } + q->tail = NULL; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.h b/ggml/src/ggml-hexagon/htp/htp-dma.h new file mode 100644 index 0000000000..4d0d54ce85 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/htp-dma.h @@ -0,0 +1,119 @@ +#ifndef HTP_DMA_H +#define HTP_DMA_H + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct { + hexagon_udma_descriptor_type1_t * desc; // descriptor pointers + hexagon_udma_descriptor_type1_t * tail; // tail pointer + void ** dst; // dst pointers + uint32_t push_idx; + uint32_t pop_idx; + uint32_t capacity; + uint32_t idx_mask; +} dma_queue; + +dma_queue * dma_queue_create(size_t capacity); +void dma_queue_delete(dma_queue * q); +void dma_queue_flush(dma_queue * q); + +// TODO: technically we don't need these and could use Q6_dmstart/wait/etc instead +// but those do not seem to always compiler properly. +static inline void dmstart(void * next) { + asm volatile(" release(%0):at" : : "r"(next)); + asm volatile(" dmstart(%0)" : : "r"(next)); +} + +static inline void dmlink(void * cur, void * next) { + asm volatile(" release(%0):at" : : "r"(next)); + asm volatile(" dmlink(%0, %1)" : : "r"(cur), "r"(next)); +} + +static inline unsigned int dmpoll(void) { + unsigned int ret = 0; + asm volatile(" %0 = dmpoll" : "=r"(ret) : : "memory"); + return ret; +} + +static inline unsigned int dmwait(void) { + unsigned int ret = 0; + asm volatile(" %0 = dmwait" : "=r"(ret) : : "memory"); + return ret; +} + +static inline bool dma_queue_push(dma_queue * q, + void * dst, + const void * src, + size_t dst_row_size, + size_t src_row_size, + size_t nrows) { + if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { + return false; + } + + hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx]; + + desc->next = NULL; + desc->length = 0; + desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1; + desc->dstbypass = 1; + desc->srcbypass = 1; + desc->order = 0; + desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE; + desc->src = (void *) src; + desc->dst = (void *) dst; + desc->allocation = 0; + desc->padding = 0; + desc->roiwidth = src_row_size; + desc->roiheight = nrows; + desc->srcstride = src_row_size; + desc->dststride = dst_row_size; + desc->srcwidthoffset = 0; + desc->dstwidthoffset = 0; + + q->dst[q->push_idx] = dst; + + dmlink(q->tail, desc); + q->tail = desc; + + // FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src); + q->push_idx = (q->push_idx + 1) & q->idx_mask; + return true; +} + +static inline uint8_t * dma_queue_pop(dma_queue * q) { + if (q->push_idx == q->pop_idx) { + return NULL; + } + + hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx]; + + // Wait for desc to complete + while (1) { + dmpoll(); + if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) { + break; + } + // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); + } + + uint8_t * dst = (uint8_t *) q->dst[q->pop_idx]; + + // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst); + q->pop_idx = (q->pop_idx + 1) & q->idx_mask; + return dst; +} + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* HTP_DMA_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h new file mode 100644 index 0000000000..f23d578806 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -0,0 +1,156 @@ +#ifndef HTP_MSG_H +#define HTP_MSG_H + +#include + +// ggml-common.h must be included prio to this header + +// Mask to enable various stages of the Ops. +// Used for debugging and profiling. +enum { + HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP) + HTP_OPMASK_QUANTIZE = (1 << 1), // Enable Quantize + HTP_OPMASK_COMPUTE = (1 << 2), // Enable Compute +}; + +// Op flags +enum { + HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0), // Skip dynamic quantization (reuse quantized tensors) + HTP_OPFLAGS_SKIP_COMPUTE = (1 << 1), // Skip actual computation (used for profiling) + HTP_OPFLAGS_EARLY_WAKEUP = (1 << 2) // Send early wakeup notification +}; + +enum htp_status { + HTP_STATUS_OK = 1, + HTP_STATUS_INTERNAL_ERR = 2, + HTP_STATUS_NO_SUPPORT = 3, + HTP_STATUS_INVAL_PARAMS = 4, + HTP_STATUS_VTCM_TOO_SMALL = 5, +}; + +// The values must match the ggml_type. +// Duplicated here because we can't include full ggml.h in the htp build. +// We have some static_asserts in the cpp code to ensure things are in sync. +enum htp_data_type { + HTP_TYPE_F32 = 0, + HTP_TYPE_F16 = 1, + HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q8_0 = 8, + HTP_TYPE_MXFP4 = 39, + HTP_TYPE_COUNT +}; + +// These values are manually translated over to HTP +// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!! +enum htp_op { + HTP_OP_MUL = 0, + HTP_OP_ADD = 1, + HTP_OP_SUB = 2, + HTP_OP_DIV = 3, + HTP_OP_MUL_MAT = 4, + HTP_OP_MUL_MAT_ID = 5, + HTP_OP_RMS_NORM = 6, + HTP_OP_UNARY_SILU = 7, + HTP_OP_GLU_SWIGLU = 8, + HTP_OP_GLU_SWIGLU_OAI = 9, + HTP_OP_SOFTMAX = 10, + HTP_OP_ADD_ID = 11, + HTP_OP_ROPE = 12, + INVALID +}; + +static inline size_t htp_type_block_size(uint32_t t) { + switch (t) { + case HTP_TYPE_F32: + return 1; + case HTP_TYPE_F16: + return 1; + case HTP_TYPE_Q4_0: + return QK4_0; + case HTP_TYPE_Q8_0: + return QK8_0; + case HTP_TYPE_MXFP4: + return QK_MXFP4; + default: + assert(0 && "unsupported HTP data type"); + } + return 0; +} + +static inline size_t htp_type_nbytes(uint32_t t) { + switch (t) { + case HTP_TYPE_F32: + return 4; + case HTP_TYPE_F16: + return 2; + case HTP_TYPE_Q4_0: + return sizeof(block_q4_0); + case HTP_TYPE_Q8_0: + return sizeof(block_q8_0); + case HTP_TYPE_MXFP4: + return sizeof(block_mxfp4); + default: + assert(0 && "unsupported HTP data type"); + } + return 0; +} + +static const char * htp_type_name(uint32_t t) { + switch (t) { + case HTP_TYPE_F32: + return "fp32"; + case HTP_TYPE_F16: + return "fp16"; + case HTP_TYPE_Q4_0: + return "q4_0"; + case HTP_TYPE_Q8_0: + return "q8_0"; + case HTP_TYPE_MXFP4: + return "mxfp4"; + } + return 0; +} + +// Internal types +#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) +#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks +#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks + +#define HTP_MAX_DIMS 4 + +struct htp_tensor { + uint32_t data; // Buffer offset in the messages, and data pointer on the NSP + uint32_t type; // Data type + uint32_t ne[HTP_MAX_DIMS]; // Number of elements + uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) +}; + +#define HTP_MAX_OP_PARAMS 64 + +struct htp_general_req { + uint32_t op; // GGML/HTP Op + int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; + // Params for the op, e.g. epsilon of RMS norm + uint32_t flags; // Request flags + + struct htp_tensor src0; // Input0 tensor + struct htp_tensor src1; // Input1 tensor + struct htp_tensor src2; // Input2 tensor + struct htp_tensor dst; // Output tensor + + // should be multiple of 64 bytes (cacheline) +}; + +struct htp_general_rsp { + uint32_t op; // GGML/HTP Op + uint32_t status; // HTP_STATUS_... + uint32_t prof_usecs; // Number of usec per request + uint32_t prof_cycles; // Number of cycles per request + uint32_t prof_pkts; // Number of instruction packets per request + uint8_t unused[44]; // Pad to 64 bytes +}; + +#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req) +#define HTP_MAX_PACKET_BUFFERS 4 + +#endif /* HTP_MSG_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h new file mode 100644 index 0000000000..4572319679 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -0,0 +1,53 @@ +#ifndef HTP_OPS_H +#define HTP_OPS_H + +#include "htp-ctx.h" +#include "htp-msg.h" +#include "worker-pool.h" + +#include +#include + +// ggml-common.h must be included prior to this header + +struct htp_spad { + uint8_t * data; + size_t size; + size_t size_per_thread; +}; + +struct htp_ops_context { + struct htp_context * ctx; + + enum htp_op op; + int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; + + struct htp_tensor src0; + struct htp_tensor src1; + struct htp_tensor src2; + struct htp_tensor dst; + + struct htp_spad src0_spad; + struct htp_spad src1_spad; + struct htp_spad src2_spad; + struct htp_spad dst_spad; + + worker_pool_context_t * wpool; // worker pool + uint32_t n_threads; // num threads + + uint32_t src0_nrows_per_thread; + uint32_t src1_nrows_per_thread; + + uint32_t flags; +}; + +int op_matmul(struct htp_ops_context * octx); +int op_matmul_id(struct htp_ops_context * octx); +int op_binary(struct htp_ops_context * octx); +int op_unary(struct htp_ops_context * octx); +int op_activations(struct htp_ops_context * octx); +int op_softmax(struct htp_ops_context * octx); +int op_add_id(struct htp_ops_context * octx); +int op_rope(struct htp_ops_context * octx); + +#endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl new file mode 100644 index 0000000000..9ebd937e46 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -0,0 +1,16 @@ +// FastRPC IDL interface for GGML HTP + +#ifndef HTP_IDL +#define HTP_IDL + +#include "AEEStdDef.idl" +#include "remote.idl" + +interface htp_iface : remote_handle64 { + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx); + AEEResult stop(); + AEEResult enable_etm(); + AEEResult disable_etm(); +}; + +#endif /* HTP_IDL */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.c b/ggml/src/ggml-hexagon/htp/hvx-exp.c new file mode 100644 index 0000000000..19f6795083 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.c @@ -0,0 +1,80 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) { + int left_over = num_elems & (VLEN_FP32 - 1); + int num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { + FARF(HIGH, "hvx_exp_f32: unaligned address in hvx op, possibly slower execution\n"); + unaligned_addr = 1; + } + // assert((0 == unaligned_addr) || (0 == num_elems_whole)); + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + FARF(HIGH, "hvx_exp_f32: unaligned loop in hvx op, possibly slower execution\n"); + } + + HVX_Vector vec_out = Q6_V_vzero(); + + if (0 == unaligned_loop) { + HVX_Vector * p_vec_in1 = (HVX_Vector *) src; + HVX_Vector * p_vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++); + *p_vec_out++ = hvx_vec_exp_fp32(neg_vec_in); + } else { + *p_vec_out++ = hvx_vec_exp_fp32(*p_vec_in1++); + } + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); + + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(neg_vec_in); + } else { + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32(in); + } + } + } + + if (left_over > 0) { + const float * srcf = (float *) src + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); + + vec_out = hvx_vec_exp_fp32(neg_vec_in); + } else { + vec_out = hvx_vec_exp_fp32(in); + } + + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out); + } +} diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.c b/ggml/src/ggml-hexagon/htp/hvx-inverse.c new file mode 100644 index 0000000000..4cf588a878 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-inverse.c @@ -0,0 +1,60 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { + int left_over = num_elems & (VLEN_FP32 - 1); + int num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { + FARF(HIGH, "hvx_inverse_f32: unaligned address in hvx op, possibly slower execution\n"); + unaligned_addr = 1; + } + // assert((0 == unaligned_addr) || (0 == num_elems_whole)); + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n"); + } + + if (0 == unaligned_loop) { + HVX_Vector * p_vec_in = (HVX_Vector *) src; + HVX_Vector * p_vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + *p_vec_out++ = hvx_vec_inverse_fp32(*p_vec_in++); + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32(in); + } + } + + if (left_over > 0) { + const float * srcf = (float *) src + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + HVX_Vector out = hvx_vec_inverse_fp32(in); + + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out); + } +} diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c new file mode 100644 index 0000000000..15ac64697c --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c @@ -0,0 +1,49 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#if 0 +// Reference algo used in hvx-utils +static void fast_sigmoid_f32(const float* restrict src, float* restrict dst, const int num_elems) +{ + const float c1 = 0.03138777; + const float c2 = 0.276281267; + const float c_log2f = 1.442695022; + + int32_t store_ints[32]; + float store_floats[3][32]; + + for (int i = 0; i < num_elems; i++) + { + float v = src0[i]; + + v *= c_log2f*0.5; + int intPart = (int)v; + float x = (v - intPart); + float xx = x * x; + float v1 = c_log2f + c2 * xx; + float v2 = x + xx * c1 * x; + float v3 = (v2 + v1); + *((int*)&v3) += intPart << 24; + float v4 = v2 - v1; + float v5 = v3 - v4; + float res = v3 / v5; + + dst[i] = res; + } +} +#endif diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.c b/ggml/src/ggml-hexagon/htp/hvx-utils.c new file mode 100644 index 0000000000..d3599bc9c1 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.c @@ -0,0 +1,947 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "hvx-utils.h" + +#define htp_binary_ops_preamble \ + int step_of_4 = num_elems >> 7; \ + int step_of_2 = (num_elems - step_of_4 * VLEN_FP32 * 4) >> 6; \ + int step_of_1 = (num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2) >> 5; \ + int remaining = num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32; \ + \ + const uint8_t * restrict src0_curr = src0; \ + const uint8_t * restrict src1_curr = src1; \ + uint8_t * restrict dst_curr = dst; + +void hvx_mul_f32(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems) { + int left_over = num_elems & (VLEN_FP32 - 1); + int num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) || + (0 == htp_is_aligned((void *) dst, VLEN))) { + FARF(HIGH, "hvx_mul_f32: unaligned address in hvx op, possibly slower execution\n"); + unaligned_addr = 1; + } + + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + FARF(HIGH, "hvx_mul_f32: unaligned loop in hvx op, possibly slower execution\n"); + } + + if (0 == unaligned_loop) { + HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0; + HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1; + HVX_Vector * restrict vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, *vec_in2++); + *vec_out++ = Q6_Vsf_equals_Vqf32(v); + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32); + HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32); + + HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2); + + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); + } + } + + if (left_over > 0) { + const float * src0f = (const float *) src0 + num_elems_whole; + const float * src1f = (const float *) src1 + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in1 = *(HVX_UVector *) src0f; + HVX_Vector in2 = *(HVX_UVector *) src1f; + + HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2); + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); + } +} + +void hvx_mul_f32_opt(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems) { + htp_binary_ops_preamble; + + for (int i = 0; i < step_of_4; i++) { + HVX_Vector v1a = *(HVX_Vector *) src0_curr; + + HVX_Vector v1b = *(HVX_Vector *) src1_curr; + + HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); + + HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b); + + HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); + + HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN); + + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); + + HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN); + + HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN); + + src0_curr += 4 * VLEN; + + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(v3a, v3b); + + *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); + + HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN); + + *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3); + + HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v4a, v4b); + + src1_curr += 4 * VLEN; + + *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4); + + dst_curr += 4 * VLEN; + } + + for (int i = 0; i < step_of_2; i++) { + HVX_Vector v1a = *(HVX_Vector *) src0_curr; + + HVX_Vector v1b = *(HVX_Vector *) src1_curr; + + HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); + + HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b); + + HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); + + src0_curr += 2 * VLEN; + + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b); + + src1_curr += 2 * VLEN; + + *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); + + dst_curr += 2 * VLEN; + } + + for (int i = 0; i < step_of_1; i++) { + HVX_Vector va = *(HVX_Vector *) src0_curr; + + src0_curr += VLEN; + + HVX_Vector vb = *(HVX_Vector *) src1_curr; + + src1_curr += VLEN; + + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(va, vb); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v); + + dst_curr += VLEN; + } + + if (remaining > 0) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); + hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v)); + } +} + +void hvx_mul_mul_f32_opt(const uint8_t * restrict src0, + const uint8_t * restrict src1, + const uint8_t * restrict src2, + uint8_t * restrict dst, + const int num_elems) { + const uint8_t * restrict src0_curr = src0; + const uint8_t * restrict src1_curr = src1; + const uint8_t * restrict src2_curr = src2; + uint8_t * restrict dst_curr = dst; + + int step_of_2 = num_elems >> 6; + int step_of_1 = (num_elems - step_of_2 * VLEN_FP32 * 2) >> 5; + int remaining = num_elems - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32; + + for (int i = 0; i < step_of_2; i++) { + HVX_Vector v1a = *(HVX_Vector *) src0_curr; + HVX_Vector v1b = *(HVX_Vector *) src1_curr; + HVX_Vector v1c = *(HVX_Vector *) src2_curr; + + HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); + + HVX_Vector v1_ = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b); + HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1_), v1c); + + HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); + + HVX_Vector v2c = *(HVX_Vector *) (src2_curr + VLEN); + + src0_curr += 2 * VLEN; + + HVX_Vector v2_ = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2_), v2c); + + src1_curr += 2 * VLEN; + src2_curr += 2 * VLEN; + + *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); + + dst_curr += 2 * VLEN; + } + for (int i = 0; i < step_of_1; i++) { + HVX_Vector va = *(HVX_Vector *) src0_curr; + src0_curr += VLEN; + + HVX_Vector vb = *(HVX_Vector *) src1_curr; + src1_curr += VLEN; + + HVX_Vector vc = *(HVX_Vector *) src2_curr; + src2_curr += VLEN; + + HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(va, vb); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), vc); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v2); + dst_curr += VLEN; + } + if (remaining > 0) { + HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), *(HVX_Vector *) src2_curr); + hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v2)); + } +} + +void hvx_add_f32(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems) { + int left_over = num_elems & (VLEN_FP32 - 1); + int num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) || + (0 == htp_is_aligned((void *) dst, VLEN))) { + FARF(HIGH, "hvx_add_f32: unaligned address in hvx op, possibly slower execution\n"); + unaligned_addr = 1; + } + + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + FARF(HIGH, "hvx_add_f32: unaligned loop in hvx op, possibly slower execution\n"); + } + + if (0 == unaligned_loop) { + HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0; + HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1; + HVX_Vector * restrict vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, *vec_in2++); + *vec_out++ = Q6_Vsf_equals_Vqf32(v); + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32); + HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32); + + HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2); + + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); + } + } + + if (left_over > 0) { + const float * src0f = (const float *) src0 + num_elems_whole; + const float * src1f = (const float *) src1 + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in1 = *(HVX_UVector *) src0f; + HVX_Vector in2 = *(HVX_UVector *) src1f; + + HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2); + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); + } +} + +void hvx_add_f32_opt(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems) { + htp_binary_ops_preamble; + + for (int i = 0; i < step_of_4; i++) { + HVX_Vector v1a = *(HVX_Vector *) src0_curr; + + HVX_Vector v1b = *(HVX_Vector *) src1_curr; + + HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); + + HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b); + + HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); + + HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN); + + HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); + + HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN); + + HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN); + + src0_curr += 4 * VLEN; + + HVX_Vector v3 = Q6_Vqf32_vadd_VsfVsf(v3a, v3b); + + *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); + + HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN); + + *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3); + + HVX_Vector v4 = Q6_Vqf32_vadd_VsfVsf(v4a, v4b); + + src1_curr += 4 * VLEN; + + *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4); + + dst_curr += 4 * VLEN; + } + for (int i = 0; i < step_of_2; i++) { + HVX_Vector v1a = *(HVX_Vector *) src0_curr; + + HVX_Vector v1b = *(HVX_Vector *) src1_curr; + + HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); + + HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b); + + HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); + + src0_curr += 2 * VLEN; + + HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b); + + src1_curr += 2 * VLEN; + + *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); + + dst_curr += 2 * VLEN; + } + for (int i = 0; i < step_of_1; i++) { + HVX_Vector va = *(HVX_Vector *) src0_curr; + + src0_curr += VLEN; + + HVX_Vector vb = *(HVX_Vector *) src1_curr; + + src1_curr += VLEN; + + HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(va, vb); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v); + + dst_curr += VLEN; + } + if (remaining > 0) { + HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); + hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v)); + } +} + +void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { + size_t left_over = num_elems & (VLEN_FP32 - 1); + size_t num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { + FARF(HIGH, "hvx_add_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); + unaligned_addr = 1; + } + + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n"); + } + + HVX_Vector val_vec = hvx_vec_splat_fp32(val); + + if (0 == unaligned_loop) { + HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; + HVX_Vector * restrict vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, val_vec); + *vec_out++ = Q6_Vsf_equals_Vqf32(v); + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); + + HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec); + + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); + } + } + + if (left_over > 0) { + const float * srcf = (const float *) src + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + + HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec); + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); + } +} + +void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { + size_t left_over = num_elems & (VLEN_FP32 - 1); + size_t num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { + FARF(HIGH, "hvx_mul_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); + unaligned_addr = 1; + } + + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + FARF(HIGH, "hvx_mul_scalar_f32: unaligned loop in hvx op, possibly slower execution\n"); + } + + HVX_Vector val_vec = hvx_vec_splat_fp32(val); + + if (0 == unaligned_loop) { + HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; + HVX_Vector * restrict vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, val_vec); + *vec_out++ = Q6_Vsf_equals_Vqf32(v); + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); + + HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec); + + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); + } + } + + if (left_over > 0) { + const float * srcf = (const float *) src + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + + HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec); + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); + } +} + +void hvx_sub_f32(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems) { + size_t left_over = num_elems & (VLEN_FP32 - 1); + size_t num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) || + (0 == htp_is_aligned((void *) dst, VLEN))) { + FARF(HIGH, "hvx_sub_f32: unaligned address in hvx op, possibly slower execution\n"); + unaligned_addr = 1; + } + + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + FARF(HIGH, "hvx_sub_f32: unaligned loop in hvx op, possibly slower execution\n"); + } + + if (0 == unaligned_loop) { + HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0; + HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1; + HVX_Vector * restrict vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, *vec_in2++); + *vec_out++ = Q6_Vsf_equals_Vqf32(v); + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32); + HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32); + + HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2); + + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); + } + } + + if (left_over > 0) { + const float * src0f = (const float *) src0 + num_elems_whole; + const float * src1f = (const float *) src1 + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in1 = *(HVX_UVector *) src0f; + HVX_Vector in2 = *(HVX_UVector *) src1f; + + HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2); + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); + } +} + +void hvx_sub_f32_opt(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems) { + htp_binary_ops_preamble; + + for (int i = 0; i < step_of_4; i++) { + HVX_Vector v1a = *(HVX_Vector *) src0_curr; + + HVX_Vector v1b = *(HVX_Vector *) src1_curr; + + HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); + + HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b); + + HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); + + HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN); + + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); + + HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN); + + HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN); + + src0_curr += 4 * VLEN; + + HVX_Vector v3 = Q6_Vqf32_vsub_VsfVsf(v3a, v3b); + + *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); + + HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN); + + *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3); + + HVX_Vector v4 = Q6_Vqf32_vsub_VsfVsf(v4a, v4b); + + src1_curr += 4 * VLEN; + + *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4); + + dst_curr += 4 * VLEN; + } + for (int i = 0; i < step_of_2; i++) { + HVX_Vector v1a = *(HVX_Vector *) src0_curr; + + HVX_Vector v1b = *(HVX_Vector *) src1_curr; + + HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); + + HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b); + + HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); + + src0_curr += 2 * VLEN; + + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b); + + src1_curr += 2 * VLEN; + + *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); + + dst_curr += 2 * VLEN; + } + for (int i = 0; i < step_of_1; i++) { + HVX_Vector va = *(HVX_Vector *) src0_curr; + + src0_curr += VLEN; + + HVX_Vector vb = *(HVX_Vector *) src1_curr; + + src1_curr += VLEN; + + HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(va, vb); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v); + + dst_curr += VLEN; + } + if (remaining > 0) { + HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); + hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v)); + } +} + +void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { + size_t left_over = num_elems & (VLEN_FP32 - 1); + size_t num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { + FARF(HIGH, "hvx_sub_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); + unaligned_addr = 1; + } + + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + FARF(HIGH, "hvx_sub_scalar_f32: unaligned loop in hvx op, possibly slower execution\n"); + } + + HVX_Vector val_vec = hvx_vec_splat_fp32(val); + + if (0 == unaligned_loop) { + HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; + HVX_Vector * restrict vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, val_vec); + *vec_out++ = Q6_Vsf_equals_Vqf32(v); + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); + + HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec); + + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); + } + } + + if (left_over > 0) { + const float * srcf = (const float *) src + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + + HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec); + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); + } +} + +float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) { + int left_over = num_elems & (VLEN_FP32 - 1); + int num_elems_whole = num_elems - left_over; + + if (0 == htp_is_aligned((void *) src, VLEN)) { + FARF(HIGH, "hvx_sum_of_squares_f32: unaligned address in hvx op, possibly slower execution\n"); + } + + assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole)); + + HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; + + HVX_Vector sum_vec_acc = Q6_V_vsplat_R(0x00000000); + HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000); + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1, *vec_in1); + sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, v); + vec_in1++; + } + + if (left_over > 0) { + const float * srcf = (const float *) src + num_elems_whole; + + HVX_Vector vec_left = *(HVX_UVector *) srcf; + + HVX_Vector vec_left_sq = Q6_Vqf32_vmpy_VsfVsf(vec_left, vec_left); + HVX_Vector vec_tmp = Q6_V_valign_VVR(vec_left_sq, zero_vec, left_over * SIZEOF_FP32); + + sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, vec_tmp); + } + + HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec_acc); + return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v)); +} + +float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) { + int left_over = num_elems & (VLEN_FP32 - 1); + int num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if (0 == htp_is_aligned((void *) src, VLEN)) { + FARF(HIGH, "hvx_self_sum_f32: unaligned address in hvx op, possibly slower execution\n"); + unaligned_addr = 1; + } + + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + FARF(HIGH, "hvx_self_sum_f32: unaligned loop in hvx op, possibly slower execution\n"); + } + + HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000); + HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000); + + if (0 == unaligned_loop) { + HVX_Vector * vec_in = (HVX_Vector *) src; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + // sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, *vec_in++); + sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), *vec_in++); + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); + + sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), in); + } + } + + if (left_over > 0) { + const float * srcf = (const float *) src + num_elems_whole; + + HVX_Vector vec_left = *(HVX_UVector *) srcf; + HVX_Vector vec_tmp = Q6_V_valign_VVR(vec_left, zero_vec, left_over * SIZEOF_FP32); + // sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, vec_tmp); + sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), vec_tmp); + } + + HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec); + return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v)); +} + +void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) { + int left_over = num_elems & (VLEN_FP32 - 1); + int num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { + FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n"); + unaligned_addr = 1; + } + + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n"); + } + + HVX_Vector scale_vec = hvx_vec_splat_fp32(scale); + + if (0 == unaligned_loop) { + HVX_Vector * vec_in1 = (HVX_Vector *) src; + HVX_Vector * vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec); + *vec_out++ = Q6_Vsf_equals_Vqf32(v); + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); + + HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec); + + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); + } + } + + if (left_over > 0) { + const float * srcf = (const float *) src + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + + HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec); + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); + } +} + +float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) { + int left_over = num_elems & (VLEN_FP32 - 1); + int num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if (0 == htp_is_aligned((void *) src, VLEN)) { + FARF(HIGH, "hvx_self_max_f32: unaligned address in hvx op, possibly slower execution\n"); + unaligned_addr = 1; + } + + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + FARF(HIGH, "hvx_self_max_f32: unaligned loop in hvx op, possibly slower execution\n"); + } + + HVX_Vector vec_max = hvx_vec_splat_fp32(((const float *) src)[0]); + HVX_Vector vec_first = hvx_vec_splat_fp32(((const float *) src)[0]); + + if (0 == unaligned_loop) { + HVX_Vector * restrict vec_in = (HVX_Vector *) src; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, *vec_in++); + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); + + vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, in); + } + } + + if (left_over > 0) { + const float * srcf = (const float *) src + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + + HVX_Vector temp = Q6_V_valign_VVR(in, vec_first, left_over * SIZEOF_FP32); + vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, temp); + } + + HVX_Vector v = hvx_vec_reduce_max_fp32(vec_max); + return hvx_vec_get_fp32(v); +} + +void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { + size_t left_over = num_elems & (VLEN_FP32 - 1); + size_t num_elems_whole = num_elems - left_over; + + if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { + FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); + } + + assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole)); + + const float * src_f = (const float *) src; + + HVX_Vector vec_min = Q6_V_vsplat_R(val); + + HVX_Vector * restrict vec_in = (HVX_Vector *) src; + HVX_Vector * restrict vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++); + *vec_out++ = Q6_Vsf_equals_Vqf32(vec_min); + } + + if (left_over > 0) { + const float * srcf = (const float *) src + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + + vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, in); + + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(vec_min)); + } +} + +void hvx_clamp_scalar_f32(const uint8_t * restrict src, + const float limit_left, + const float limit_right, + uint8_t * restrict dst, + const int num_elems) { + size_t left_over = num_elems & (VLEN_FP32 - 1); + size_t num_elems_whole = num_elems - left_over; + + if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { + FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); + } + + assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole)); + + HVX_Vector * restrict vec_in = (HVX_Vector *) src; + HVX_Vector * restrict vec_out = (HVX_Vector *) dst; + + HVX_Vector range_left = hvx_vec_splat_fp32(limit_left); + HVX_Vector range_right = hvx_vec_splat_fp32(limit_right); + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in_vec = *vec_in++; + HVX_Vector temp_v = in_vec; + + HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right); + HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec); + + in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v); + in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v); + + *vec_out++ = Q6_Vsf_equals_Vqf32(in_vec); + } + + if (left_over > 0) { + const float * srcf = (const float *) src + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + + HVX_Vector temp_v = in; + + HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in, range_right); + HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in); + + in = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v); + in = Q6_V_vmux_QVV(pred_cap_left, range_left, temp_v); + + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(in)); + } +} diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h new file mode 100644 index 0000000000..b2ca8e88f4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -0,0 +1,998 @@ +#ifndef HVX_UTILS_H +#define HVX_UTILS_H + +#include "ops-utils.h" + +#include +#include + +#define SIZEOF_FP32 (4) +#define SIZEOF_FP16 (2) +#define VLEN (128) +#define VLEN_FP32 (VLEN / SIZEOF_FP32) +#define VLEN_FP16 (VLEN / SIZEOF_FP16) + +static inline HVX_Vector hvx_vec_splat_fp32(float i) { + union { + float f; + int32_t i; + } fp32 = { .f = i }; + + return Q6_V_vsplat_R(fp32.i); +} + +static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) { + // Rotate as needed. + v = Q6_V_vlalign_VVR(v, v, (size_t) addr); + + uint32_t left_off = (size_t) addr & 127; + uint32_t right_off = left_off + n; + + HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) addr); + HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off); + + if (right_off > 128) { + Q6_vmem_QRIV(qr, (HVX_Vector *) addr + 1, v); + // all 1's + qr = Q6_Q_vcmp_eq_VbVb(v, v); + } + + ql_not = Q6_Q_or_QQn(ql_not, qr); + Q6_vmem_QnRIV(ql_not, (HVX_Vector *) addr, v); +} + +static inline void hvx_vec_store_a(void * ptr, size_t n, HVX_Vector v) { + assert((unsigned long) ptr % 128 == 0); + + HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) ptr); + HVX_VectorPred qr = Q6_Q_vsetq2_R(n); + ql_not = Q6_Q_or_QQn(ql_not, qr); + Q6_vmem_QnRIV(ql_not, (HVX_Vector *) ptr, v); +} + +static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all elements + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + + HVX_Vector ctrl = *(HVX_Vector *) repl; + return Q6_V_vdelta_VV(v, ctrl); +} + +// copy n fp16 elements : source and destination are aligned to HVX Vector (128) +static inline void hvx_copy_fp16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_Vector * restrict vdst = (HVX_Vector *) dst; + HVX_Vector * restrict vsrc = (HVX_Vector *) src; + + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + HVX_Vector v = vsrc[i]; + vdst[i] = v; + } + + if (nloe) { + HVX_Vector v = vsrc[i]; + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v); + } +} + +// copy n fp16 elements : source is aligned, destination is potentially unaligned +static inline void hvx_copy_fp16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; + HVX_Vector * restrict vsrc = (HVX_Vector *) src; + + assert((unsigned long) src % 128 == 0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + HVX_Vector v = vsrc[i]; + vdst[i] = v; + } + + if (nloe) { + HVX_Vector v = vsrc[i]; + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v); + } +} + +// copy n fp16 elements : source is aligned, destination is potentially unaligned +static inline void hvx_copy_fp16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_Vector * restrict vdst = (HVX_Vector *) dst; + HVX_UVector * restrict vsrc = (HVX_UVector *) src; + + assert((unsigned long) dst % 128 == 0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + HVX_Vector v = vsrc[i]; + vdst[i] = v; + } + + if (nloe) { + HVX_Vector v = vsrc[i]; + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v); + } +} + +// copy n fp32 elements : source and destination are aligned to HVX Vector (128) +static inline void hvx_copy_fp32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_Vector * restrict vdst = (HVX_Vector *) dst; + HVX_Vector * restrict vsrc = (HVX_Vector *) src; + + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + + uint32_t nvec = n / 32; + uint32_t nloe = n % 32; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + HVX_Vector v = vsrc[i]; + vdst[i] = v; + } + + if (nloe) { + HVX_Vector v = vsrc[i]; + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); + } +} + +// copy n fp32 elements : source is aligned, destination is unaligned +static inline void hvx_copy_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; + HVX_Vector * restrict vsrc = (HVX_Vector *) src; + + assert((unsigned long) src % 128 == 0); + + uint32_t nvec = n / 32; + uint32_t nloe = n % 32; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + HVX_Vector v = vsrc[i]; + vdst[i] = v; + } + + if (nloe) { + HVX_Vector v = vsrc[i]; + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); + } +} + +// copy n fp32 elements : source is unaligned, destination is aligned +static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_Vector * restrict vdst = (HVX_Vector *) dst; + HVX_UVector * restrict vsrc = (HVX_UVector *) src; + + assert((unsigned long) dst % 128 == 0); + + uint32_t nvec = n / 32; + uint32_t nloe = n % 32; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + HVX_Vector v = vsrc[i]; + vdst[i] = v; + } + + if (nloe) { + HVX_Vector v = vsrc[i]; + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); + } +} + +// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned +static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) { + HVX_Vector * restrict vdst = (HVX_Vector *) dst; + + HVX_Vector velem = hvx_vec_splat_fp32(elem); + + assert((unsigned long) dst % 128 == 0); + + uint32_t nvec = n / 32; + uint32_t nloe = n % 32; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + vdst[i] = velem; + } + + if (nloe) { + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), velem); + } +} + +static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { + uint32_t left_off = (size_t) addr & (chunk_size - 1); + uint32_t right_off = left_off + n; + return right_off <= chunk_size; +} + +static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) { + union { + HVX_Vector v; + __fp16 d[64]; + } u = { .v = v }; + + const uint32_t n0 = n / 16; + const uint32_t n1 = n % 16; + int i = 0; + for (; i < n0; i++) { + htp_dump_fp16_line(pref, u.d + (16 * i), 16); + } + if (n1) { + htp_dump_fp16_line(pref, u.d + (16 * i), n1); + } +} + +static void hvx_vec_dump_fp16(char * pref, HVX_Vector v) { + hvx_vec_dump_fp16_n(pref, v, 64); +} + +static void hvx_vec_dump_fp32_n(char * pref, HVX_Vector v, uint32_t n) { + union { + HVX_Vector v; + float d[32]; + } u = { .v = v }; + + const uint32_t n0 = n / 16; + const uint32_t n1 = n % 16; + int i = 0; + for (; i < n0; i++) { + htp_dump_fp32_line(pref, u.d + (16 * i), 16); + } + if (n1) { + htp_dump_fp32_line(pref, u.d + (16 * i), n1); + } +} + +static void hvx_vec_dump_fp32_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + float d[32]; + } u = { .v = v }; + + FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1], + u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); +} + +static void hvx_vec_dump_fp32(char * pref, HVX_Vector v) { + hvx_vec_dump_fp32_n(pref, v, 32); +} + +static void hvx_vec_dump_int32(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int32_t d[32]; + } u = { .v = v }; + + for (int i = 0; i < 32 / 16; i++) { + htp_dump_int32_line(pref, u.d + (16 * i), 16); + } +} + +static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int32_t d[32]; + } u = { .v = v }; + + FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12], + u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); +} + +static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int8_t d[128]; + } u = { .v = v }; + + FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60], + u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]); +} + +static void hvx_vec_dump_int8(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int8_t d[128]; + } u = { .v = v }; + + for (int i = 0; i < 128 / 16; i++) { + htp_dump_int8_line(pref, u.d + (16 * i), 16); + } +} + +static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + uint8_t d[128]; + } u = { .v = v }; + + for (int i = 0; i < 128 / 16; i++) { + htp_dump_uint8_line(pref, u.d + (16 * i), 16); + } +} + +static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) { + typedef union { + HVX_Vector v; + int8_t d[128]; + } U; + + U u0 = { .v = v0 }; + U u1 = { .v = v1 }; + + for (int i = 0; i < n; i++) { + if (u0.d[i] != u1.d[i]) { + return false; + } + } + + return true; +} + +static inline float hvx_vec_get_fp32(HVX_Vector v) { + float __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 4, v); + return x; +} + +static inline HVX_Vector hvx_vec_int32_reduce_sum_n(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // int32 + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum + width = width << 1; + } + return sum; +} + +static inline HVX_Vector hvx_vec_int32_reduce_sum(HVX_Vector in) { + return hvx_vec_int32_reduce_sum_n(in, 32); +} + +static inline HVX_Vector hvx_vec_qf32_reduce_sum_n(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right + sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum + width = width << 1; + } + return sum; +} + +static inline HVX_Vector hvx_vec_qf32_reduce_sum(HVX_Vector in) { + return hvx_vec_qf32_reduce_sum_n(in, 32); +} + +static inline HVX_Vector hvx_vec_fp32_reduce_sum_n(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum + width = width << 1; + } + return sum; +} + +static inline HVX_Vector hvx_vec_fp32_reduce_sum(HVX_Vector in) { + return hvx_vec_fp32_reduce_sum_n(in, 32); +} + +static inline HVX_Vector hvx_vec_reduce_max_fp16(HVX_Vector in) { + unsigned total = 128; // total vec nbytes + unsigned width = 2; // fp16 nbytes + + HVX_Vector _max = in, _max_t; + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max2_fp16(HVX_Vector in, HVX_Vector _max) { + unsigned total = 128; // total vec nbytes + unsigned width = 2; // fp32 nbytes + + HVX_Vector _max_t; + + _max = Q6_Vhf_vmax_VhfVhf(in, _max); + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max_fp32(HVX_Vector in) { + unsigned total = 128; // total vec nbytes + unsigned width = 4; // fp32 nbytes + + HVX_Vector _max = in, _max_t; + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max2_fp32(HVX_Vector in, HVX_Vector _max) { + unsigned total = 128; // total vec nbytes + unsigned width = 4; // fp32 nbytes + + HVX_Vector _max_t; + + _max = Q6_Vsf_vmax_VsfVsf(in, _max); + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_abs_fp16(HVX_Vector v) { + // abs by clearing the fp16 sign bit + HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); + return Q6_V_vand_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) { + // neg by setting the fp16 sign bit + HVX_Vector mask = Q6_Vh_vsplat_R(0x8000); + return Q6_V_vor_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) { + // abs by clearing the fp32 sign bit + HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff); + return Q6_V_vand_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) { +#if __HTP_ARCH__ > 75 + return Q6_Vsf_vfneg_Vsf(v); +#else + // neg by setting the fp32 sign bit + HVX_Vector mask = Q6_V_vsplat_R(0x80000000); + return Q6_V_vor_VV(v, mask); +#endif // __HTP_ARCH__ > 75 +} + +// ==================================================== +// FUNCTION: 1/(x+1) y(0) = 1, y(0.5) = 0.6667, y(1) = 0.5 +// Order:3; continuity: True; Ends forced: True +// Mode: unsigned; Result fractional bits: 14 +// Peak Error: 1.1295e-04 Rms Error: 2.8410e-05 Mean Error: 1.1370e-05 +// 32769 -32706 31252 -10589 +// 32590 -30635 22793 -4493 +// 32066 -27505 16481 -2348 +// 31205 -24054 11849 -1306 + +static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) { + // input is 0..0xffff representing 0.0 .. 1.0 + HVX_Vector p; + p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull); + p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull); + p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull); + p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull); + return p; // signed result, 14 fractional bits +} + +// Find reciprocal of fp16. +// (1) first, convert to fp32, multiplying by 1.0; this is done to +// handle denormals. Ignoring sign and zero, result should be at +// least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000) +// (exponent in range [103,143]) +// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly +// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32 +// (4) convert that to fp16 +// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace +// the result with the max value. +static inline HVX_Vector hvx_vec_inverse_fp16(HVX_Vector vals) { + HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF); + HVX_Vector avals = Q6_V_vand_VV(vals, em_mask); + HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(avals, vals); + // is too small to 1/x ? for 'standard' fp16, this would be 0x101 + HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals); + + HVX_VectorPair to_qf32 = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00)); // *1.0 + HVX_Vector to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32)); + HVX_Vector to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32)); + + // bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector + HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9)); + // likewise extract the upper 16 from each, containing the exponents in range 103..142 + HVX_Vector exp_u16 = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0); + //Get exponent in IEEE 32-bit representation + exp_u16 = Q6_Vuh_vlsr_VuhR(exp_u16, 7); + + // so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane + // We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0) + // Use poly to transform to 1/x, with 14 fractional bits + // + HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16); + + HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros + + // Get mantissa for 16-bit represenation + HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF)); + + //Compute Reciprocal Exponent + HVX_Vector exp_recip = + Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1))); + //Convert it for 16-bit representation + exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15)); + exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10); + + //Merge exponent and mantissa for reciprocal + HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip); + // map 'small' inputs to standard largest value 0x7bff + recip = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip); + // add sign back + recip = Q6_V_vandor_VQR(recip, is_neg, 0x80008000); + return recip; +} + +#define IEEE_VSF_EXPLEN (8) +#define IEEE_VSF_EXPBIAS (127) +#define IEEE_VSF_EXPMASK (0xFF) +#define IEEE_VSF_MANTLEN (23) +#define IEEE_VSF_MANTMASK (0x7FFFFF) +#define IEEE_VSF_MIMPMASK (0x800000) + +static inline HVX_Vector hvx_vec_truncate_fp32(HVX_Vector in_vec) { + HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); + HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); + HVX_Vector const_zero_v = Q6_V_vzero(); + + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); + + HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; + expval_v &= IEEE_VSF_EXPMASK; + expval_v -= IEEE_VSF_EXPBIAS; + + // negative exp == fractional value + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); + + HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift + + HVX_Vector mant_v = in_vec & mask_mant_v; // obtain mantissa + HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0 + + vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer + vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0 + + HVX_Vector neg_vout = -vout; + + vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives + + return (vout); +} + +static inline HVX_Vector hvx_vec_floor_fp32(HVX_Vector in_vec) { + HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); + HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); + HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN); + HVX_Vector const_zero_v = Q6_V_vzero(); + HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf + + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); + + HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; + expval_v &= IEEE_VSF_EXPMASK; + expval_v -= IEEE_VSF_EXPBIAS; + + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); + HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v); + HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v); + HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec); + + // if expval < 0 (q_negexp) // <0, floor is 0 + // if vin > 0 + // floor = 0 + // if vin < 0 + // floor = -1 + // if expval < mant_len (q_expltmn) // >0, but fraction may exist + // get sign (q_negative) + // mask >> expval // fraction bits to mask off + // vout = ~(mask) // apply mask to remove fraction + // if (qneg) // negative floor is one less (more, sign bit for neg) + // vout += ((impl_mask) >> expval) + // if (mask && vin) + // vout = vin + // else // already an integer + // ; // no change + + // compute floor + mask_mant_v >>= expval_v; + HVX_Vector neg_addin_v = mask_impl_v >> expval_v; + HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v); + HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec); + + HVX_Vector mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v); // chk if bits set + HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v); + + HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear + HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits + + vout = in_vec; + vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval0 -> 0 + vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1 + + return vout; +} + +static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { + // This looks complicated. + // Ideally should just be Q6_Vh_equals_Vhf(vin) + // but that instruction does not do proper rounding. + + // convert to qf32, multiplying by 1.0 in the process. + HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00)); + + // 'in-range' values are +/32752. + // add 192K to it, convert to sf + HVX_Vector v192K = Q6_V_vsplat_R(0x48400000); + HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K)); + HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K)); + + // for in-range cases, result is {163858... 229360} so the exponent is always 144. + // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer. + // Start by <<10 to get the final 'sign' bit in bit 15... + vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10); + vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10); + + // now round down to 16 + return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0); +} + +static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) { + HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3); + HVX_Vector two_sf = hvx_vec_splat_fp32(2.0); + + // First approximation + HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf); + + HVX_Vector r_qf; + + // Refine + r_qf = Q6_Vqf32_vmpy_VsfVsf( + i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf))))); + r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( + r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); + r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( + r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); + + return Q6_Vsf_equals_Vqf32(r_qf); +} + +#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 +#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 +#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267 +#define FAST_SIGMOID_C3 (0x3f000000) // 0.5 + +static inline HVX_Vector hvx_vec_fast_sigmoid_fp32(HVX_Vector v) { + v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); + v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3)); + + HVX_Vector in_int = hvx_vec_truncate_fp32(Q6_Vsf_equals_Vqf32(v)); + HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int)); + HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x); + + HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2)); + v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); + + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1)); + v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx); + v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x); + + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1)); + HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1); + v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24); + v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent); + v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24); + + HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1)); + HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4)); + + HVX_Vector res = hvx_vec_inverse_fp32(v5); + res = Q6_Vqf32_vmpy_VsfVsf(v3, res); + + return Q6_Vsf_equals_Vqf32(res); +} + +#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!) +#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!) +#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!) +#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!) +#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!) +#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!) +#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 +#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 +#define EXP_ONE (0x3f800000) // 1.0 +#define EXP_RANGE_R (0x41a00000) // 20.0 +#define EXP_RANGE_L (0xc1a00000) // -20.0 + +static inline HVX_Vector hvx_vec_exp_fp32(HVX_Vector in_vec) { + HVX_Vector z_qf32_v; + HVX_Vector x_v; + HVX_Vector x_qf32_v; + HVX_Vector y_v; + HVX_Vector k_v; + HVX_Vector f_v; + HVX_Vector epsilon_v; + HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E); + HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2); + HVX_Vector E_const; + HVX_Vector zero_v = Q6_V_vzero(); + + // exp(x) is approximated as follows: + // f = floor(x/ln(2)) = floor(x*log2(e)) + // epsilon = x - f*ln(2) + // exp(x) = exp(epsilon+f*ln(2)) + // = exp(epsilon)*exp(f*ln(2)) + // = exp(epsilon)*2^f + // + // Since epsilon is close to zero, it can be approximated with its Taylor series: + // exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+... + // Preserving the first eight elements, we get: + // exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7 + // = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2 + + HVX_Vector temp_v = in_vec; + + // Clamp inputs to (-20.0, 20.0) + HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R)); + HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec); + + in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v); + in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v); + + epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec); + epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); + + // f_v is the floating point result and k_v is the integer result + f_v = hvx_vec_floor_fp32(epsilon_v); + k_v = hvx_vec_truncate_fp32(f_v); + + x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v); + + // x = x - f_v * logn2; + epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2); + x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v); + // normalize before every QFloat's vmpy + x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); + + // z = x * x; + z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); + z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); + + x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); + + // y = E4 + E5 * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_5); + y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); + E_const = Q6_V_vsplat_R(EXP_COEFF_4); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E3 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_3); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E2 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_2); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E1 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_1); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E0 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_0); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = x + y * z; + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = y + 1.0; + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE)); + + // insert exponents + // y = ldexpf(y, k); + // y_v += k_v; // qf32 + // modify exponent + + y_v = Q6_Vsf_equals_Vqf32(y_v); + + // add k_v to the exponent of y_v + HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1); + + y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1); + y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent); + + // exponent cannot be negative; if overflow is detected, result is set to zero + HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent); + + y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN); + + y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v); + + return y_v; +} + +#define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation +#define RSQRT_ONE_HALF 0x3f000000 // 0.5 +#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5 + +static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) { + //Algorithm : + // x2 = input*0.5 + // y = * (long *) &input + // y = 0x5f3759df - (y>>2) + // y = y*(threehalfs - x2*y*y) + + HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST); + HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF); + HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES); + + HVX_Vector x2, y, ypower2, temp; + + x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf); + x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero()); + + y = Q6_Vw_vasr_VwR(in_vec, 1); + y = Q6_Vw_vsub_VwVw(rsqrtconst, y); + + // 1st iteration + ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp)); + + // 2nd iteration + y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); + ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); + + // 3rd iteration + y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); + ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); + + return Q6_Vsf_equals_Vqf32(temp); +} + +static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { + int step_of_1 = num_elems >> 5; + int remaining = num_elems - step_of_1 * VLEN_FP32; + + assert(remaining == 0); + + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + v_dst[i] = hvx_vec_fast_sigmoid_fp32(v_src[i]); + } +} + +float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems); +void hvx_mul_f32(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems); +void hvx_mul_f32_opt(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems); +void hvx_mul_mul_f32_opt(const uint8_t * restrict src0, + const uint8_t * restrict src1, + const uint8_t * restrict src2, + uint8_t * restrict dst, + const int num_elems); +void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); +void hvx_add_f32(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems); +void hvx_add_f32_opt(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems); +void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); +void hvx_sub_f32(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems); +void hvx_sub_f32_opt(const uint8_t * restrict src0, + const uint8_t * restrict src1, + uint8_t * restrict dst, + const int num_elems); +void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); +void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale); +void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); +void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); +void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate); +float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems); +float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems); +void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); +void hvx_clamp_scalar_f32(const uint8_t * restrict src, + const float limit_left, + const float limit_right, + uint8_t * restrict dst, + const int num_elems); + +#endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c new file mode 100644 index 0000000000..10e2733324 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -0,0 +1,829 @@ +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wunused-function" + +#define FARF_ERROR 1 +#define FARF_HIGH 1 +#define FARF_MEDIUM 0 +#define FARF_LOW 0 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "ops-utils.h" +#include "worker-pool.h" + +AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { + struct htp_context * ctx; + int err = 0; + + ctx = calloc(1, sizeof(*ctx)); + if (ctx == NULL) { + return AEE_ENOMEMORY; + } + + // Use the context structure as a handle + *handle = (remote_handle64) ctx; + + // Enable FARF logs + HAP_setFARFRuntimeLoggingParams(0xffff, NULL, 0); + + // Set client class + { + HAP_power_request_t request; + memset(&request, 0, sizeof(HAP_power_request_t)); + request.type = HAP_power_set_apptype; + request.apptype = HAP_POWER_COMPUTE_CLIENT_CLASS; + + if ((err = HAP_power_set((void *) ctx, &request)) != 0) { + return err; + } + } + + { + HAP_power_request_t request; + memset(&request, 0, sizeof(request)); + + request.type = HAP_power_set_DCVS_v3; + request.dcvs_v3.set_dcvs_enable = TRUE; + request.dcvs_v3.dcvs_enable = TRUE; + request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE; + request.dcvs_v3.set_bus_params = TRUE; + request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.bus_params.target_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.set_core_params = TRUE; + request.dcvs_v3.core_params.min_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.core_params.max_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX; + request.dcvs_v3.set_sleep_disable = TRUE; + request.dcvs_v3.sleep_disable = TRUE; + if ((err = HAP_power_set((void *) ctx, &request)) != 0) { + return err; + } + + memset(&request, 0, sizeof(request)); + request.type = HAP_power_set_HVX; + request.hvx.power_up = TRUE; + if ((err = HAP_power_set((void *) ctx, &request)) != 0) { + return err; + } + } + + { + // Power on HMX + HAP_power_request_t request; + memset(&request, 0, sizeof(HAP_power_request_t)); + request.type = HAP_power_set_HMX; + request.hmx.power_up = TRUE; + FARF(ALWAYS, "Powering HMX on\n"); + err = HAP_power_set((void *) &ctx, &request); + if (err != AEE_SUCCESS) { + FARF(ERROR, "Error powering on HMX."); + return err; + } + } + + return AEE_SUCCESS; +} + +AEEResult htp_iface_close(remote_handle64 handle) { + struct htp_context * ctx = (struct htp_context *) handle; + + if (!ctx) { + return AEE_EBADPARM; + } + + if (ctx->queue) { + FARF(ERROR, "Closing handle with queue still open"); + return AEE_EITEMBUSY; + } + + free(ctx); + return AEE_SUCCESS; +} + +AEEResult htp_iface_enable_etm(remote_handle64 handle) { + int err = HAP_user_etm_enable(); + if (err) { + if (err == AEE_EVERSIONNOTSUPPORT) { + FARF(ERROR, "API HAP_user_etm_enable is not supported\n"); + } else { + FARF(ERROR, "Error executing HAP_user_etm_enable with error code : 0x%x\n", err); + } + } + return err; +} + +AEEResult htp_iface_disable_etm(remote_handle64 handle) { + int err = HAP_user_etm_disable(); + if (err) { + if (err == AEE_EVERSIONNOTSUPPORT) { + FARF(ERROR, "API HAP_user_etm_disable is not supported\n"); + } else { + FARF(ERROR, "Error executing HAP_user_etm_disable with error code : 0x%x\n", err); + } + } + return err; +} + +static int vtcm_acquire(struct htp_context * ctx) { + if (!ctx->vtcm_valid) { + // Temporarily bump thread priority to make sure it's higher than other sessions. + // This way the resource manager will notify the other thread to release VTCM. + // Note that we need to reaquire VTCM at normal priority for this to work next time. + qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10); + HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + HAP_compute_res_release_cached(ctx->vtcm_rctx); + qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio); + + HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + ctx->vtcm_valid = true; + } + + ctx->vtcm_inuse = true; + return 0; +} + +static int vtcm_release(struct htp_context * ctx) { + ctx->vtcm_inuse = false; + + if (ctx->vtcm_valid && ctx->vtcm_needs_release) { + ctx->vtcm_valid = false; + ctx->vtcm_needs_release = false; + HAP_compute_res_release_cached(ctx->vtcm_rctx); + } + + return 0; +} + +static int vtcm_release_callback(unsigned int rctx, void * state) { + struct htp_context * ctx = (struct htp_context *) state; + + if (!ctx || ctx->vtcm_rctx != rctx) { + return AEE_EBADPARM; + } + + // If VTCM is not inuse (not processing Ops) release it right here + // otherwise we'll release it once we're done with the current Op. + + if (ctx->vtcm_inuse) { + ctx->vtcm_needs_release = false; + return 0; + } + + ctx->vtcm_valid = false; + HAP_compute_res_release_cached(ctx->vtcm_rctx); + + return 0; +} + +static int vtcm_alloc(struct htp_context * ctx) { + unsigned int vtcm_size = 8 * 1024 * 1024; // 8MB default + HAP_compute_res_query_VTCM(0, &vtcm_size, NULL, NULL, NULL); + + compute_res_attr_t attr; + HAP_compute_res_attr_init(&attr); + HAP_compute_res_attr_set_serialize(&attr, 0); + HAP_compute_res_attr_set_cache_mode(&attr, 1); + HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size); + HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx); + HAP_compute_res_attr_set_hmx_param(&attr, 1); + + // Allocate VTCM for scratch pads + uint32_t rctx = HAP_compute_res_acquire(&attr, 1000000 /* timeout */); + if (!rctx) { + FARF(ERROR, "failed to allocate %zu bytes VTCM\n", ctx->vtcm_size); + return AEE_ENOMEMORY; + } + + void * vtcm_ptr; + if (HAP_compute_res_attr_get_vtcm_ptr_v2(&attr, &vtcm_ptr, &vtcm_size) != 0) { + HAP_compute_res_release(rctx); + FARF(ERROR, "failed to allocate %zu bytes VTCM (new)\n", ctx->vtcm_size); + return AEE_ENOMEMORY; + } + + ctx->vtcm_base = (uint8_t *) vtcm_ptr; + ctx->vtcm_size = vtcm_size; + ctx->vtcm_rctx = rctx; + ctx->vtcm_valid = false; + ctx->vtcm_inuse = false; + ctx->vtcm_needs_release = false; + + return 0; +} + +static void vtcm_free(struct htp_context * ctx) { + if (ctx->vtcm_rctx) { + HAP_compute_res_release(ctx->vtcm_rctx); + ctx->vtcm_base = 0; + ctx->vtcm_rctx = 0; + } +} + +static void htp_packet_callback(dspqueue_t queue, int error, void * context); +static void htp_error_callback(dspqueue_t queue, int error, void * context); + +AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) { + struct htp_context * ctx = (struct htp_context *) handle; + + if (!ctx) { + return AEE_EBADPARM; + } + + if (ctx->queue) { + FARF(ERROR, "Queue already open"); + return AEE_EITEMBUSY; + } + + // Import queue created on the CPU + int err = dspqueue_import(dsp_queue_id, // Queue ID from dspqueue_export + htp_packet_callback, // Packet callback + htp_error_callback, // Error callback; no errors expected on the DSP + (void *) ctx, // Callback context + &ctx->queue); + + if (err) { + FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err); + return err; + } + + ctx->thread_id = qurt_thread_get_id(); + ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id); + + // allocate VTCM + err = vtcm_alloc(ctx); + if (err != AEE_SUCCESS) { + FARF(ERROR, "Unable to allocate VTCM"); + return AEE_ENOMEMORY; + } + + qurt_sysenv_max_hthreads_t hw_threads; + qurt_sysenv_get_max_hw_threads(&hw_threads); + uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF; + + if (n_hvx == 0) { + n_hvx = hw_nhvx; + } + if (n_hvx > hw_threads.max_hthreads) { + n_hvx = hw_threads.max_hthreads; + } + if (n_hvx > HTP_MAX_NTHREADS) { + n_hvx = HTP_MAX_NTHREADS; + } + + ctx->n_threads = n_hvx; + for (int i = 0; i < ctx->n_threads; i++) { + ctx->dma[i] = dma_queue_create(HTP_SPAD_SRC0_NROWS * 2); + } + + // init worker pool + err = worker_pool_init(&ctx->worker_pool, n_hvx); + if (err != AEE_SUCCESS) { + FARF(ERROR, "Unable to create worker pool"); + return err; + } + + FARF(HIGH, "session %u started: n-hvx %u vtcm-size %zu vtcm-rctx %u n-threads %u thread-id %d thread-prio %d \n", + sess_id, hw_nhvx, ctx->vtcm_size, ctx->vtcm_rctx, ctx->n_threads, ctx->thread_id, ctx->thread_prio); + + return AEE_SUCCESS; +} + +AEEResult htp_iface_stop(remote_handle64 handle) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } + + if (!ctx->queue) { + FARF(ERROR, "Queue not open"); + return AEE_EBADSTATE; + } + + // Close queue. dspqueue_close() will also wait for callbacks to finish. + int err = dspqueue_close(ctx->queue); + ctx->queue = NULL; + if (err != 0) { + FARF(ERROR, "Queue close failed with 0x%08x", (unsigned) err); + return err; + } + + if (ctx->worker_pool) { + // Release worker pool + worker_pool_release(&ctx->worker_pool); + } + + for (int i = 0; i < ctx->n_threads; i++) { + dma_queue_delete(ctx->dma[i]); + } + + vtcm_free(ctx); + + return AEE_SUCCESS; +} + +static void htp_error_callback(dspqueue_t queue, int error, void * context) { + // No errors expected on the DSP. + FARF(ERROR, "Error callback: 0x%08x", (unsigned) error); +} + +struct profile_data { + uint64_t usecs; + uint64_t cycles; + uint64_t pkts; +}; + +static inline void profile_start(struct profile_data * d) { + d->usecs = HAP_perf_get_qtimer_count(); + d->cycles = htp_get_cycles(); + d->pkts = htp_get_pktcnt(); +} + +static inline void profile_stop(struct profile_data * d) { + d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); + d->cycles = htp_get_cycles() - d->cycles; + d->pkts = htp_get_pktcnt() - d->pkts; +} + +static int send_htp_rsp(struct htp_context * c, + uint32_t op, + uint32_t status, + struct dspqueue_buffer * bufs, + size_t n_bufs, + struct profile_data * prof) { + // Prep response struct + struct htp_general_rsp rsp; + rsp.op = op; + rsp.status = status; + rsp.prof_usecs = prof->usecs; + rsp.prof_cycles = prof->cycles; + rsp.prof_pkts = prof->pkts; + + int err = dspqueue_write(c->queue, + 0, // Flags + n_bufs, + bufs, // Buffer references + sizeof(rsp), + (const uint8_t *) &rsp, // Message + DSPQUEUE_TIMEOUT_NONE); + + if (err != 0) { + FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); + } + + return err; +} + +static void proc_matmul_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + size_t n_bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_matmul(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_matmul_id_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + size_t n_bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[3].fd; + rsp_bufs[0].ptr = bufs[3].ptr; + rsp_bufs[0].size = bufs[3].size; + rsp_bufs[0].offset = bufs[3].offset; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.src2 = req->src2; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.src2.data = (uint32_t) bufs[2].ptr; + octx.dst.data = (uint32_t) bufs[3].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_matmul_id(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_binary(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[3].fd; + rsp_bufs[0].ptr = bufs[3].ptr; + rsp_bufs[0].offset = bufs[3].offset; + rsp_bufs[0].size = bufs[3].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.src2 = req->src2; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.src2.data = (uint32_t) bufs[2].ptr; + octx.dst.data = (uint32_t) bufs[3].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_binary(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_unary(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_activations_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + uint32_t n_bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + int write_idx = (n_bufs == 3) ? 2 : 1; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[write_idx].fd; + rsp_bufs[0].ptr = bufs[write_idx].ptr; + rsp_bufs[0].offset = bufs[write_idx].offset; + rsp_bufs[0].size = bufs[write_idx].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + if (3 == n_bufs) { + octx.src1 = req->src1; + } + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + if (3 == n_bufs) { + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + } else { + octx.dst.data = (uint32_t) bufs[1].ptr; + } + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + if (octx.op == HTP_OP_SOFTMAX) { + rsp_status = op_softmax(&octx); + } else { + rsp_status = op_activations(&octx); + } + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_rope_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + uint32_t n_bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + int write_idx = (n_bufs == 4) ? 3 : 2; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[write_idx].fd; + rsp_bufs[0].ptr = bufs[write_idx].ptr; + rsp_bufs[0].offset = bufs[write_idx].offset; + rsp_bufs[0].size = bufs[write_idx].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + if (4 == n_bufs) { + octx.src2 = req->src2; + } + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + if (4 == n_bufs) { + octx.src2.data = (uint32_t) bufs[2].ptr; + octx.dst.data = (uint32_t) bufs[3].ptr; + } else { + octx.dst.data = (uint32_t) bufs[2].ptr; + } + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_rope(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void htp_packet_callback(dspqueue_t queue, int error, void * context) { + struct htp_context * ctx = (struct htp_context *) context; + + // Repeatedly read packets from the queue until it's empty. We don't + // necessarily get a separate callback for each packet, and new packets + // may arrive while we're processing the previous one. This ensures we + // keep the DSP busy as much as possible and avoid waiting for the CPU. + + while (1) { + struct htp_general_req req; + uint32_t req_size; + + struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; + uint32_t n_bufs; + uint32_t flags; + + // Read packet from queue + int err = dspqueue_read_noblock(queue, &flags, + HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references + &n_bufs, // Number of buffer references + bufs, // Buffer references + sizeof(req), // Max message length + &req_size, // Message length + (uint8_t *) &req); // Message + + if (err == AEE_EWOULDBLOCK) { + // Consumed all packets available for now + return; + } + + if (err != 0) { + FARF(ERROR, "dspqueue_read_noblock failed: 0x%08x", (unsigned) err); + return; + } + + if (req_size != sizeof(req)) { + FARF(ERROR, "Invalid request size"); + continue; + } + + if (req.flags & HTP_OPFLAGS_EARLY_WAKEUP) { + // Host wants early notification + dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); + } + + // Process packet based on its message type + switch (req.op) { + case HTP_OP_MUL_MAT: + if (n_bufs != 3) { + FARF(ERROR, "Bad matmul-req buffer list"); + continue; + } + proc_matmul_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_MUL_MAT_ID: + if (n_bufs != 4) { + FARF(ERROR, "Bad matmul-id-req buffer list"); + continue; + } + proc_matmul_id_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_MUL: + case HTP_OP_ADD: + case HTP_OP_SUB: + if (n_bufs != 3) { + FARF(ERROR, "Bad binary-req buffer list"); + continue; + } + proc_binary_req(ctx, &req, bufs); + break; + + case HTP_OP_RMS_NORM: + if (n_bufs != 2) { + FARF(ERROR, "Bad unary-req buffer list"); + continue; + } + + proc_unary_req(ctx, &req, bufs); + break; + + case HTP_OP_UNARY_SILU: + if (n_bufs != 2) { + FARF(ERROR, "Bad act-req buffer list"); + continue; + } + proc_activations_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_GLU_SWIGLU: + case HTP_OP_SOFTMAX: + if ((n_bufs != 2) && (n_bufs != 3)) { + FARF(ERROR, "Bad act-req buffer list"); + continue; + } + proc_activations_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_ADD_ID: + if (n_bufs != 4) { + FARF(ERROR, "Bad add-id-req buffer list"); + continue; + } + proc_add_id_req(ctx, &req, bufs); + break; + + case HTP_OP_ROPE: + if ((n_bufs != 3) && (n_bufs != 4)) { + FARF(ERROR, "Bad rope-req buffer list"); + continue; + } + proc_rope_req(ctx, &req, bufs, n_bufs); + break; + + default: + FARF(ERROR, "Unknown Op %u", req.op); + break; + } + } +} diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c new file mode 100644 index 0000000000..c99b6a0d18 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -0,0 +1,2223 @@ +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +struct htp_matmul_type { + const char * type; + void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); + void (*vec_dot_rx2)(const int n, + float * restrict s, + const void * restrict vx, + uint32_t vx_row_size, + const void * restrict vy); +}; + +typedef struct { + HVX_Vector v[2]; +} HVX_Vector_x2; + +typedef struct { + HVX_Vector v[4]; +} HVX_Vector_x4; + +typedef struct { + HVX_Vector v[8]; +} HVX_Vector_x8; + +// vdelta control to replicate first 4x fp32 values across lanes +static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, + 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, + 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, + 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, + 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, + 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, +}; + +// vdelta control to replicate and interleave first 8x fp32 values across lanes +static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_fp32[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, + 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, + 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, + 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44, + 0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, + 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, +}; + +// vdelta control to replicate first fp32 value across all elements +static const uint8_t __attribute__((aligned(128))) repl_1x_fp32[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, + 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, + 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, + 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, + 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, + 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, + 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, +}; + +// vdelta control to replicate first fp16 value across all elements +static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = { + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, + 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, + 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, + 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, + 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, + 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, +}; + +// vdelta control to expand first 32 e8m0 values into 32 uint32 elements +static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { + 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00, + 0x00, 0x11, 0x10, 0x10, 0x10, 0x02, 0x00, 0x04, 0x00, 0x01, 0x02, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, + 0x00, 0x00, 0x22, 0x20, 0x20, 0x20, 0x21, 0x22, 0x20, 0x24, 0x04, 0x00, 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x02, + 0x00, 0x04, 0x00, 0x11, 0x12, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, + 0x01, 0x02, 0x00, 0x04, 0x44, 0x40, 0x40, 0x40, 0x41, 0x40, 0x40, 0x40, 0x42, 0x40, 0x44, 0x40, 0x41, 0x42, 0x48, + 0x48, 0x08, 0x08, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x12, 0x10, 0x10, 0x10, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, + 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20, +}; + +static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { + 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0, + 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + +// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales + +static inline size_t q8x4x2_row_size(uint32_t ne) { + // ensures perfect alignment of quants and full row + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (ne + qk - 1) / qk; + return htp_round_up(ne + nb * 8 * sizeof(__fp16), 128); +} + +static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + // Convert uint4 to int4 (i.e. x - 8) + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + v0 = Q6_Vb_vsub_VbVb(v0, i8); + v1 = Q6_Vb_vsub_VbVb(v1, i8); + v2 = Q6_Vb_vsub_VbVb(v2, i8); + v3 = Q6_Vb_vsub_VbVb(v3, i8); + v4 = Q6_Vb_vsub_VbVb(v4, i8); + v5 = Q6_Vb_vsub_VbVb(v5, i8); + v6 = Q6_Vb_vsub_VbVb(v6, i8); + v7 = Q6_Vb_vsub_VbVb(v7, i8); + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); + v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); + v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); + v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); + v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); + v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0 = vptr[0]; // first 128 vals + HVX_Vector v1 = vptr[1]; // ... + HVX_Vector v2 = vptr[2]; // ... + HVX_Vector v3 = vptr[3]; // ... + HVX_Vector v4 = vptr[4]; // ... + HVX_Vector v5 = vptr[5]; // ... + HVX_Vector v6 = vptr[6]; // ... + HVX_Vector v7 = vptr[7]; // ... + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0 = vptr[0]; // first 64 vals + HVX_Vector v1 = vptr[1]; // second 64 vals + HVX_Vector v2 = vptr[2]; // third 64 vals + HVX_Vector v3 = vptr[3]; // forth 64 vals + + HVX_Vector_x4 r = { v0, v1, v2, v3 }; + return r; +} + +static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) { + const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr; + + HVX_VectorPair v0 = vptr[0]; // first 64 vals + HVX_VectorPair v1 = vptr[1]; // second 64 vals + HVX_VectorPair v2 = vptr[2]; // third 64 vals + HVX_VectorPair v3 = vptr[3]; // forth 64 vals + + HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero()); + HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero()); + HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero()); + HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero()); + HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero()); + HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero()); + HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero()); + HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero()); + + HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo)); + HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo)); + HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo)); + HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo)); + + // vcombine does a shuffle, use vdeal to undo + + HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) }; + return r; +} + +// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). +// Accumulate each block into a single int32 value. +// Return a single HVX vector with 32x int32 accumulators. +// This version is parameterized to support less than 1024 elements. +// if() checks are optimized out at compile time -- make sure to pass N as a constexpr. + +static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { + HVX_Vector r0 = Q6_V_vsplat_R(0); + HVX_Vector r1 = Q6_V_vsplat_R(0); + HVX_Vector r2 = Q6_V_vsplat_R(0); + HVX_Vector r3 = Q6_V_vsplat_R(0); + HVX_Vector r4 = Q6_V_vsplat_R(0); + HVX_Vector r5 = Q6_V_vsplat_R(0); + HVX_Vector r6 = Q6_V_vsplat_R(0); + HVX_Vector r7 = Q6_V_vsplat_R(0); + + HVX_VectorPair p3; + HVX_VectorPair p2; + HVX_VectorPair p1; + HVX_VectorPair p0; + + if (n >= 128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); } + if (n >= 256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); } + if (n >= 384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); } + if (n >= 512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); } + if (n >= 640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); } + if (n >= 768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); } + if (n >= 896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); } + if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); } + + if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } + if (n >= 384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); } + if (n >= 640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); } + if (n >= 896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); } + + if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } + if (n >= 384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); } + if (n >= 640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); } + if (n >= 896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); } + + if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } + if (n >= 640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); } + + if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } + if (n >= 640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); } + + if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } + if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } + + return r0; +} + +static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) { + return hvx_vec_rmpy_x8_n(x, y, 1024); +} + +// Handle most common cases of tensors not multiple of 1024. +static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { + if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); }; + if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); }; + if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); }; + return hvx_vec_rmpy_x8_n(x, y, 1024); +} + +static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx % 128 == 0); + assert((unsigned long) vy % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elemements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + } + + // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + } + + // Reduce and convert into fp32 + r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); + + hvx_vec_store_u(&s[0], 4, r0_sum); +} + +static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, + float * restrict s, + const void * restrict vx, + uint32_t vx_row_size, + const void * restrict vy) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx % 128 == 0); + assert((unsigned long) vy % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + + const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elemements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + } + + // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + } + + // Convert into fp32 and reduce + r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); + r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); + HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); + + hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); +} + +static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx % 128 == 0); + assert((unsigned long) vy % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + } + + // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + } + + // Reduce and convert into fp32 + r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); + + hvx_vec_store_u(&s[0], 4, r0_sum); +} + +static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, + float * restrict s, + const void * restrict vx, + uint32_t vx_row_size, + const void * restrict vy) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx % 128 == 0); + assert((unsigned long) vy % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + + const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + } + + // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + } + + // Convert into fp32 and reduce + r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); + r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); + HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); + + hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); +} + +static void vec_dot_mxfp4x4x2_q8x4x2(const int n, + float * restrict s, + const void * restrict vx, + const void * restrict vy) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx % 128 == 0); + assert((unsigned long) vy % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + + // Zero-out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + } + + // Reduce and convert into fp32 + r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); + + hvx_vec_store_u(&s[0], 4, r0_sum); +} + +static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, + float * restrict s, + const void * restrict vx, + uint32_t vx_row_size, + const void * restrict vy) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx % 128 == 0); + assert((unsigned long) vy % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + + const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_sum = Q6_V_vsplat_R(0); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + + // Zero-out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + } + + // Convert into fp32 and reduce + r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); + r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); + HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); + + hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); +} + +#if 1 +static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + if (0) { + float rsum = 0; + const __fp16 * restrict vx = (const __fp16 * restrict) x; + const float * restrict vy = (const float * restrict) y; + + for (uint32_t i = 0; i < n; i++) { + rsum += vx[i] * (__fp16) vy[i]; + } + *s = rsum; + return; + } + + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y; + + uint32_t nv0 = n / 64; // num full fp16 hvx vectors + uint32_t nv1 = n % 64; // leftover elements + + // for some reason we need volatile here so that the compiler doesn't try anything funky + volatile HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + for (i = 0; i < nv0; i++) { + HVX_VectorPair yp = vy[i]; + + HVX_Vector x = vx[i]; + HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 + + HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); + HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); + + HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); + } + + if (nv1) { + HVX_VectorPair yp = vy[i]; + + HVX_Vector x = vx[i]; + HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 + + if (nv1 >= 32) { + HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi); + nv1 -= 32; + } + + rsum = hvx_vec_qf32_reduce_sum(rsum); + + if (nv1) { + HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); + HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); + } + + // hvx_vec_dump_fp16("X", x); + // hvx_vec_dump_fp16("Y", y); + // hvx_vec_dump_fp32("SUM", Q6_Vsf_equals_Vqf32(sum)); + // hvx_vec_dump_fp32("RSUM", Q6_Vsf_equals_Vqf32(rsum)); + } else { + rsum = hvx_vec_qf32_reduce_sum(rsum); + } + + *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)); + +# ifdef HTP_DEBUG + { + float rsum = 0; + const __fp16 * restrict vx = (const __fp16 * restrict) x; + const float * restrict vy = (const float * restrict) y; + + for (uint32_t i = 0; i < n; i++) { + rsum += vx[i] * vy[i]; + } + + float diff = fabs(*s - rsum); + if (diff > 0.001) { + FARF(HIGH, "vec-dot-f16-missmatch: %u (%u:%u) expected %.6f got %.6f\n", n, nv0, nv1, rsum, *s); + // htp_dump_f16("x", vx, n); + // htp_dump_f32("y", vy, n); + } + } +# endif +} +#else +static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const uint32_t fk = 64; + const uint32_t nb = n / fk; + + assert(n % fk == 0); + assert(nb % 4 == 0); + + const uint32_t x_blk_size = 2 * fk; // fp16 + const uint32_t y_blk_size = 4 * fk; // fp32 + + // Row sum (qf32) + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); + HVX_Vector rsum2 = Q6_V_vsplat_R(0); + HVX_Vector rsum3 = Q6_V_vsplat_R(0); + + for (uint32_t i = 0; i < nb; i += 4) { + HVX_Vector_x4 vx = hvx_vec_load_x4_f16(x + (i * x_blk_size)); + HVX_Vector_x4 vy = hvx_vec_load_x4_f32_as_f16(y + (i * y_blk_size)); + + HVX_VectorPair fa0 = Q6_Wqf32_vmpy_VhfVhf(vx.v[0], vy.v[0]); + HVX_VectorPair fa1 = Q6_Wqf32_vmpy_VhfVhf(vx.v[1], vy.v[1]); + HVX_VectorPair fa2 = Q6_Wqf32_vmpy_VhfVhf(vx.v[2], vy.v[2]); + HVX_VectorPair fa3 = Q6_Wqf32_vmpy_VhfVhf(vx.v[3], vy.v[3]); + + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa0), Q6_V_hi_W(fa0))); + rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa1), Q6_V_hi_W(fa1))); + rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa2), Q6_V_hi_W(fa2))); + rsum3 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum3, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa3), Q6_V_hi_W(fa3))); + } + + // Reduce and convert into fp32 + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum1); + rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, rsum3); + HVX_Vector rsum = hvx_vec_qf32_reduce_sum(Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum2)); + hvx_vec_store_u(s, 4, Q6_Vsf_equals_Vqf32(rsum)); +} +#endif + +#define htp_matmul_preamble \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +// q8x4 src1 tensor is already in VTCM spad +static void matmul(struct htp_matmul_type * mt, + struct htp_tensor * restrict src0, + struct htp_tensor * restrict src1, + struct htp_tensor * restrict dst, + struct htp_spad * restrict src0_spad, + struct htp_spad * restrict src1_spad, + struct htp_spad * restrict dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_matmul_preamble; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const size_t dst_row_size = nb1; + const size_t src0_row_size = nb01; + const size_t src1_row_size = q8x4x2_row_size(ne10); + + const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + + // Per-thread VTCM scratchpads for all tensors + // Note that the entire src1 tensor is already in VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; + uint8_t * restrict src1_data = src1_spad->data; + + volatile uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; + + // Prefill spad with src0 rows + #pragma unroll(4) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= HTP_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + src0_row_size_padded, src0_row_size, 2); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue); + + #pragma unroll(2) + for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size); + float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); + mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + } + + // Prefetch next (n + spad_nrows) row + const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size, + src0_row_size_padded, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const int is0 = (ir0 - src0_start_row); + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + src0_row_size_padded, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue); + + #pragma unroll(2) + for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size); + float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); + mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], + src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// q8x4x2 src1 tensor is already in VTCM spad +static void matvec(struct htp_matmul_type * mt, + struct htp_tensor * restrict src0, + struct htp_tensor * restrict src1, + struct htp_tensor * restrict dst, + struct htp_spad * restrict src0_spad, + struct htp_spad * restrict src1_spad, + struct htp_spad * restrict dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_matmul_preamble; + + const uint32_t src0_nrows = ne01; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const size_t dst_row_size = nb1; + const size_t src0_row_size = nb01; + const size_t src1_row_size = q8x4x2_row_size(ne10); + + const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + + // Per-thread VTCM scratchpads for all tensors + // Note that the entire src1 tensor is already in VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + uint8_t * spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; + uint8_t * src1_data = src1_spad->data; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + float * tmp = (float *) spad_dst; + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; + const uint8_t * restrict src1_col = (const uint8_t *) src1_data; + float * restrict dst_col = (float *) dst->data; + + // Prefill spad with 2x src0 rows + #pragma unroll(2) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= HTP_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + src0_row_size_padded, src0_row_size, 2); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue); + mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_row_size_padded, src1_col); + + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size, + src0_row_size_padded, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + const uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row); + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + src0_row_size_padded, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue); + mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + } + + hvx_copy_fp32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], + src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)] + +struct mmid_row_mapping { + uint32_t i1; + uint32_t i2; +}; + +// q8x4 src1 tensor is already in VTCM spad +static void matmul_id(struct htp_matmul_type * mt, + struct htp_tensor * restrict src0, + struct htp_tensor * restrict src1, + struct htp_tensor * restrict ids, + struct htp_tensor * restrict dst, + struct htp_spad * restrict src0_spad, + struct htp_spad * restrict src1_spad, + struct htp_spad * restrict src2_spad, + struct htp_spad * restrict dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_matmul_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t src0_nrows = ne01; // src0 rows per expert + const uint32_t src1_nrows = ne11; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const uint32_t n_ids = ids->ne[0]; // n_expert_used + const uint32_t n_as = ne02; // n_expert + + const size_t matrix_row_counts_size = n_as * sizeof(uint32_t); + const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); + + const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0; + const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size; + + const size_t dst_row_size = nb1; + const size_t src0_row_size = nb01; + const size_t src1_row_size = q8x4x2_row_size(ne10); + + const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + + // Per-thread VTCM scratchpads for all tensors + // Note that the entire src1 tensor is already in VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; + uint8_t * restrict src1_data = src1_spad->data; + + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0); + + // Prefill spad with src0 rows + #pragma unroll(4) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= HTP_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + src0_row_size_padded, src0_row_size, 2); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue); + + for (uint32_t cid = 0; cid < cne1; ++cid) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); + const int rm1 = row_mapping.i1; // expert idx + const int rm2 = row_mapping.i2; // token idx + + const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx + const uint8_t * restrict src1_col = + (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); + + mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + } + + // Prefetch next (n + spad_nrows) row + const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size, + src0_row_size_padded, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row); + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + src0_row_size_padded, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue); + + for (uint32_t cid = 0; cid < cne1; ++cid) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); + const int rm1 = row_mapping.i1; // expert idx + const int rm2 = row_mapping.i2; // token idx + + const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx + const uint8_t * restrict src1_col = + (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); + + mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], + src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// q8x4 src1 tensor is already in VTCM spad +static void matvec_id(struct htp_matmul_type * mt, + struct htp_tensor * restrict src0, + struct htp_tensor * restrict src1, + struct htp_tensor * restrict src2, + struct htp_tensor * restrict dst, + struct htp_spad * restrict src0_spad, + struct htp_spad * restrict src1_spad, + struct htp_spad * restrict src2_spad, + struct htp_spad * restrict dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_matmul_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t src0_nrows = ne01; // src0 rows per expert + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + assert(ne13 % ne03 == 0); + + const size_t dst_row_size = nb1; + const size_t src0_row_size = nb01; + const size_t src1_row_size = q8x4x2_row_size(ne10); + + const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + + const uint32_t n_aids = src2->ne[0]; // num activated experts + const uint32_t n_ids = ne02; // num experts + + // Per-thread VTCM scratchpads for all tensors + // Note that the entire src1 tensor is already in VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; + uint8_t * restrict src1_data = src1_spad->data; + + for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) { // for each expert + const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]); + assert(eid < n_ids); + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02; + const uint8_t * restrict src1_col = (const uint8_t *) src1_data; + float * restrict dst_row = (float *) (dst->data + ie1 * nb1); + + // Prefill spad with src0 rows + #pragma unroll(4) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= HTP_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + src0_row_size_padded, src0_row_size, 2); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue); + mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + + // Prefetch next (n + spad_nrows) row + const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size, + src0_row_size_padded, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row); + dma_queue_push(dma_queue, spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size, + src0_row_size_padded, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue); + mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], + src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], + dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// *** matmul in fp16 + +static void matmul_f16_f32(struct htp_tensor * restrict src0, + struct htp_tensor * restrict src1, + struct htp_tensor * restrict dst, + struct htp_spad * restrict src0_spad, + struct htp_spad * restrict src1_spad, + struct htp_spad * restrict dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_matmul_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const size_t src0_row_size = sizeof(__fp16) * ne00; + const size_t src1_row_size = sizeof(float) * ne10; + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const uint32_t nr0 = ne0; + + // This is the size of the rest of the dimensions of the result + const uint32_t nr1 = ne1 * ne2 * ne3; + + uint32_t chunk_size = 64; + + // distribute the thread work across the inner or outer loop based on which one is larger + uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + // The number of elements in each chunk + const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + + uint32_t current_chunk = ith; + + const uint32_t ith0 = current_chunk % nchunk0; + const uint32_t ith1 = current_chunk / nchunk0; + + const uint32_t ir0_start = dr0 * ith0; + const uint32_t ir0_end = MIN(ir0_start + dr0, nr0); + + const uint32_t ir1_start = dr1 * ith1; + const uint32_t ir1_end = MIN(ir1_start + dr1, nr1); + + // broadcast factors + const uint32_t r2 = ne12 / ne02; + const uint32_t r3 = ne13 / ne03; + + // no work for this thread + if (ir0_start >= ir0_end || ir1_start >= ir1_end) { + return; + } + + // block-tiling attempt + const uint32_t blck_0 = 64; + const uint32_t blck_1 = 64; + + float tmp[32]; + + for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (uint32_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1++) { + const uint32_t i13 = (ir1 / (ne12 * ne1)); + const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; + const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); + + // broadcast src0 into src1 + const uint32_t i03 = i13 / r3; + const uint32_t i02 = i12 / r2; + + const uint32_t i1 = i11; + const uint32_t i2 = i12; + const uint32_t i3 = i13; + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); + const uint8_t * restrict src1_col = + (const uint8_t *) src1->data + (i11 + i12 * ne11 + i13 * ne12 * ne11) * src1_row_size; + float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + for (uint32_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0++) { + vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row + ir0 * src0_row_size, src1_col); + } + + hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matmul-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], + src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// *** dynamic quant + +static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + + // Load and convert into QF32 + HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + + // Convert into fp16 + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + // Compute max and scale + HVX_Vector vmax_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf)); + vmax_hf = hvx_vec_reduce_max2_fp16(hvx_vec_abs_fp16(vx23_hf), vmax_hf); + + // Replicate first fp16 scale across all lanes + HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16; + vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl); + + HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); + + *(HVX_UVector *) y_d = vd_hf; + + // Divide input by the scale + HVX_Vector vd_inv_hf = hvx_vec_inverse_fp16(vd_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; +} + +// Overrides input x +static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (k + qk - 1) / qk; + + const uint32_t qrow_size = k; // int8 + + const uint32_t dblk_size = 8 * 2; // 8x __fp16 + const uint32_t qblk_size = QK_Q8_0x4x2; // int8 + + uint8_t * restrict y_q = (y + 0); // quants first + uint8_t * restrict y_d = (y + qrow_size); // then scales + + // Temp scales override input since we're working off of the aligned temp buffer in VTCM + uint8_t * restrict t_d = (uint8_t *) x; + + for (uint32_t i = 0; i < nb; i++) { + quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, + t_d + (i * 2 + 0) * dblk_size / 2); + quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, + t_d + (i * 2 + 1) * dblk_size / 2); + } + + // now copy the scales into final location + hvx_copy_fp16_ua(y_d, t_d, nb * 8); +} + +static void quantize_fp32_q8x4x2(const struct htp_tensor * src, + uint8_t * restrict dst, + struct htp_spad * spad, + uint32_t nth, + uint32_t ith, + uint32_t nrows_per_thread) { + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = src->nb[1]; + const size_t dst_row_size = q8x4x2_row_size(ne0); + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); + uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); + + const size_t src_row_size_padded = htp_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); + memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding + + for (uint32_t i = ir_first; i < ir_last; ++i) { + htp_l2fetch(src_data, 2, src_row_size, src_row_size); + hvx_copy_fp32_aa(tmp_data, src_data, ne0); + + // FARF(HIGH, "quantize-q8x4-row: %u\n", i); + quantize_row_fp32_q8x4x2((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-fp32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); +} + +// ** matmul callbacks for worker_pool + +static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "q4x4x2-q8x4x2"; + mt.vec_dot = vec_dot_q4x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; + + matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, + octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "q4x4x2-q8x4x2"; + mt.vec_dot = vec_dot_q4x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; + + matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, + octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "q8x4x2-q8x4x2"; + mt.vec_dot = vec_dot_q8x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; + + matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, + octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "q8x4x2-q8x4x2"; + mt.vec_dot = vec_dot_q8x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; + + matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, + octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "mxfp4x4x2-q8x4x2"; + mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; + + matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, + octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "mxfp4x4x2-q8x4x2"; + mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; + + matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, + octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void htp_matmul_f16_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + matmul_f16_f32(&octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, + octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +// ** matmul-id callbacks for worker_pool + +static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "q4x4x2-q8x4x2"; + mt.vec_dot = vec_dot_q4x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; + + matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, + &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "q4x4x2-q8x4x2"; + mt.vec_dot = vec_dot_q4x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; + + matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, + &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "q8x4x2-q8x4x2"; + mt.vec_dot = vec_dot_q8x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; + + matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, + &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "q8x4x2-q8x4x2"; + mt.vec_dot = vec_dot_q8x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; + + matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, + &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "mxfp4x4x2-q8x4x2"; + mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; + + matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, + &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "mxfp4x4x2-q8x4x2"; + mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; + mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; + + matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, + &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + +// ** main matmul entry point + +int op_matmul(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + htp_matmul_preamble; + + const char * op_type; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src1_nrows = ne11 * ne12 * ne13; + + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + size_t src1_row_size = nb11; + + const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + size_t src1_row_size_padded; + + worker_callback_t quant_job_func; + worker_callback_t matmul_job_func; + + bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE); + + switch (src0->type) { + case HTP_TYPE_Q4_0: + op_type = "q4x4x2-fp32"; + quant_job_func = htp_quantize_fp32_q8x4x2; + if (src1_nrows > 1) { + matmul_job_func = htp_matmul_q4x4x2_q8x4x2; + } else { + matmul_job_func = htp_matvec_q4x4x2_q8x4x2; + } + + src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + + // Entire src1 tensor is placed into the VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + + octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + break; + + case HTP_TYPE_Q8_0: + op_type = "q8x4x2-fp32"; + quant_job_func = htp_quantize_fp32_q8x4x2; + if (src1_nrows > 1) { + matmul_job_func = htp_matmul_q8x4x2_q8x4x2; + } else { + matmul_job_func = htp_matvec_q8x4x2_q8x4x2; + } + + src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + + // Entire src1 tensor is placed into the VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + + octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + break; + + case HTP_TYPE_MXFP4: + op_type = "mxfp4x4x2-f32"; + quant_job_func = htp_quantize_fp32_q8x4x2; + if (src1_nrows > 1) { + matmul_job_func = htp_matmul_mxfp4x4x2_q8x4x2; + } else { + matmul_job_func = htp_matvec_mxfp4x4x2_q8x4x2; + } + + src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + + // Entire src1 tensor is placed into the VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + + octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + break; + + case HTP_TYPE_F16: + op_type = "f16-f32"; + quant_job_func = NULL; // htp_quantize_f32_f16; + matmul_job_func = htp_matmul_f16_f32; + + // For all tensors we allocate N rows per thread, padded to HVX vector size + octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + need_quant = false; + break; + + default: + return HTP_STATUS_NO_SUPPORT; + } + + // VTCM scratchpads for all tensors + size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; + + FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type, + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size); + + FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0], + src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], + dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, + octx->ctx->vtcm_size, spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + + if (need_quant) { + // Run quant jobs + const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); + } + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + // Run matmul jobs + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs); + } + + return HTP_STATUS_OK; +} + +// ** main matmul-id entry point + +int op_matmul_id(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * ids = &octx->src2; + struct htp_tensor * dst = &octx->dst; + + htp_matmul_preamble; + + const char * op_type; + + worker_callback_t quant_job_func; + worker_callback_t matmul_id_job_func; + + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + + const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + + const uint32_t src0_nrows = ne01; // per expert + const uint32_t src1_nrows = ne11 * ne12 * ne13; + + size_t src1_row_size; + size_t src1_row_size_padded; + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + size_t matrix_row_counts_size = n_as * sizeof(uint32_t); + size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); + + switch (src0->type) { + case HTP_TYPE_Q4_0: + op_type = "q4x2x2-f32"; + quant_job_func = htp_quantize_fp32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + if (src1_nrows > 1) { + matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2; + } else { + matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2; + } + + // Entire src1 tensor is placed into the VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } + + octx->src2_spad.size = octx->src2_spad.size_per_thread; + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + break; + + case HTP_TYPE_Q8_0: + op_type = "q8x2x2-f32"; + quant_job_func = htp_quantize_fp32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + if (src1_nrows > 1) { + matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2; + } else { + matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2; + } + + // Entire src1 tensor is placed into the VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } + + octx->src2_spad.size = octx->src2_spad.size_per_thread; + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + break; + + case HTP_TYPE_MXFP4: + op_type = "mxfp4x2x2-f32"; + quant_job_func = htp_quantize_fp32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + if (src1_nrows > 1) { + matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2; + } else { + matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2; + } + + // Entire src1 tensor is placed into the VTCM + // For other tensors we allocate N rows per thread, padded to HVX vector size + octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } + + octx->src2_spad.size = octx->src2_spad.size_per_thread; + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + break; + + default: + return HTP_STATUS_NO_SUPPORT; + } + + size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; + + FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type, + octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size); + + FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, + src1->data, dst->data); + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, + octx->ctx->vtcm_size, spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; + + octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + + if (src1_nrows > 1) { + // initialize matrix_row_counts and map + uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0; + struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size; + + memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); + + // group rows by src0 matrix + for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx + for (uint32_t id = 0; id < n_ids; ++id) { // expert idx + const uint32_t i02 = + *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + assert(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 }; + matrix_row_counts[i02] += 1; + } + } + } + + // Setup worker pool callbacks + if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) { + // Run quant jobs + const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); + } + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + // Run matmul-id jobs + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs); + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/ops-utils.h b/ggml/src/ggml-hexagon/htp/ops-utils.h new file mode 100644 index 0000000000..f03ff34028 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/ops-utils.h @@ -0,0 +1,116 @@ +#ifndef OPS_UTILS_H +#define OPS_UTILS_H + +#include "htp-msg.h" + +#ifndef MAX +# define MAX(a, b) ((a) > (b) ? (a) : (b)) +#endif + +#ifndef MIN +# define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +static inline uint64_t htp_get_cycles() { + uint64_t cycles = 0; + asm volatile(" %0 = c15:14\n" : "=r"(cycles)); + return cycles; +} + +static inline uint64_t htp_get_pktcnt() { + uint64_t pktcnt; + asm volatile(" %0 = c19:18\n" : "=r"(pktcnt)); + return pktcnt; +} + +static inline int32_t htp_is_aligned(void * addr, uint32_t align) { + return ((size_t) addr & (align - 1)) == 0; +} + +static inline uint32_t htp_round_up(uint32_t n, uint32_t m) { + return m * ((n + m - 1) / m); +} + +static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) { + const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); + asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control)); +} + +static inline int32_t htp_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { + uint32_t left_off = (size_t) addr & (chunk_size - 1); + uint32_t right_off = left_off + n; + return right_off <= chunk_size; +} + +static inline void htp_dump_int8_line(char * pref, const int8_t * x, int n) { + char str[1024], *p = str; + p += sprintf(p, "%s: ", pref); + for (int i = 0; i < 16; i++) { + p += sprintf(p, "%d, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void htp_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) { + char str[1024], *p = str; + p += sprintf(p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += sprintf(p, "%d, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void htp_dump_int32_line(char * pref, const int32_t * x, uint32_t n) { + char str[1024], *p = str; + p += sprintf(p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += sprintf(p, "%d, ", (int) x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void htp_dump_fp16_line(char * pref, const __fp16 * x, uint32_t n) { + char str[1024], *p = str; + p += sprintf(p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += sprintf(p, "%.6f, ", (float) x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void htp_dump_fp32_line(char * pref, const float * x, uint32_t n) { + char str[1024], *p = str; + p += sprintf(p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += sprintf(p, "%.6f, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void htp_dump_f32(char * pref, const float * x, uint32_t n) { + uint32_t n0 = n / 16; + uint32_t n1 = n % 16; + + uint32_t i = 0; + for (; i < n0; i++) { + htp_dump_fp32_line(pref, x + (16 * i), 16); + } + if (n1) { + htp_dump_fp32_line(pref, x + (16 * i), n1); + } +} + +static inline void htp_dump_f16(char * pref, const __fp16 * x, uint32_t n) { + uint32_t n0 = n / 16; + uint32_t n1 = n % 16; + + uint32_t i = 0; + for (; i < n0; i++) { + htp_dump_fp16_line(pref, x + (16 * i), 16); + } + if (n1) { + htp_dump_fp16_line(pref, x + (16 * i), n1); + } +} + +#endif /* OPS_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c new file mode 100644 index 0000000000..16afa50f5b --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -0,0 +1,418 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define htp_rope_preamble \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct rope_th_ctx { + int32_t n_dims; + int32_t mode; + int32_t n_ctx_orig; + int32_t sections[4]; + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + float theta_scale; + float corr_dims[2]; + + struct htp_ops_context * octx; +}; + +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + + return (1 - MIN(1, MAX(0, y))); +} + +static void rope_cache_init(const float theta_base, + float freq_scale, + const float * freq_factors, + float * corr_dims, + uint32_t ne0, + float ext_factor, + float mscale, + float * cache, + float theta_scale) { + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py + float theta = theta_base; + + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + + float theta_extrap = theta / ff; + + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta2 = theta_interp; + + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta2 = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + + cache[i0 + 0] = cosf(theta2) * mscale; + cache[i0 + 1] = sinf(theta2) * mscale; + + theta *= theta_scale; + } +} + +#define M_PI 3.1415926535897932384626433 + +static void rope_corr_dims(int n_dims, + int n_ctx_orig, + float freq_base, + float beta_fast, + float beta_slow, + float * dims) { + float start = floorf(n_dims * logf(n_ctx_orig / (beta_fast * 2 * (float) M_PI)) / (2 * logf(freq_base))); + float end = ceilf(n_dims * logf(n_ctx_orig / (beta_slow * 2 * (float) M_PI)) / (2 * logf(freq_base))); + dims[0] = MAX(0, start); + dims[1] = MIN(n_dims - 1, end); +} + +static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) { + memset(rope_ctx, 0, sizeof(struct rope_th_ctx)); + + const int32_t * op_params = &octx->op_params[0]; + + rope_ctx->n_dims = ((const int32_t *) op_params)[1]; + rope_ctx->mode = ((const int32_t *) op_params)[2]; + rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4]; + + memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float)); + memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float)); + memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float)); + memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float)); + memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float)); + memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float)); + memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4); + + rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims); + + rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast, + rope_ctx->beta_slow, rope_ctx->corr_dims); + + rope_ctx->octx = octx; + FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims, + rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor); +} + +static void hvx_calc_rope_f32(const float * restrict src0, + float * restrict dst, + const int num_elems, + const float * restrict theta_cache) { + // for (int i = 0; i < num_elems; i += 2) { + //const float cos_theta = theta_cache[i + 0]; + //const float sin_theta = theta_cache[i + 1]; + + //const float x0 = src[0]; + //const float x1 = src[1]; + + //dst[0] = x0*cos_theta - x1*sin_theta; + //dst[1] = x0*sin_theta + x1*cos_theta; + + //src += 2; + //dst += 2; + // } + + const uint8_t * restrict src0_curr = (const uint8_t *) src0; + const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; + uint8_t * restrict dst_curr = (uint8_t *) dst; + + int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once + + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v0 = *(HVX_Vector *) src0_curr; + HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN); + + HVX_Vector v2 = *(HVX_Vector *) theta_curr; + HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1 + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); + + *(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore); + *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore); + + src0_curr += 2 * VLEN; + theta_curr += 2 * VLEN; + dst_curr += 2 * VLEN; + } +} + +static void rope_hex_f32(struct rope_th_ctx * rope_ctx, + const uint32_t ir0, + const uint32_t ir1, + int nth, + int ith, + int opt_path) { + struct htp_ops_context * octx = rope_ctx->octx; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src2 = &octx->src2; + struct htp_tensor * dst = &octx->dst; + + htp_rope_preamble; + + const int32_t * pos = (const int32_t *) src1->data; + + float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01)); + + const float * freq_factors = NULL; + if (src2 != NULL) { + freq_factors = (const float *) src2->data; + } + + int ir = 0; + + for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch + for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len + const int32_t p = pos[i2]; + + rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor, + rope_ctx->attn_factor, wp0, rope_ctx->theta_scale); + + for (uint32_t i1 = 0; i1 < ne1; i1++) { // attn-heads + if (ir++ < ir0) { + continue; + } + if (ir > ir1) { + break; + } + + const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01); + float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1); + + const float * src_loc = src; + float * dst_data_loc = dst_data; + + if (1 == opt_path) { + hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); + } else { + for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { + const float cos_theta = wp0[i0 + 0]; + const float sin_theta = wp0[i0 + 1]; + + const float x0 = src_loc[0]; + const float x1 = src_loc[1]; + + dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; + dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; + + src_loc += 2; + dst_data_loc += 2; + } + } + + for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) { + dst_data_loc[0] = src_loc[0]; + dst_data_loc[1] = src_loc[1]; + + src_loc += 2; + dst_data_loc += 2; + } + } + } + } +} + +static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) { + struct htp_ops_context * octx = rope_ctx->octx; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + htp_rope_preamble; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + int is_aligned = 1; + int opt_path = 0; + if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) || + (0 == htp_is_aligned((void *) dst->data, VLEN))) { + FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n"); + is_aligned = 0; + } + if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { + struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data; + + rope_job_f32_per_thread(rope_ctx, n, i); +} + +static int execute_op_rope_f32(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src2 = &octx->src2; + struct htp_tensor * dst = &octx->dst; + + worker_callback_t op_func; + const char * op_type = NULL; + + struct rope_th_ctx rope_ctx; + + switch (octx->op) { + case HTP_OP_ROPE: + op_func = rope_job_dispatcher_f32; + op_type = "rope-f32"; + + init_rope_ctx(&rope_ctx, octx); + break; + + default: + FARF(ERROR, "Unsupported Op %u\n", octx->op); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t n_threads = octx->n_threads; + + const size_t src0_row_size = src0->nb[1]; + const size_t src1_row_size = src0_row_size; + const size_t dst_row_size = dst->nb[1]; + + // VTCM scratchpads for all tensors + // N rows per thread, padded to HVX vector size + octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; + octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; + octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads; + + size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + + if (src2->ne[0]) { + FARF(HIGH, + "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u " + "dst-spad-size %u\n", + op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], + src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], + dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + } else { + FARF(HIGH, + "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], + src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, + octx->dst_spad.size); + } + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, + spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t n_jobs = MIN(n_threads, src0_nrows); + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs); + } + + return err; +} + +int op_rope(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src0.type) { + case HTP_TYPE_F32: + err = execute_op_rope_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c new file mode 100644 index 0000000000..5bf0cbf792 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -0,0 +1,402 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define htp_softmax_preamble3 \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \ + const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \ + const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \ + const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \ + \ + const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \ + const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \ + const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \ + const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct softmax_th_ctx { + bool use_f16; + bool use_src1; + uint32_t n_head; + uint32_t n_head_log2; + + float scale; + float max_bias; + float m0; + float m1; + + struct htp_ops_context * octx; +}; + +static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + + memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx)); + + memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float)); + memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); + + softmax_ctx->n_head = src0->ne[2]; + softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head)); + + softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2); + softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2); + + softmax_ctx->use_src1 = (src1->ne[0] != 0); + softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); + + softmax_ctx->octx = octx; +} + +static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + const int num_elems, + float scale, + const uint8_t * restrict mask, + float slope) { + const uint8_t * restrict src_curr = src; + uint8_t * restrict dst_curr = dst; + const uint8_t * restrict mask_curr = mask; + + HVX_Vector scale_vec = hvx_vec_splat_fp32(scale); + HVX_Vector slope_vec = hvx_vec_splat_fp32(slope); + + int step_of_1 = num_elems >> 5; + + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = *(HVX_Vector *) src_curr; + + HVX_Vector v3 = *(HVX_Vector *) mask_curr; + + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec); + + HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v3, slope_vec); + + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, v4); + + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v5); + + src_curr += VLEN; + dst_curr += VLEN; + mask_curr += VLEN; + } +} + +static void hvx_fast_softmax_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems) { + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_pad = (HVX_Vector *) pad; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000); + HVX_Vector max_vec = hvx_vec_splat_fp32(((const float *) src)[0]); + HVX_Vector zero_v = Q6_V_vzero(); + HVX_Vector one_v = hvx_vec_splat_fp32(1.0); + + int step_of_1 = num_elems >> 5; + + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = v_src[i]; + max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1); + } + + HVX_Vector v = hvx_vec_reduce_max_fp32(max_vec); + max_vec = hvx_vec_repl4(v); + + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec); + + HVX_Vector v3 = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(v2)); + + sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3); + + v_pad[i] = v3; + } + + v = hvx_vec_qf32_reduce_sum(sum_vec); + sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v)); + + HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v); + HVX_Vector v4 = hvx_vec_inverse_fp32(sum_vec); + HVX_Vector scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v); + + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = v_pad[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_vec); + v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + } +} + +static float hvx_softmax_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict spad, + const int num_elems, + const float max) { + hvx_sub_scalar_f32(src, max, spad, num_elems); + + hvx_exp_f32(spad, dst, num_elems, false); + + float sum = hvx_self_sum_f32(dst, num_elems); + + return sum; +} + +static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) { + struct htp_ops_context * octx = softmax_ctx->octx; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * dst = &octx->dst; + + htp_softmax_preamble3; + + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1); + + float * wp0 = (float *) src0_spad_data; + float * wp1 = (float *) src1_spad_data; + float * wp2 = (float *) dst_spad_data; + + for (uint32_t i03 = 0; i03 < ne03; i03++) { + for (uint32_t i02 = 0; i02 < ne02; i02++) { + for (uint32_t i01 = ith; i01 < ne01; i01 += nth) { + const uint32_t i11 = i01; + const uint32_t i12 = i02 % ne12; + const uint32_t i13 = i03 % ne13; + + // ALiBi + const uint32_t h = i02; // head + + const float slope = (softmax_ctx->max_bias > 0.0f) ? + h < softmax_ctx->n_head_log2 ? + powf(softmax_ctx->m0, h + 1) : + powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) : + 1.0f; + + float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03); + float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3); + + // broadcast the mask across rows + __fp16 * mp_f16 = (softmax_ctx->use_src1) ? + (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : + NULL; + float * mp_f32 = (softmax_ctx->use_src1) ? + (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : + NULL; + + if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) { + hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale, + (const uint8_t *) mp_f32, slope); + } else { + hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale); + if (mp_f32) { + if (softmax_ctx->use_f16) { + for (int i = 0; i < ne00; ++i) { + wp0[i] += slope * (float) mp_f16[i]; + } + } else { + for (int i = 0; i < ne00; ++i) { + wp0[i] += slope * mp_f32[i]; + } + } + } + } + + if (1 == opt_path) { + hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); + } else { + float max = hvx_self_max_f32((const uint8_t *) wp0, ne00); + float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); + sum = sum > 0.0 ? (1.0 / sum) : 1; + hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum); + } + } + } + } +} + +static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) { + struct htp_ops_context * octx = softmax_ctx->octx; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + htp_softmax_preamble3; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + int is_aligned = 1; + int opt_path = 0; + if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { + is_aligned = 0; + FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n"); + } + if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + softmax_htp_f32(nth, ith, softmax_ctx, opt_path); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, + ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) { + struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data; + softmax_job_f32_per_thread(p_softmax_ctx, n, i); +} + +static int execute_op_softmax_f32(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; + + worker_callback_t op_func; + const char * op_type = NULL; + + struct softmax_th_ctx softmax_ctx; + + switch (octx->op) { + case HTP_OP_SOFTMAX: + op_func = softmax_job_dispatcher_f32; + op_type = "softmax-f32"; + + init_softmax_ctx(&softmax_ctx, octx); + break; + + default: + FARF(ERROR, "Unsupported Op %u\n", octx->op); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t n_threads = octx->n_threads; + + const size_t src0_row_size = src0->nb[1]; + const size_t src1_row_size = src0_row_size; + const size_t dst_row_size = dst->nb[1]; + + // VTCM scratchpads for all tensors + // N rows per thread, padded to HVX vector size + octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; + octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; + octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads; + + size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + + if (src1->ne[0]) { + FARF(HIGH, + "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], + src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, + octx->dst_spad.size); + } else { + FARF(HIGH, "%s: %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + } + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, + spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t n_jobs = MIN(n_threads, src0_nrows); + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs); + } + + return err; +} + +int op_softmax(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src0.type) { + case HTP_TYPE_F32: + err = execute_op_softmax_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c new file mode 100644 index 0000000000..bb7557b025 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -0,0 +1,255 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define htp_unary_preamble \ + const uint32_t ne00 = src->ne[0]; \ + const uint32_t ne01 = src->ne[1]; \ + const uint32_t ne02 = src->ne[2]; \ + const uint32_t ne03 = src->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src->nb[0]; \ + const uint32_t nb01 = src->nb[1]; \ + const uint32_t nb02 = src->nb[2]; \ + const uint32_t nb03 = src->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems, + float epsilon) { + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_fp32(epsilon); + + int step_of_1 = num_elems >> 5; + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + HVX_Vector reduced_sum = hvx_vec_qf32_reduce_sum(sum_v); + sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum)); + + HVX_Vector t_v = hvx_vec_splat_fp32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_fp32(t_v); + HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); + HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); + + HVX_Vector scale_v = hvx_vec_rsqrt_fp32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); + + #pragma unroll(4) + for (int i = 0; i < step_of_1; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + } +} + +static void rms_norm_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + htp_l2fetch(src_local + row_elems, 1, row_size, row_size); + } + + if (1 == opt_path) { + hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); + } else { + float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems); + + const float mean = sum / row_elems; + const float scale = 1.0f / sqrtf(mean + epsilon); + + hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale); + } + } +} + +static void unary_job_f32_per_thread(const struct htp_tensor * src, + struct htp_tensor * dst, + uint8_t * spad, + int htp_op, + int32_t * op_params, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread) { + htp_unary_preamble; + + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + int is_aligned = 1; + int opt_path = 0; + if ((0 == htp_is_aligned((void *) src->data, VLEN)) || (0 == htp_is_aligned((void *) dst->data, VLEN))) { + is_aligned = 0; + FARF(HIGH, "unary-f32: unaligned addresses in unary op, possibly slower execution\n"); + } + if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + const uint8_t * restrict data_src = (const uint8_t *) src->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); + float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); + uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01); + + switch (htp_op) { + case HTP_OP_RMS_NORM: + rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; + + default: + break; + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0], + src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2], + dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + + unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i, + octx->src0_nrows_per_thread); +} + +static int execute_op_unary_f32(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + const struct htp_tensor * src0 = &octx->src0; + struct htp_tensor * dst = &octx->dst; + + worker_callback_t unary_op_func; + const char * op_type = NULL; + + switch (octx->op) { + case HTP_OP_RMS_NORM: + unary_op_func = unary_job_dispatcher_f32; + op_type = "rmsnorm-f32"; + break; + + default: + FARF(ERROR, "Unsupported unary Op %u\n", octx->op); + return HTP_STATUS_NO_SUPPORT; + } + + const int n_threads = octx->n_threads; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + + const size_t src0_row_size = src0->nb[1]; + const size_t dst_row_size = dst->nb[1]; + + // VTCM scratchpads for all tensors + octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; + octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; + + size_t spad_size = octx->src0_spad.size + octx->dst_spad.size; + + FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < spad_size) { + FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, + spad_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t n_jobs = MIN(n_threads, src0_nrows); + + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs); + } + + return err; +} + +int op_unary(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src0.type) { + case HTP_TYPE_F32: + err = execute_op_unary_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/worker-pool.c b/ggml/src/ggml-hexagon/htp/worker-pool.c new file mode 100644 index 0000000000..cd38c2126c --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/worker-pool.c @@ -0,0 +1,297 @@ +#include "worker-pool.h" + +#include +#include +#include +#include +#include +#include + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif + +#include "HAP_farf.h" + +#define WORKER_THREAD_STACK_SZ (2 * 16384) +#define LOWEST_USABLE_QURT_PRIO (254) + +struct worker_pool_s; + +// internal structure kept in thread-local storage per instance of worker pool +typedef struct { + struct worker_pool_s * pool; + unsigned int id; +} worker_context_t; + +// internal structure kept in thread-local storage per instance of worker pool +typedef struct worker_pool_s { + worker_pool_job_t job[MAX_NUM_WORKERS]; // list of job descriptors + qurt_thread_t thread[MAX_NUM_WORKERS]; // thread ID's of the workers + worker_context_t context[MAX_NUM_WORKERS]; // worker contexts + void * stack[MAX_NUM_WORKERS]; // thread stack pointers + unsigned int n_threads; // number of workers in this pool + + atomic_uint seqn; // seqno used to detect new jobs + atomic_uint next_job; // next job index + atomic_uint n_pending; // number of pending jobs + atomic_uint n_jobs; // number of current jobs + atomic_bool killed; // threads need to exit +} worker_pool_t; + +static void worker_pool_main(void * context) { + worker_context_t * me = (worker_context_t *) context; + worker_pool_t * pool = me->pool; + + FARF(HIGH, "worker-pool: thread %u started", me->id); + + unsigned int prev_seqn = 0; + while (!atomic_load(&pool->killed)) { + unsigned int seqn = atomic_load(&pool->seqn); + if (seqn == prev_seqn) { + // Nothing to do + qurt_futex_wait(&pool->seqn, prev_seqn); + continue; + } + + // New job + prev_seqn = seqn; + + unsigned int n = atomic_load(&pool->n_jobs); + unsigned int i = atomic_fetch_add(&pool->next_job, 1); + if (i >= n) { + // Spurios wakeup + continue; + } + + pool->job[i].func(n, i, pool->job[i].data); + + atomic_fetch_sub(&pool->n_pending, 1); + } + + FARF(HIGH, "worker-pool: thread %u stopped", me->id); +} + +AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context, uint32_t n_threads, uint32_t stack_size) { + int err = 0; + + if (NULL == context) { + FARF(ERROR, "NULL context passed to worker_pool_init()."); + return AEE_EBADPARM; + } + + // Allocations + int size = (stack_size * n_threads) + (sizeof(worker_pool_t)); + + unsigned char * mem_blob = (unsigned char *) malloc(size); + if (!mem_blob) { + FARF(ERROR, "Could not allocate memory for worker pool!!"); + return AEE_ENOMEMORY; + } + + worker_pool_t * me = (worker_pool_t *) (mem_blob + stack_size * n_threads); + + // name for the first worker, useful in debugging threads + char name[19]; + snprintf(name, 12, "0x%8x:", (int) me); + strcat(name, "worker0"); + me->n_threads = n_threads; + + // initializations + for (unsigned int i = 0; i < me->n_threads; i++) { + me->stack[i] = NULL; + me->thread[i] = 0; + + me->context[i].id = i; + me->context[i].pool = me; + } + + // initialize job queue + me->n_pending = 0; + me->n_jobs = 0; + me->next_job = 0; + me->seqn = 0; + me->killed = 0; + + // launch the workers + qurt_thread_attr_t attr; + qurt_thread_attr_init(&attr); + + for (unsigned int i = 0; i < me->n_threads; i++) { + // set up stack + me->stack[i] = mem_blob; + mem_blob += stack_size; + qurt_thread_attr_set_stack_addr(&attr, me->stack[i]); + qurt_thread_attr_set_stack_size(&attr, stack_size); + + // set up name + qurt_thread_attr_set_name(&attr, name); + name[17] = (name[17] + 1); + // name threads context:worker0, context:worker1, .. (recycle at 9, but num threads should be less than that anyway) + if (name[17] > '9') { + name[17] = '0'; + } + + // set up priority - by default, match the creating thread's prio + int prio = qurt_thread_get_priority(qurt_thread_get_id()); + + if (prio < 1) { + prio = 1; + } + if (prio > LOWEST_USABLE_QURT_PRIO) { + prio = LOWEST_USABLE_QURT_PRIO; + } + + qurt_thread_attr_set_priority(&attr, prio); + + // launch + err = qurt_thread_create(&me->thread[i], &attr, worker_pool_main, (void *) &me->context[i]); + if (err) { + FARF(ERROR, "Could not launch worker threads!"); + worker_pool_release((worker_pool_context_t *) &me); + return AEE_EQURTTHREADCREATE; + } + } + *context = (worker_pool_context_t *) me; + return AEE_SUCCESS; +} + +AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads) { + return worker_pool_init_with_stack_size(context, n_threads, WORKER_THREAD_STACK_SZ); +} + +// clean up worker pool +void worker_pool_release(worker_pool_context_t * context) { + worker_pool_t * me = (worker_pool_t *) *context; + + // if no worker pool exists, return error. + if (NULL == me) { + return; + } + + atomic_store(&me->killed, 1); + atomic_fetch_add(&me->seqn, 1); + qurt_futex_wake(&me->seqn, me->n_threads); + + // de-initializations + for (unsigned int i = 0; i < me->n_threads; i++) { + if (me->thread[i]) { + int status; + (void) qurt_thread_join(me->thread[i], &status); + } + } + + // free allocated memory (were allocated as a single buffer starting at stack[0]) + if (me->stack[0]) { + free(me->stack[0]); + } + + *context = NULL; +} + +// run jobs +AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n) { + worker_pool_t * me = (worker_pool_t *) context; + if (NULL == me) { + FARF(ERROR, "worker-pool: invalid context"); + return AEE_EBADPARM; + } + + if (n > me->n_threads) { + FARF(ERROR, "worker-pool: invalid number of jobs %u for n-threads %u", n, me->n_threads); + return AEE_EBADPARM; + } + + memcpy(me->job, job, sizeof(worker_pool_job_t) * n); + + if (n > 1) { + atomic_store(&me->next_job, 1); + atomic_store(&me->n_jobs, n); + atomic_store(&me->n_pending, n - 1); + + // wake up workers + atomic_fetch_add(&me->seqn, 1); + qurt_futex_wake(&me->seqn, n - 1); + } + + // main thread runs job #0 + me->job[0].func(n, 0, me->job[0].data); + + if (n > 1) { + while (atomic_load(&me->n_pending)) + ; + } + + return 0; +} + +// run func +AEEResult worker_pool_run_func(worker_pool_context_t context, worker_callback_t func, void * data, unsigned int n) { + worker_pool_job_t job[n]; + + for (unsigned int i = 0; i < n; i++) { + job[i].func = func; + job[i].data = data; + } + + return worker_pool_run_jobs(context, job, n); +} + +AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio) { + worker_pool_t * me = (worker_pool_t *) context; + + // if no worker pool exists, return error. + if (!me) { + return AEE_ENOMORE; + } + + int result = AEE_SUCCESS; + if (prio < 1) { + prio = 1; + } + if (prio > LOWEST_USABLE_QURT_PRIO) { + prio = LOWEST_USABLE_QURT_PRIO; + } + + for (unsigned int i = 0; i < me->n_threads; i++) { + int res = qurt_thread_set_priority(me->thread[i], (unsigned short) prio); + if (0 != res) { + result = AEE_EBADPARM; + FARF(ERROR, "QURT failed to set priority of thread %d, ERROR = %d", me->thread[i], res); + } + } + + return result; +} + +AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids) { + worker_pool_t * me = (worker_pool_t *) context; + if (!me) { + FARF(ERROR, "worker-pool: invalid context"); + return AEE_EBADPARM; + ; + } + + for (int i = 0; i < me->n_threads; i++) { + tids[i] = me->thread[i]; + } + + return AEE_SUCCESS; +} + +AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio) { + worker_pool_t * me = (worker_pool_t *) context; + if (!me) { + FARF(ERROR, "worker-pool: invalid context"); + return AEE_EBADPARM; + } + + int priority = qurt_thread_get_priority(me->thread[0]); + if (priority > 0) { + *prio = priority; + return 0; + } else { + *prio = 0; + return AEE_EBADSTATE; + } +} diff --git a/ggml/src/ggml-hexagon/htp/worker-pool.h b/ggml/src/ggml-hexagon/htp/worker-pool.h new file mode 100644 index 0000000000..6f8c9056c4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/worker-pool.h @@ -0,0 +1,57 @@ +#ifndef HTP_WORKER_POOL_H +#define HTP_WORKER_POOL_H + +// MACRO enables function to be visible in shared-library case. +#define WORKERPOOL_API __attribute__((visibility("default"))) + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// signature of callbacks to be invoked by worker threads +typedef void (*worker_callback_t)(unsigned int n, unsigned int i, void *); + +/// Typedef of worker_pool context +typedef void * worker_pool_context_t; + +/// descriptor for requested callback +typedef struct { + worker_callback_t func; + void * data; +} worker_pool_job_t; + +/// Maximum supported number of worker threads. +#define MAX_NUM_WORKERS 10 + +// Initialize worker pool. +WORKERPOOL_API AEEResult worker_pool_init(worker_pool_context_t * context, uint32_t n_threads); + +// Initialize worker pool with custom stack size +WORKERPOOL_API AEEResult worker_pool_init_with_stack_size(worker_pool_context_t * context, + uint32_t n_threads, + uint32_t stack_size); + +// Kill worker threads and release worker pool resources +WORKERPOOL_API void worker_pool_release(worker_pool_context_t * context); + +// Run jobs with the worker pool. +WORKERPOOL_API AEEResult worker_pool_run_jobs(worker_pool_context_t context, worker_pool_job_t * job, unsigned int n); + +WORKERPOOL_API AEEResult worker_pool_run_func(worker_pool_context_t context, + worker_callback_t func, + void * data, + unsigned int n); + +WORKERPOOL_API AEEResult worker_pool_set_thread_priority(worker_pool_context_t context, unsigned int prio); +WORKERPOOL_API AEEResult worker_pool_get_thread_priority(worker_pool_context_t context, unsigned int * prio); +WORKERPOOL_API AEEResult worker_pool_retrieve_thread_id(worker_pool_context_t context, unsigned int * tids); + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef HTP_WORKER_POOL_H diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 6b499320e7..23b6889919 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -29,10 +29,11 @@ if (CXX_IS_HIPCC) endif() else() # Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES. + if(AMDGPU_TARGETS AND NOT GPU_TARGETS) + set(GPU_TARGETS ${AMDGPU_TARGETS}) + endif() if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS}) - elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) - set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS}) endif() cmake_minimum_required(VERSION 3.21) enable_language(HIP) diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index e9201cdc68..ec37a25337 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -682,6 +682,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, #endif #ifdef __cplusplus +#include #include #include @@ -697,6 +698,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size()); } +// Return true if the edges in the graph match expectations. +inline bool ggml_check_edges(const struct ggml_cgraph * cgraph, + int start_idx, + std::initializer_list> edges) { + for (const auto & edge : edges) { + int dst_node = edge[0]; + int src_idx = edge[1]; + int src_node = edge[2]; + if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) { + return false; + } + } + return true; +} + // expose GGUF internals for test code GGML_API size_t gguf_type_size(enum gguf_type type); GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index db33a4ab6c..93a3600b63 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -6156,8 +6156,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3)); } else if (mode == GGML_SCALE_MODE_BILINEAR) { if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { - sf0 = (float)(ne0 - 1) / (ne00 - 1); - sf1 = (float)(ne1 - 1) / (ne01 - 1); + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; pixel_offset = 0.0f; } diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index b1575b8145..75657f3fca 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -32,8 +32,10 @@ #include "pad.hpp" #include "quantize.hpp" #include "quants.hpp" +#include "roll.hpp" #include "rope.hpp" #include "set_rows.hpp" +#include "ssm_conv.hpp" #include "softmax.hpp" #include "tsembd.hpp" #include "wkv.hpp" diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 33f9035075..c97c589943 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -30,6 +30,9 @@ #include #include +#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC +# include +#endif #include #include "ggml-sycl.h" @@ -39,13 +42,16 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/common.hpp" #include "ggml-sycl/element_wise.hpp" +#include "ggml-sycl/norm.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" #include "ggml-sycl/set_rows.hpp" #include "ggml-sycl/set.hpp" #include "ggml-sycl/sycl_hw.hpp" #include "ggml-sycl/getrows.hpp" +#include "ggml-sycl/repeat_back.hpp" #include "ggml-sycl/quantize.hpp" +#include "ggml-sycl/ssm_conv.hpp" #include "ggml.h" static bool g_sycl_loaded = false; @@ -54,6 +60,7 @@ int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; int g_ggml_sycl_prioritize_dmmv = 0; +int g_ggml_sycl_use_async_mem_op = 0; static ggml_sycl_device_info ggml_sycl_init() { ggml_sycl_device_info info = {}; @@ -237,7 +244,20 @@ static void ggml_check_sycl() try { fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); #endif */ - + // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be + // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in + // other places. +#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC + g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph; + if (g_ggml_sycl_use_async_mem_op) { + for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) { + if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) { + g_ggml_sycl_use_async_mem_op = 0; + break; + } + } + } +#endif if (CHECK_TRY_ERROR(g_all_sycl_device_count = dpct::dev_mgr::instance().device_count()) != 0) { initialized = true; @@ -2598,6 +2618,10 @@ catch (sycl::exception const &exc) { std::exit(1); } +static void ggml_sycl_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_repeat_back(ctx, dst); +} static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); @@ -2614,6 +2638,11 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds ggml_sycl_op_rms_norm(ctx, dst); } +static void ggml_sycl_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_rms_norm_back(ctx, dst); +} + static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_l2_norm(ctx, dst); @@ -3031,19 +3060,51 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) { } } +// Helper functions to unify device memory allocation for both async and sync paths +static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) { + bool use_async = g_ggml_sycl_use_async_mem_op; +#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC + if (use_async) { + return syclex::async_malloc(*stream, sycl::usm::alloc::device, size); + } +#else + // If async allocation extension is not available, use_async should always be false. + GGML_ASSERT(!use_async); +#endif + return sycl::malloc(size, *stream, sycl::usm::alloc::device); +} + +static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { + bool use_async = g_ggml_sycl_use_async_mem_op; +#if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC + if (use_async) { + syclex::async_free(*stream, ptr); + return; + } +#else + // If async allocation extension is not available, use_async should always be false. + GGML_ASSERT(!use_async); +#endif + sycl::free(ptr, *stream); +} + static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, dpct::queue_ptr stream) { - auto * tmp_buf = sycl::malloc_shared(size, *stream); - SYCL_CHECK( - CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size) - .wait())); + uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + GGML_ASSERT((size % sizeof(block_q4_0) == 0)); GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); int offset_blks = offset / sizeof(block_q4_0); auto qs_ptr = data_device + offset_blks * QK4_0 / 2; auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; - stream->parallel_for( + auto reorder_event = stream->parallel_for( size / sizeof(block_q4_0), [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { const block_q4_0* x = (const block_q4_0*)tmp_buf; @@ -3054,9 +3115,11 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j]; } *(d_ptr + ib) = x[ib].d; - }).wait_and_throw(); - - sycl::free(tmp_buf, *stream); + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + sycl_ext_free(stream, tmp_buf); } static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { @@ -3065,14 +3128,19 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d const int nblocks = size / sizeof(block_q4_K); - auto * tmp_buf = sycl::malloc_shared(size, *stream); - SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait())); + uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } auto * qs_ptr = data_device; auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks; auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks); - stream->parallel_for(nblocks, [=](auto i) { + auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { const block_q4_K * x = (const block_q4_K *) tmp_buf; const int ib = i; @@ -3085,9 +3153,11 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d } dm_ptr[ib] = x[ib].dm; - }).wait_and_throw(); - - sycl::free(tmp_buf, *stream); + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + sycl_ext_free(stream, tmp_buf); } static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { @@ -3096,42 +3166,46 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d const int nblocks = size / sizeof(block_q6_K); - auto * tmp_buf = sycl::malloc_shared(size, *stream); - SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait())); + uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } auto * ql_ptr = data_device; auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks; auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks; sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks); - stream - ->parallel_for(nblocks, - [=](auto i) { - const block_q6_K * x = (const block_q6_K *) tmp_buf; - const int ib = i; + auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { + const block_q6_K * x = (const block_q6_K *) tmp_buf; + const int ib = i; - const uint8_t * ql = x[ib].ql; - const uint8_t * qh = x[ib].qh; - uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib; - uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib; - uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib; + const uint8_t * ql = x[ib].ql; + const uint8_t * qh = x[ib].qh; + uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib; + uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib; + uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib; - for (int j = 0; j < QK_K / 2; ++j) { - base_ql_ptr[j] = ql[j]; - } - for (int j = 0; j < QK_K / 4; ++j) { - base_qh_ptr[j] = qh[j]; - } + for (int j = 0; j < QK_K / 2; ++j) { + base_ql_ptr[j] = ql[j]; + } + for (int j = 0; j < QK_K / 4; ++j) { + base_qh_ptr[j] = qh[j]; + } - for (int j = 0; j < QK_K / 16; ++j) { - base_scales_ptr[j] = x[ib].scales[j]; - } + for (int j = 0; j < QK_K / 16; ++j) { + base_scales_ptr[j] = x[ib].scales[j]; + } - dm_ptr[ib] = x[ib].d; - }) - .wait_and_throw(); - - sycl::free(tmp_buf, *stream); + dm_ptr[ib] = x[ib].d; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + sycl_ext_free(stream, tmp_buf); } static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { @@ -3617,6 +3691,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_REPEAT: ggml_sycl_repeat(ctx, dst); break; + case GGML_OP_REPEAT_BACK: + ggml_sycl_repeat_back(ctx, dst); + break; case GGML_OP_GET_ROWS: ggml_sycl_get_rows(ctx, dst); break; @@ -3756,6 +3833,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_LEAKY_RELU: ggml_sycl_leaky_relu(ctx, dst); break; + case GGML_OP_RMS_NORM_BACK: + ggml_sycl_rms_norm_back(ctx, dst); + break; case GGML_OP_RMS_NORM: ggml_sycl_rms_norm(ctx, dst); break; @@ -3851,6 +3931,11 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_sycl_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_SSM_CONV: + ggml_sycl_ssm_conv(ctx, dst); + case GGML_OP_ROLL: + ggml_sycl_roll(ctx, dst); + break; case GGML_OP_ARANGE: ggml_sycl_arange(ctx, dst); break; @@ -4056,6 +4141,18 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) { GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__, ggml_op_name(node_op)); return false; + case GGML_OP_MUL_MAT: + // We cannot use graphs with ggml_sycl_mul_mat() when SYCL async memory allocation extensions are not available, + // as SYCL malloc / free and host wait calls are not supported when recording to a graph which are all present + // in reordering. + if (!g_ggml_sycl_use_async_mem_op) { + GGML_LOG_INFO( + "%s: disabling SYCL graphs due to unsupported node type when using a compiler without the " + "oneAPI async memory allocation extension " + "%s\n", + __func__, ggml_op_name(node_op)); + return false; + } } } return true; @@ -4442,6 +4539,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g ggml_type src0_type = op->src[0]->type; return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } + case GGML_OP_REPEAT_BACK: + { + ggml_type src0_type = op->src[0]->type; + return src0_type == GGML_TYPE_F32; + } case GGML_OP_DUP: case GGML_OP_ARGMAX: case GGML_OP_NONE: @@ -4478,6 +4580,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return ggml_is_contiguous(op->src[0]); case GGML_OP_RMS_NORM: return ((op->src[0]->ne[0] % WARP_SIZE) == 0); + case GGML_OP_RMS_NORM_BACK: + return ((op->src[0]->ne[0] % WARP_SIZE) == 0); case GGML_OP_SCALE: return true; case GGML_OP_CONT: @@ -4512,6 +4616,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_RWKV_WKV7: case GGML_OP_GATED_LINEAR_ATTN: return true; + case GGML_OP_SSM_CONV: + return op->type == GGML_TYPE_F32 && + op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_F32; + case GGML_OP_ROLL: + return op->type == GGML_TYPE_F32; case GGML_OP_ARANGE: return op->type == GGML_TYPE_F32; default: diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 4ec1416849..823d3a4828 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -480,6 +480,162 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); } +void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); // dz + GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); // x + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float eps = 1e-5f; + std::memcpy(&eps, dst->op_params, sizeof(float)); + if (!(eps > 0.0f) || !std::isfinite(eps)) eps = 1e-5f; + + const float * g_base = static_cast(dst->src[0]->data); // dz + const float * x_base = static_cast(dst->src[1]->data); // x + float * dx_base = static_cast< float *>(dst->data); + + const int64_t D = dst->ne[0]; + const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3; + const int64_t N = ggml_nrows(dst); + if (D == 0 || N == 0) return; + + const ggml_tensor *G = dst->src[0]; + const ggml_tensor *X = dst->src[1]; + const int ts = (int) ggml_type_size(X->type); + GGML_ASSERT((size_t) X->nb[0] == (size_t) ts); + GGML_ASSERT((size_t) G->nb[0] == (size_t) ts); + GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts); + + const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts; + const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts; + const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts; + + dpct::queue_ptr q = ctx.stream(); + + // work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D + const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device]; + auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; }; + int wg_cap = 256; + if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg); + int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min(D, wg_cap), WARP_SIZE), wg_cap)); + + // FP32 path: per-thread compensated accumulation + hierarchical reduction + q->submit([&](sycl::handler &cgh) { + const int nwarps_loc = std::max(1, WG / WARP_SIZE); + // store one partial value per warp (xx and xg) for cross-warp reduction + auto l_xx = sycl::local_accessor(sycl::range<1>(nwarps_loc), cgh); + auto l_xg = sycl::local_accessor(sycl::range<1>(nwarps_loc), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG), + sycl::range<3>(1, 1, WG)), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + const int row = item_ct1.get_group(2); + const int tid = item_ct1.get_local_id(2); + + const int64_t i1 = row % n1; + const int64_t i2 = (row / n1) % n2; + const int64_t i3 = row / (n1 * n2); + + const float *__restrict x_row = x_base + i3 * xs3 + i2 * xs2 + i1 * xs1; + const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1; + float *__restrict d_row = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1; + + // per-thread accumulation (compensated by default) + float sum_xx = 0.f, sum_xg = 0.f; +#ifndef GGML_SYCL_RMS_BACK_FAST + float c_xx = 0.f, c_xg = 0.f; +#endif + for (int64_t col = tid; col < D; col += WG) { + const float xv = x_row[col]; + const float gv = g_row[col]; +#ifdef GGML_SYCL_RMS_BACK_FAST + sum_xx += xv * xv; + sum_xg += xv * gv; +#else + float y1 = xv * xv - c_xx; + float t1 = sum_xx + y1; + c_xx = (t1 - sum_xx) - y1; + sum_xx = t1; + + float y2 = xv * gv - c_xg; + float t2 = sum_xg + y2; + c_xg = (t2 - sum_xg) - y2; + sum_xg = t2; +#endif + } + + // warp-level reduction + sycl::float2 xx = sycl::float2(sum_xx, +#ifndef GGML_SYCL_RMS_BACK_FAST + c_xx +#else + 0.f +#endif + ); + sycl::float2 xg = sycl::float2(sum_xg, +#ifndef GGML_SYCL_RMS_BACK_FAST + c_xg +#else + 0.f +#endif + ); + xx = warp_reduce_sum(xx, item_ct1); + xg = warp_reduce_sum(xg, item_ct1); + + // cross-warp reduction using local memory (single barrier) + const auto sub_group = item_ct1.get_sub_group(); + const auto sg_id = sub_group.get_group_linear_id(); + const auto wi_in_sg = sub_group.get_local_linear_id(); + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + + sycl::float2 xx_total = xx; + sycl::float2 xg_total = xg; + if (nwarps > 1) { + if (wi_in_sg == 0) { + l_xx[sg_id] = xx; + l_xg[sg_id] = xg; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + if (sg_id == 0) { + const unsigned wi_u = wi_in_sg; + sycl::float2 xx_first = (wi_u < static_cast(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f); + sycl::float2 xg_first = (wi_u < static_cast(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f); + xx_total = warp_reduce_sum(xx_first, item_ct1); + xg_total = warp_reduce_sum(xg_first, item_ct1); + } else { + // other subgroups keep their local totals; they'll be ignored + xx_total = xx; + xg_total = xg; + } + // ensure all threads see the first-subgroup result via broadcast below + } + + // compute inv_r and coeff once per row and broadcast to the whole work-group + float inv_r = 0.f; + float coeff = 0.f; + if (tid == 0) { + const float sum_xx_f = xx_total.x() + xx_total.y(); + const float sum_xdz_f = xg_total.x() + xg_total.y(); + const float mean_eps = sum_xx_f / (float) D + eps; + const float sum_eps = sum_xx_f + eps * (float) D; + inv_r = sycl::rsqrt(mean_eps); + coeff = -sum_xdz_f / sum_eps; + } + inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r); + coeff = sycl::group_broadcast(item_ct1.get_group(), coeff); + + for (int64_t col = tid; col < D; col += WG) { + d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r; + } + }); + }); + +} + void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-sycl/norm.hpp b/ggml/src/ggml-sycl/norm.hpp index 612cd67cf9..8cb885eb2e 100644 --- a/ggml/src/ggml-sycl/norm.hpp +++ b/ggml/src/ggml-sycl/norm.hpp @@ -19,6 +19,8 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); +void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context& ctx, ggml_tensor* dst); + void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst); diff --git a/ggml/src/ggml-sycl/repeat_back.cpp b/ggml/src/ggml-sycl/repeat_back.cpp new file mode 100644 index 0000000000..abcd4cee72 --- /dev/null +++ b/ggml/src/ggml-sycl/repeat_back.cpp @@ -0,0 +1,56 @@ +#include "repeat_back.hpp" + +#include "common.hpp" + +void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const float * src0_dd = (const float *) dst->src[0]->data; + float * dst_dd = (float *) dst->data; + + const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3]; + const int64_t ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2], + ne03 = dst->src[0]->ne[3]; + + const int nr0 = (int) (ne00 / ne0); + const int nr1 = (int) (ne01 / ne1); + const int nr2 = (int) (ne02 / ne2); + const int nr3 = (int) (ne03 / ne3); + + const size_t total = ne0 * ne1 * ne2 * ne3; + const int BLOCK_SIZE = 256; + const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE; + + queue_ptr stream = ctx.stream(); + + stream->parallel_for( + sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)), + [=](sycl::nd_item<1> item_ct1) { + const size_t i = item_ct1.get_global_linear_id(); + if (i >= total) { + return; + } + + const int i0 = i % ne0; + const int i1 = (i / ne0) % ne1; + const int i2 = (i / (ne0 * ne1)) % ne2; + const int i3 = i / (ne0 * ne1 * ne2); + + float acc = 0.0f; + + for (int j3 = 0; j3 < nr3; ++j3) { + for (int j2 = 0; j2 < nr2; ++j2) { + for (int j1 = 0; j1 < nr1; ++j1) { + for (int j0 = 0; j0 < nr0; ++j0) { + acc += src0_dd[(i0 + j0 * ne0) + (i1 + j1 * ne1) * ne00 + (i2 + j2 * ne2) * ne00 * ne01 + + (i3 + j3 * ne3) * ne00 * ne01 * ne02]; + } + } + } + } + + dst_dd[i] = acc; + }); +} diff --git a/ggml/src/ggml-sycl/repeat_back.hpp b/ggml/src/ggml-sycl/repeat_back.hpp new file mode 100644 index 0000000000..17a87f3e15 --- /dev/null +++ b/ggml/src/ggml-sycl/repeat_back.hpp @@ -0,0 +1,8 @@ +#ifndef GGML_SYCL_REPEAT_BACK_HPP +#define GGML_SYCL_REPEAT_BACK_HPP + +#include "common.hpp" + +void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#endif // GGML_SYCL_REPEAT_BACK_HPP diff --git a/ggml/src/ggml-sycl/roll.cpp b/ggml/src/ggml-sycl/roll.cpp new file mode 100644 index 0000000000..1e05181789 --- /dev/null +++ b/ggml/src/ggml-sycl/roll.cpp @@ -0,0 +1,122 @@ +#include "roll.hpp" +#include "common.hpp" + +using namespace sycl; + +static inline int wrap_add(int i, int shift, int n) { + + int s = i + shift; + return (s >= n) ? (s - n) : s; +} + +static void kernel_roll_fused_i0_i1( + queue &q, + const float *src_d, + float *dst_d, + int ne0, int ne1, int ne2, int ne3, + int sh0, int sh1, int sh2, int sh3) +{ + if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return; + + + const int stride1 = ne0; + const int stride2 = ne0 * ne1; + const int stride3 = ne0 * ne1 * ne2; + + + const int shNe0 = (ne0 - sh0) % ne0; + const int shNe1 = (ne1 - sh1) % ne1; + const int shNe2 = (ne2 - sh2) % ne2; + const int shNe3 = (ne3 - sh3) % ne3; + + + const size_t g0 = (size_t) ne3; + const size_t g1 = (size_t) ne2; + const size_t g2 = (size_t) (ne1 * ne0); + + const range<3> global{ g0, g1, g2 }; + + q.submit([&](handler &h) { + h.parallel_for(global, [=](id<3> idx) { + const int i3 = (int) idx[0]; + const int i2 = (int) idx[1]; + + const int fused = (int) idx[2]; + const int i1 = fused / ne0; + const int i0 = fused - i1 * ne0; // fused % ne0 + + + const int idx_dst = i0 + + i1 * stride1 + + i2 * stride2 + + i3 * stride3; + + + const int s0 = wrap_add(i0, shNe0, ne0); + const int s1 = wrap_add(i1, shNe1, ne1); + const int s2 = wrap_add(i2, shNe2, ne2); + const int s3 = wrap_add(i3, shNe3, ne3); + + const int idx_src = s0 + + s1 * stride1 + + s2 * stride2 + + s3 * stride3; + + dst_d[idx_dst] = src_d[idx_src]; + }); + }); +} + +void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const ggml_tensor *src = dst->src[0]; + GGML_ASSERT(src && src->type == GGML_TYPE_F32); + + const int ne0 = (int) dst->ne[0]; + const int ne1 = (int) dst->ne[1]; + const int ne2 = (int) dst->ne[2]; + const int ne3 = (int) dst->ne[3]; + + const int32_t *params = (const int32_t *) dst->op_params; + int shift0 = params[0]; + int shift1 = params[1]; + int shift2 = params[2]; + int shift3 = params[3]; + + + if ((shift0 | shift1 | shift2 | shift3) == 0) { + const size_t nb = ggml_nbytes(src); + queue *q = ctx.stream(); + SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb))); + return; + } + + auto norm = [](int sh, int n) -> int { + if (n <= 0) return 0; + sh %= n; + if (sh < 0) sh += n; + return sh; + }; + shift0 = norm(shift0, ne0); + shift1 = norm(shift1, ne1); + shift2 = norm(shift2, ne2); + shift3 = norm(shift3, ne3); + + try { + queue *q = ctx.stream(); + + const float *src_d = (const float *) src->data; + float *dst_d = (float *) dst->data; + GGML_ASSERT(src_d && dst_d); + + kernel_roll_fused_i0_i1( + *q, src_d, dst_d, + ne0, ne1, ne2, ne3, + shift0, shift1, shift2, shift3 + ); + } catch (const std::exception &e) { + std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what()); + throw; + } +} diff --git a/ggml/src/ggml-sycl/roll.hpp b/ggml/src/ggml-sycl/roll.hpp new file mode 100644 index 0000000000..97dc03d64b --- /dev/null +++ b/ggml/src/ggml-sycl/roll.hpp @@ -0,0 +1,20 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_ROLL_HPP +#define GGML_SYCL_ROLL_HPP + +#include "common.hpp" + +void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst); + +#endif // GGML_SYCL_ROLL_HPP diff --git a/ggml/src/ggml-sycl/ssm_conv.cpp b/ggml/src/ggml-sycl/ssm_conv.cpp new file mode 100644 index 0000000000..0dc0f71c9a --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_conv.cpp @@ -0,0 +1,127 @@ +#include "ssm_conv.hpp" +#include "common.hpp" + +#include + +using namespace sycl; + +static void kernel_ssm_conv( + queue &q, + const float *src_data, + const float *weights, + float *dst_data, + int d_conv, + int d_inner, + int n_t, + int n_s, + int ncs __attribute__((unused)), + int src_stride_inner, + int src_stride_seq, + int dst_stride_token, + int dst_stride_seq +) { + const size_t total_work = static_cast(d_inner) * static_cast(n_t) * static_cast(n_s); + const size_t work_group_size = 256; + const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size; + + const range<1> global_range(num_work_groups * work_group_size); + const range<1> local_range(work_group_size); + + q.submit([&](handler &h) { + h.parallel_for( + nd_range<1>(global_range, local_range), + [=](nd_item<1> item) { + const size_t idx = item.get_global_id(0); + if (idx >= total_work) { + return; + } + + const int channel = static_cast(idx % d_inner); + const int token = static_cast((idx / d_inner) % n_t); + const int seq = static_cast(idx / (static_cast(d_inner) * static_cast(n_t))); + + const float *s = src_data + + static_cast(seq) * static_cast(src_stride_seq) + + static_cast(channel) * static_cast(src_stride_inner) + + static_cast(token); + + const float *c = weights + static_cast(channel) * static_cast(d_conv); + + float sumf = 0.0f; + for (int i0 = 0; i0 < d_conv; ++i0) { + sumf += s[i0] * c[i0]; + } + + const size_t dst_idx = + static_cast(seq) * static_cast(dst_stride_seq) + + static_cast(token) * static_cast(dst_stride_token) + + static_cast(channel); + + dst_data[dst_idx] = sumf; + } + ); + }); +} + +void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int d_conv = src1->ne[0]; + const int ncs = src0->ne[0]; + const int d_inner = src0->ne[1]; + const int n_t = dst->ne[1]; + const int n_s = dst->ne[2]; + + GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t); + GGML_ASSERT(src0->ne[1] == d_inner); + GGML_ASSERT(src1->ne[1] == d_inner); + + GGML_ASSERT(dst->ne[0] == d_inner); + GGML_ASSERT(dst->ne[1] == n_t); + GGML_ASSERT(dst->ne[2] == n_s); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + + GGML_ASSERT(src0->nb[1] == src0->ne[0] * static_cast(sizeof(float))); + + const int src_stride_inner = ncs; + const int src_stride_seq = ncs * d_inner; + const int dst_stride_token = d_inner; + const int dst_stride_seq = d_inner * n_t; + + try { + queue *q = ctx.stream(); + + const float *src_data = static_cast(src0->data); + const float *weights = static_cast(src1->data); + float *dst_data = static_cast(dst->data); + + GGML_ASSERT(src_data && weights && dst_data); + + kernel_ssm_conv( + *q, + src_data, + weights, + dst_data, + d_conv, + d_inner, + n_t, + n_s, + ncs, + src_stride_inner, + src_stride_seq, + dst_stride_token, + dst_stride_seq + ); + + } catch (const std::exception &e) { + std::fprintf(stderr, "[SYCL-SSM_CONV] ERROR: %s\n", e.what()); + throw; + } +} diff --git a/ggml/src/ggml-sycl/ssm_conv.hpp b/ggml/src/ggml-sycl/ssm_conv.hpp new file mode 100644 index 0000000000..1a8ad05f0c --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_conv.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 21bd052255..50e7922dc6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -96,8 +96,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define GGML_VK_MAX_NODES 8192 -#define MAX_VK_BUFFERS 256 - #define VK_CHECK(err, msg) \ do { \ vk::Result err_ = (err); \ @@ -387,12 +385,76 @@ static constexpr uint32_t num_argsort_pipelines = 11; static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); static constexpr uint32_t num_topk_moe_pipelines = 10; -static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; -static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS }; +static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, + GGML_OP_RESHAPE }; +static constexpr std::initializer_list topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }; +static constexpr std::initializer_list topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; +//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ] +//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] +//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] +//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ] +//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ] +//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ] +//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] +//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ] +//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ] +//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ] +static constexpr std::initializer_list> topk_moe_early_softmax_norm_edges { + { 1, 0, 0 }, // reshape->src[0] == softmax + { 2, 0, 0 }, // argsort->src[0] == softmax + { 3, 0, 2 }, // view->src[0] == argsort + { 4, 0, 1 }, // get_rows->src[0] == reshape + { 4, 1, 3 }, // get_rows->src[1] == view + { 5, 0, 4 }, // reshape->src[0] == get_rows + { 6, 0, 5 }, // sum_rows->src[0] == reshape + { 7, 0, 6 }, // clamp->src[0] == sum_rows + { 8, 0, 5 }, // div->src[0] == reshape + { 8, 1, 7 }, // div->src[1] == clamp + { 9, 0, 8 }, // reshape->src[0] == div +}; + +// same as early_softmax_norm but ending after the get_rows +static constexpr std::initializer_list> topk_moe_early_softmax_edges { + { 1, 0, 0 }, // reshape->src[0] == softmax + { 2, 0, 0 }, // argsort->src[0] == softmax + { 3, 0, 2 }, // view->src[0] == argsort + { 4, 0, 1 }, // get_rows->src[0] == reshape + { 4, 1, 3 }, // get_rows->src[1] == view +}; + +//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ] +//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ] +//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ] +//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ] +//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ] +//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ] +static constexpr std::initializer_list> topk_moe_late_softmax_edges { + { 1, 0, 0 }, // view->src[0] == argsort + { 2, 1, 1 }, // get_rows->src[1] == view + { 3, 0, 2 }, // reshape->src[0] == get_rows + { 4, 0, 3 }, // soft_max->src[0] == reshape + { 5, 0, 4 }, // reshape->src[0] == soft_max +}; + +enum topk_moe_mode { + TOPK_MOE_EARLY_SOFTMAX, + TOPK_MOE_EARLY_SOFTMAX_NORM, + TOPK_MOE_LATE_SOFTMAX, + TOPK_MOE_COUNT, +}; + +static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) { + topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM : + num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX : + TOPK_MOE_LATE_SOFTMAX; + return mode; +} struct vk_device_struct { std::recursive_mutex mutex; @@ -488,6 +550,7 @@ struct vk_device_struct { vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT]; vk_pipeline pipeline_matmul_split_k_reduce; vk_pipeline pipeline_quantize_q8_1; @@ -525,7 +588,7 @@ struct vk_device_struct { vk_pipeline pipeline_add_id_f32; vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; - vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32; + vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32; vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_sqr_f32; vk_pipeline pipeline_sqrt_f32; @@ -606,8 +669,7 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_split_k_reduce; - // [2] is {!norm, norm} - vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2]; + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT]; std::vector all_pipelines; @@ -955,6 +1017,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); struct vk_op_topk_moe_push_constants { uint32_t n_rows; uint32_t n_expert_used; + float clamp_min; + float clamp_max; }; struct vk_op_add_id_push_constants { @@ -1240,6 +1304,7 @@ struct vk_op_upscale_push_constants { uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; float sf0; float sf1; float sf2; float sf3; + float pixel_offset; }; struct vk_op_sum_rows_push_constants @@ -1311,7 +1376,6 @@ struct ggml_vk_garbage_collector { std::vector tl_semaphores; std::vector semaphores; std::vector events; - std::vector temp_buffers; std::vector contexts; }; @@ -1482,8 +1546,6 @@ struct ggml_backend_vk_context { // and set to true after the buffer contents are consumed. bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; - vk_buffer buffer_pool[MAX_VK_BUFFERS]; - vk_context_ref compute_ctx; vk_context_ref transfer_ctx; @@ -2452,8 +2514,11 @@ static void ggml_vk_load_shaders(vk_device& device) { l_warptile_id, m_warptile_id, s_warptile_id, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, + l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k, l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, - l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; + l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid, + l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int, + l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k; std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, @@ -2516,10 +2581,16 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + // Integer MMQ has a smaller shared memory profile, but heavier register use l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + // K-quants use even more registers, mitigate by setting WMITER to 1 + l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 }; + l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; @@ -2528,10 +2599,18 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; + l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; + m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; + s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; + + l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; + m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; + s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; + // chip specific tuning if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; - m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; } l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; @@ -2916,18 +2995,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ -#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ if (device->mul_mat ## ID ## _m[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ if (device->mul_mat ## ID ## _s[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ // Create 2 variants, {f16,f32} accumulator @@ -2966,11 +3042,19 @@ static void ggml_vk_load_shaders(vk_device& device) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { - CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0); } #endif @@ -3000,6 +3084,24 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + + CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + } +#endif } else { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); @@ -3026,6 +3128,24 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + } +#endif } #undef CREATE_MM2 #undef CREATE_MMQ @@ -3090,6 +3210,12 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, ); } #endif @@ -3149,7 +3275,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } // reusing CREATE_MM from the fp32 path if ((device->coopmat2 || device->coopmat_support) -#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && !device->coopmat_bf16_support #endif ) { @@ -3498,7 +3624,6 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); - ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1); ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -3623,8 +3748,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1); - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1); + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); + } ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1); @@ -3739,8 +3869,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; + vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { return nullptr; @@ -5067,6 +5206,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } } + // MMQ + if (src1_type == GGML_TYPE_Q8_1) { + vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc; + + if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + return nullptr; + } + + return pipelines; + } + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { @@ -5144,71 +5294,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; } -static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) { - VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")"); - VK_LOG_MEMORY("ggml_vk_pool_malloc"); - - int best_i = -1; - size_t best_size = std::numeric_limits::max(); //smallest unused buffer that fits our needs - int worst_i = -1; - size_t worst_size = 0; //largest unused buffer seen so far - for (int i = 0; i < MAX_VK_BUFFERS; ++i) { - vk_buffer &b = ctx->buffer_pool[i]; - if (b != nullptr && b->size >= size && b->size < best_size) { - best_i = i; - best_size = b->size; - } - if (b != nullptr && b->size > worst_size) { - worst_i = i; - worst_size = b->size; - } - } - if(best_i != -1) { - //found the smallest buffer that fits our needs - vk_buffer b = ctx->buffer_pool[best_i]; - ctx->buffer_pool[best_i].reset(); - return b; - } - if(worst_i != -1) { - //no buffer that fits our needs, resize largest one to save memory - vk_buffer& b = ctx->buffer_pool[worst_i]; - ggml_vk_destroy_buffer(b); - } - - return ggml_vk_create_buffer_device(ctx->device, size); -} - -static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) { - VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")"); - for (int i = 0; i < MAX_VK_BUFFERS; ++i) { - vk_buffer& b = ctx->buffer_pool[i]; - if (b == nullptr) { - b = buffer; - return; - } - } - std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl; - ggml_vk_destroy_buffer(buffer); -} - -// Returns an available temporary buffer that may only be used temporarily, it will be reused -static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) { - // Try to find existing temp buffer with enough capacity - for (auto& buffer : ctx->gc.temp_buffers) { - if (buffer->size >= size) { - return buffer; - } - } - - VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")"); - - // Otherwise create new buffer - vk_buffer buf = ggml_vk_pool_malloc(ctx, size); - ctx->gc.temp_buffers.push_back(buf); - - return buf; -} - static void * ggml_vk_host_malloc(vk_device& device, size_t size) { VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); vk_buffer buf = ggml_vk_create_buffer(device, size, @@ -5709,14 +5794,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); // Copy device to device ggml_vk_ensure_sync_staging_buffer(src->device, size); - ggml_vk_ensure_sync_staging_buffer(dst->device, size); // Copy to src staging buffer ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); - // memcpy to dst staging buffer - memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size); // Copy to dst buffer - ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size); + ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1); } } @@ -6937,10 +7019,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; + + // Check for mmq first + vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr; + + if (mmp == nullptr) { + // Fall back to f16 dequant mul mat + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + quantize_y = false; + } const bool qx_needs_dequant = mmp == nullptr || x_non_contig; - const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); if (qx_needs_dequant) { // Fall back to dequant + f16 mulmat @@ -6950,8 +7041,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); - const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8; vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); @@ -6964,12 +7055,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; - const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); const uint64_t ids_sz = nbi2; const uint64_t d_sz = sizeof(float) * d_ne; vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; + vk_pipeline to_q8_1 = nullptr; if (x_non_contig) { to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); @@ -6984,9 +7076,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); + } + if (dryrun) { const uint64_t x_sz_upd = x_sz * ne02 * ne03; - const uint64_t y_sz_upd = y_sz * ne12 * ne13; + uint64_t y_sz_upd = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; + } if ( (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { @@ -6995,7 +7094,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ctx->prealloc_size_x = x_sz_upd; } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { ctx->prealloc_size_y = y_sz_upd; } @@ -7007,6 +7106,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } return; } @@ -7043,6 +7145,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (qy_needs_dequant) { d_Y = ctx->prealloc_y; GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -7074,6 +7179,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ctx->prealloc_y_last_tensor_used = src1; } } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } uint32_t stride_batch_x = ne00*ne01; uint32_t stride_batch_y = ne10*ne11; @@ -7082,14 +7198,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); } - if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) { stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } + uint32_t y_sz_total = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; + } + // compute ggml_vk_matmul_id( ctx, subctx, pipeline, - { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total }, { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, ne01, ne21, ne10, ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, @@ -7855,14 +7976,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_UPSCALE: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - int mode = ggml_get_op_params_i32(dst, 0); + ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(dst, 0) & 0xFF); switch (mode) { case GGML_SCALE_MODE_NEAREST: return ctx->device->pipeline_upscale_nearest_f32; case GGML_SCALE_MODE_BILINEAR: return ctx->device->pipeline_upscale_bilinear_f32; - case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS: - return ctx->device->pipeline_upscale_bilinear_ac_f32; + default: + return nullptr; } } return nullptr; @@ -8028,8 +8149,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (ctx->num_additional_fused_ops) { uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); GGML_ASSERT(idx < num_topk_moe_pipelines); - bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; - return ctx->device->pipeline_topk_moe[idx][with_norm]; + topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); + return ctx->device->pipeline_topk_moe[idx][mode]; } if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { @@ -8084,6 +8205,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; } case GGML_OP_ARGSORT: + if (ctx->num_additional_fused_ops) { + uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); + GGML_ASSERT(idx < num_topk_moe_pipelines); + topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); + return ctx->device->pipeline_topk_moe[idx][mode]; + } + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); return ctx->device->pipeline_argsort_f32[idx]; @@ -9351,22 +9479,26 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0); - float sf0 = (float)dst->ne[0] / src0->ne[0]; - float sf1 = (float)dst->ne[1] / src0->ne[1]; - float sf2 = (float)dst->ne[2] / src0->ne[2]; - float sf3 = (float)dst->ne[3] / src0->ne[3]; + GGML_TENSOR_UNARY_OP_LOCALS + + float sf0 = (float)ne0 / ne00; + float sf1 = (float)ne1 / ne01; + float sf2 = (float)ne2 / ne02; + float sf3 = (float)ne3 / ne03; + float pixel_offset = 0.5f; if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) { - sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); - sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; + pixel_offset = 0.0f; } ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { (uint32_t)ggml_nelements(dst), 0, 0, - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], - (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, - (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], - sf0, sf1, sf2, sf3, + (uint32_t)ne00, (uint32_t)ne01, + (uint32_t)nb00 / src0_type_size, (uint32_t)nb01 / src0_type_size, (uint32_t)nb02 / src0_type_size, (uint32_t)nb03 / src0_type_size, + (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, + sf0, sf1, sf2, sf3, pixel_offset }, dryrun); } @@ -9619,10 +9751,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) { - bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1; + topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; - ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; - ggml_tensor * ids = cgraph->nodes[node_idx + 3]; + ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] : + (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] : + cgraph->nodes[node_idx + 5]; + ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3]; GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); @@ -9681,9 +9815,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, GGML_ASSERT(d_ids != nullptr); } - vk_op_topk_moe_push_constants pc; + vk_op_topk_moe_push_constants pc {}; pc.n_rows = n_rows; pc.n_expert_used = n_expert_used; + if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { + ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; + pc.clamp_min = ggml_get_op_params_f32(clamp, 0); + pc.clamp_max = ggml_get_op_params_f32(clamp, 1); + } GGML_ASSERT(n_expert_used <= n_experts); @@ -11278,7 +11417,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } } + +#define ENABLE_SYNC_LOGGING 0 + if (need_sync) { +#if ENABLE_SYNC_LOGGING + std::cerr << "sync" << std::endl; +#endif ctx->unsynced_nodes_written.clear(); ctx->unsynced_nodes_read.clear(); ggml_vk_sync_buffers(ctx, compute_ctx); @@ -11296,6 +11441,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } } +#if ENABLE_SYNC_LOGGING + if (!dryrun) { + for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { + auto *n = cgraph->nodes[node_idx + i]; + std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name; + if (n->op == GGML_OP_GLU) { + std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " "; + } + std::cerr << std::endl; + } + } +#endif switch (node->op) { case GGML_OP_REPEAT: @@ -11474,7 +11631,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_ARGSORT: - ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + if (ctx->num_additional_fused_ops) { + ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun); + } else { + ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + } break; case GGML_OP_SUM: @@ -11789,10 +11950,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * // Clean up after graph processing is done static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); - for (auto& buffer : ctx->gc.temp_buffers) { - ggml_vk_pool_free(ctx, buffer); - } - ctx->gc.temp_buffers.clear(); ctx->prealloc_y_last_pipeline_used = {}; ctx->unsynced_nodes_written.clear(); @@ -11835,10 +11992,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_destroy_buffer(ctx->prealloc_split_k); ctx->prealloc_y_last_pipeline_used = nullptr; - for (auto& buffer : ctx->buffer_pool) { - ggml_vk_destroy_buffer(buffer); - } - ctx->prealloc_size_x = 0; ctx->prealloc_size_y = 0; ctx->prealloc_size_split_k = 0; @@ -12255,31 +12408,28 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st } static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, - int node_idx, bool with_norm) { + int node_idx, topk_moe_mode mode) { - if (with_norm) { - if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) { - return false; - } - for (size_t i = 0; i < topk_moe_norm.size(); ++i) { - if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) { - return false; - } - } - } else { - if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) { - return false; - } - for (size_t i = 0; i < topk_moe.size(); ++i) { - if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) { - return false; - } - } + const ggml_tensor * softmax; + const ggml_tensor * weights; + + switch (mode) { + case TOPK_MOE_EARLY_SOFTMAX_NORM: + softmax = cgraph->nodes[node_idx + 0]; + weights = cgraph->nodes[node_idx + 9]; + break; + case TOPK_MOE_EARLY_SOFTMAX: + softmax = cgraph->nodes[node_idx + 0]; + weights = cgraph->nodes[node_idx + 4]; + break; + case TOPK_MOE_LATE_SOFTMAX: + softmax = cgraph->nodes[node_idx + 4]; + weights = cgraph->nodes[node_idx + 5]; + break; + default: + return false; } - const ggml_tensor * softmax = cgraph->nodes[node_idx + 0]; - const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4]; - const float * op_params = (const float *)softmax->op_params; float scale = op_params[0]; @@ -12304,60 +12454,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc return false; } - // Check that the nodes don't have any unexpected uses - const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1]; - const ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; - const ggml_tensor * view = cgraph->nodes[node_idx + 3]; - const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; - const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr; - const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr; - const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr; - const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr; - - // softmax is used by reshape and argsort - if (ggml_node_get_use_count(cgraph, node_idx) != 2 || - reshape1->src[0] != softmax || - argsort->src[0] != softmax) { - return false; - } - // reshape is used by get_rows - if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 || - get_rows->src[0] != reshape1) { - return false; - } - // argsort is used by view - if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 || - view->src[0] != argsort) { - return false; - } - // view is written (via argsort), we can skip checking it - - if (with_norm) { - // get_rows is used by reshape - if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 || - reshape5->src[0] != get_rows) { - return false; - } - - // reshape is used by sum_rows and div - if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 || - sum_rows->src[0] != reshape5 || - div->src[0] != reshape5) { - return false; - } - - // sum_rows is used by div - if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 || - div->src[1] != sum_rows) { - return false; - } - - // div/reshape are written - if (reshape8->src[0] != div) { - return false; - } - } - if (!ctx->device->subgroup_arithmetic || !ctx->device->subgroup_shuffle || !ctx->device->subgroup_require_full_support || @@ -12443,10 +12539,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = num_adds - 1; } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { - ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { - ctx->num_additional_fused_ops = topk_moe.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && + ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; } } ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); @@ -12544,10 +12648,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = num_adds - 1; } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) { - ctx->num_additional_fused_ops = topk_moe_norm.size() - 1; - } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) { - ctx->num_additional_fused_ops = topk_moe.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && + ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && + ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { + ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; } } @@ -12679,25 +12791,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * while (first_unused < graph->n_nodes) { std::vector current_set; - // Avoid reordering topk_moe_norm - if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) { - bool is_topk_moe_norm = true; - for (size_t j = 0; j < topk_moe_norm.size(); ++j) { - if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) { - is_topk_moe_norm = false; + // Check for fusion patterns and avoid reordering them + auto const &match_pattern = [&](const std::initializer_list &pattern, int start) -> bool { + if (start + (int)pattern.size() <= graph->n_nodes) { + bool is_pattern = true; + for (size_t j = 0; j < pattern.size(); ++j) { + if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) { + is_pattern = false; + } } + return is_pattern; } - if (is_topk_moe_norm) { - for (size_t j = 0; j < topk_moe_norm.size(); ++j) { + return false; + }; + + auto const &keep_pattern = [&](const std::initializer_list &pattern) -> bool { + if (match_pattern(pattern, first_unused)) { + for (size_t j = 0; j < pattern.size(); ++j) { new_order.push_back(graph->nodes[first_unused + j]); used[first_unused + j] = true; } while (first_unused < graph->n_nodes && used[first_unused]) { first_unused++; } - continue; + return true; } + return false; + }; + + if (keep_pattern(topk_moe_early_softmax_norm)) { + continue; } + if (keep_pattern(topk_moe_early_softmax)) { + continue; + } + if (keep_pattern(topk_moe_late_softmax)) { + continue; + } + // First, grab the next unused node. current_set.push_back(first_unused); @@ -12715,6 +12846,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (is_empty(graph->nodes[j])) { continue; } + // Don't pull forward nodes from fusion patterns + if (match_pattern(topk_moe_early_softmax_norm, j) || + match_pattern(topk_moe_early_softmax, j) || + match_pattern(topk_moe_late_softmax, j)) { + continue; + } bool ok = true; for (int c = first_unused; c < j; ++c) { if (!used[c] && diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 0d98f5a9d6..09676a623b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -437,7 +437,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { #if defined(DATA_A_MXFP4) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); - return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]); + return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]) * 0.5; } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { vec2 v0 = dequantize(ib, iqs, a_offset); @@ -488,9 +488,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]); const uint scales = data_a[a_offset + ib].scales[scalesi]; - const vec2 d = vec2(data_a[a_offset + ib].d); + const vec2 dm = vec2(data_a[a_offset + ib].dm); - return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + return dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); } vec2 get_dm(uint ib, uint a_offset) { return vec2(1, 0); @@ -529,7 +529,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint is = 2 * n + b; // 0..7 const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - const vec2 loadd = vec2(data_a[a_offset + ib].d); + const vec2 loadd = vec2(data_a[a_offset + ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); @@ -567,7 +567,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint8_t hm = uint8_t(1 << (iqs / 16)); - const vec2 loadd = vec2(data_a[a_offset + ib].d); + const vec2 loadd = vec2(data_a[a_offset + ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 67baedf7c6..8ac6482dc9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -120,7 +120,7 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2 float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); - const f16vec2 d = bl.block.d; + const f16vec2 dm = bl.block.dm; const uint idx = coordInBlock[1]; const uint scalesi = (idx & 0xF0) >> 4; // 0..15 @@ -131,7 +131,7 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2 qs = unpack8(qs)[idx & 1]; const uint scales = bl.block.scales[scalesi]; - float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4); + float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4); return ret; } @@ -680,7 +680,7 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords uint32_t qs = bl.block.qs[iqs]; qs >>= shift; qs &= 0xF; - float16_t ret = float16_t(kvalues_mxfp4[qs] * d); + float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5); return ret; } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp index ffba5a77dd..3194ba291f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp @@ -26,7 +26,7 @@ void main() { const float d = e8m0_to_fp32(data_a[ib].e); [[unroll]] for (uint l = 0; l < 8; ++l) { - data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]); - data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]); + data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF])); + data_b[b_idx + l + 16] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4])); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp index 58dc2e5dfd..dc05a78348 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -24,8 +24,8 @@ void main() { const uint ql_idx = 32 * ip + il; const uint8_t qs = data_a[i].qs[32 * ip + il]; - FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); - FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); + FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].dm.x); + FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].dm.y); data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp index 8b7be557e9..0f23dc0a34 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -20,8 +20,8 @@ void main() { const uint is = 2 * il; const uint n = 4; - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y); const uint y_idx = ib * QUANT_K + 64 * il + n * ir; const uint qs_idx = 32*il + n * ir; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp index 6bc04670fc..970469a601 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -19,8 +19,8 @@ void main() { const uint ir = tid % 16; const uint is = 2 * il; - const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].dm.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].dm.y); const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; const uint qs_idx = 32*il + 2 * ir; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 03ed25d3bf..14093c0de5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -41,9 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); @@ -75,7 +73,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im], fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2)))))))); } - temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + temp[j][n] = fma(dm.x, sum1, fma(-dm.y, sum2, temp[j][n])); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 21d07d2e50..49d91ad591 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; @@ -81,7 +79,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n])); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 9e46c89a11..0d61b4966e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -14,9 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - vec2 d = vec2(data_a[ib0 + i].d); - const FLOAT_TYPE dall = FLOAT_TYPE(d.x); - const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; @@ -113,7 +111,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); - temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + temp[j][n] = fma(dm.x, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dm.y, smin, temp[j][n])); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index a20788c4b5..d260969f07 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -120,81 +120,11 @@ shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; #define NUM_WARPS (BLOCK_SIZE / WARP) -#ifdef MUL_MAT_ID -shared u16vec2 row_ids[BN]; -uint _ne1; - -#ifdef MUL_MAT_ID_USE_SUBGROUPS -shared uvec4 ballots_sh[NUM_WARPS]; - -void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { - _ne1 = 0; - uint num_elements = p.nei1 * p.nei0; - uint nei0shift = findLSB(p.nei0); - - uint ids[16]; - uint iter = 0; - - for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { - // prefetch up to 16 elements - if (iter == 0) { - [[unroll]] for (uint k = 0; k < 16; ++k) { - uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; - bool in_range = i < num_elements; - uint ii1; - if (nei0_is_pow2) { - ii1 = i >> nei0shift; - } else { - ii1 = i / p.nei0; - } - uint ii0 = i - ii1 * p.nei0; - ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; - } - } - uint i = j + gl_LocalInvocationIndex; - bool in_range = i < num_elements; - uint ii1; - if (nei0_is_pow2) { - ii1 = i >> nei0shift; - } else { - ii1 = i / p.nei0; - } - uint ii0 = i - ii1 * p.nei0; - uint id = ids[iter++]; - uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - - ballots_sh[gl_SubgroupID] = ballot; - barrier(); - - uint subgroup_base = 0; - uint total = 0; - for (uint k = 0; k < gl_NumSubgroups; ++k) { - if (k == gl_SubgroupID) { - subgroup_base = total; - } - total += subgroupBallotBitCount(ballots_sh[k]); - } - barrier(); - - uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); - if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { - row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); - } - _ne1 += total; - iter &= 15; - if (_ne1 >= (ic + 1) * BN) { - break; - } - } - barrier(); -} -#endif // MUL_MAT_ID_USE_SUBGROUPS -#endif // MUL_MAT_ID - #ifdef COOPMAT shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #endif +#include "mul_mm_id_funcs.glsl" #include "mul_mm_funcs.glsl" void main() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 0ebfbd6462..ee5ded2e8d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -134,15 +134,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 128; // 2 values per idx const uint iqs = idx % 128; // 0..127 - const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15 const uint scalesi = iqs / 8; // 0..15 const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi])); const uint scales = data_a[ib].scales[scalesi]; - const vec2 d = vec2(data_a[ib].d); + const vec2 dm = vec2(data_a[ib].dm); - const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_Q3_K) @@ -179,7 +179,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint is = 2 * n + b; // 0..7 const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 - const vec2 loadd = vec2(data_a[ib].d); + const vec2 loadd = vec2(data_a[ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); @@ -215,7 +215,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint8_t hm = uint8_t(1 << (iqs / 16)); - const vec2 loadd = vec2(data_a[ib].d); + const vec2 loadd = vec2(data_a[ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); const uint scidx1 = (is < 4) ? is : (is - 4); @@ -468,7 +468,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint ib = idx / 8; const uint iqs = (idx & 0x07) * 2; - const float d = e8m0_to_fp32(data_a[ib].e); + const float d = e8m0_to_fp32(data_a[ib].e) * 0.5; const uint vui = uint(data_a[ib].qs[iqs]); const uint vui2 = uint(data_a[ib].qs[iqs+1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl new file mode 100644 index 0000000000..1d0e84ac94 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl @@ -0,0 +1,70 @@ +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[BN]; +uint _ne1; + +#ifdef MUL_MAT_ID_USE_SUBGROUPS +shared uvec4 ballots_sh[NUM_WARPS]; + +void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { + row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); + } + _ne1 += total; + iter &= 15; + if (_ne1 >= (ic + 1) * BN) { + break; + } + } + barrier(); +} +#endif // MUL_MAT_ID_USE_SUBGROUPS +#endif // MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index b5d761c0ba..8b238ac4bc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -10,10 +10,9 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif -#ifdef COOPMAT -#extension GL_KHR_cooperative_matrix : enable -#extension GL_KHR_memory_scope_semantics : enable +#if defined(MUL_MAT_ID_USE_SUBGROUPS) #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable #endif #ifdef MUL_MAT_ID @@ -24,7 +23,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif #if defined(A_TYPE_PACKED32) layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif @@ -76,40 +78,27 @@ layout (constant_id = 10) const uint WARP = 32; #define BK 32 -#ifdef COOPMAT -#define SHMEM_STRIDE (BK / 4 + 4) -#else -#define SHMEM_STRIDE (BK / 4 + 1) +#define MMQ_SHMEM + +#include "mul_mmq_shmem_types.glsl" + +#ifndef BK_STEP +#define BK_STEP 4 #endif -shared int32_t buf_a_qs[BM * SHMEM_STRIDE]; +// Shared memory cache +shared block_a_cache buf_a[BM * BK_STEP]; +shared block_b_cache buf_b[BN * BK_STEP]; +// Register cache +block_a_cache cache_a[WMITER * TM]; +block_b_cache cache_b; -#ifndef COOPMAT -#if QUANT_AUXF == 1 -shared FLOAT_TYPE buf_a_dm[BM]; -#else -shared FLOAT_TYPE_VEC2 buf_a_dm[BM]; -#endif -#endif - -shared int32_t buf_b_qs[BN * SHMEM_STRIDE]; -#ifndef COOPMAT -shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; -#endif - -#define LOAD_VEC_A (4 * QUANT_R) +#define LOAD_VEC_A (4 * QUANT_R_MMQ) #define LOAD_VEC_B 16 -#ifdef MUL_MAT_ID -shared u16vec2 row_ids[4096]; -#endif // MUL_MAT_ID - #define NUM_WARPS (BLOCK_SIZE / WARP) -#ifdef COOPMAT -shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; -#endif - +#include "mul_mm_id_funcs.glsl" #include "mul_mmq_funcs.glsl" void main() { @@ -139,26 +128,12 @@ void main() { const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WSUBM = WM / WMITER; const uint WSUBN = WN / WNITER; - -#ifdef COOPMAT - const uint warp_i = gl_SubgroupID; - - const uint tiw = gl_SubgroupInvocationID; - - const uint cms_per_row = WM / TM; - const uint cms_per_col = WN / TN; - - const uint storestride = WARP / TM; - const uint store_r = tiw % TM; - const uint store_c = tiw / TM; -#else const uint warp_i = gl_LocalInvocationID.x / WARP; const uint tiw = gl_LocalInvocationID.x % WARP; const uint tiwr = tiw % (WSUBM / TM); const uint tiwc = tiw / (WSUBM / TM); -#endif const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM); @@ -172,17 +147,27 @@ void main() { const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK; #ifdef MUL_MAT_ID - uint _ne1 = 0; - for (uint ii1 = 0; ii1 < p.nei1; ii1++) { - for (uint ii0 = 0; ii0 < p.nei0; ii0++) { +#ifdef MUL_MAT_ID_USE_SUBGROUPS + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true, ic); + } else { + load_row_ids(expert_idx, false, ic); + } +#else + _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { - row_ids[_ne1] = u16vec2(ii0, ii1); + if (_ne1 >= ic * BN) { + row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1); + } _ne1++; } } } barrier(); +#endif // Workgroup has no work if (ic * BN >= _ne1) return; @@ -209,159 +194,70 @@ void main() { uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; #endif -#ifdef COOPMAT - coopmat cache_a; - coopmat cache_b; - coopmat cm_result; - - coopmat factors[cms_per_row * cms_per_col]; - - coopmat sums[cms_per_row * cms_per_col]; - - [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { - sums[i] = coopmat(0.0f); - } -#else - int32_t cache_a_qs[WMITER * TM * BK / 4]; - - int32_t cache_b_qs[TN * BK / 4]; - ACC_TYPE sums[WMITER * TM * WNITER * TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); } -#endif -#if QUANT_AUXF == 1 - FLOAT_TYPE cache_a_dm[WMITER * TM]; -#else - FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; -#endif - - FLOAT_TYPE_VEC2 cache_b_ds[TN]; - - for (uint block = start_k; block < end_k; block += BK) { + for (uint block = start_k; block < end_k; block += BK * BK_STEP) { [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { - const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK; - const uint iqs = loadr_a; const uint buf_ib = loadc_a + l; + const uint ib = pos_a_ib + buf_ib * p.stride_a / BK; + const uint iqs = loadr_a; - if (iqs == 0) { -#if QUANT_AUXF == 1 - buf_a_dm[buf_ib] = get_d(ib); -#else - buf_a_dm[buf_ib] = get_dm(ib); -#endif + [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) { + block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs); } -#if QUANT_R == 1 - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs); -#else - const i32vec2 vals = repack(ib, iqs); - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x; - buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y; -#endif } [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { -#ifdef MUL_MAT_ID - const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; - const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; - const uint ib = idx / 8; - const uint iqs = idx & 0x7; -#else - const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; - const uint ib_outer = ib / 4; - const uint ib_inner = ib % 4; - - const uint iqs = loadr_b; -#endif - const uint buf_ib = loadc_b + l; - if (iqs == 0) { - buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[buf_ib]; + const uint ib = pos_b_ib + row_idx.y * p.batch_stride_b / BK + (row_idx.x % p.ne11) * p.stride_b / BK; +#else + const uint ib = pos_b_ib + buf_ib * p.stride_b / BK; +#endif + const uint iqs = loadr_b; + + [[unroll]] for (uint k_step = 0; k_step < BK_STEP; k_step++) { + block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs); } - const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z; - buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w; } barrier(); - pos_a_ib += 1; - pos_b_ib += 1; + pos_a_ib += BK_STEP; + pos_b_ib += BK_STEP; -#ifdef COOPMAT - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - const uint ib_a = warp_r * WM + cm_row * TM; + for (uint k_step = 0; k_step < BK_STEP; k_step++) { // Load from shared into cache - coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); - - // TODO: only cache values that are actually needed - [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { - cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; - } - - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - const uint ib_b = warp_c * WN + cm_col * TN; - coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); - - // TODO: only cache values that are actually needed - [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { - cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; - } - - cm_result = coopmat(0); - cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); - } - - coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - sums[cm_col * cms_per_row + cm_row] += factors * coopmat(cm_result); - } - } -#else - // Load from shared into cache - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; - cache_a_dm[wsir * TM + cr] = buf_a_dm[ib]; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k]; - } - } - } - - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; - cache_b_ds[cc] = buf_b_ds[ib]; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k]; - } - } - [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (uint cc = 0; cc < TN; cc++) { - [[unroll]] for (uint cr = 0; cr < TM; cr++) { - const uint cache_a_idx = wsir * TM + cr; - const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - int32_t q_sum = 0; - [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { - q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], - cache_b_qs[cc * (BK / 4) + idx_k]); - } + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint reg_ib = wsir * TM + cr; + const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; - sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1); + block_a_to_registers(reg_ib, k_step * BM + buf_ib); + } + } + + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc; + block_b_to_registers(ib); + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint cache_a_idx = wsir * TM + cr; + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + + sums[sums_idx] += mmq_dot_product(cache_a_idx); + } } } } } -#endif barrier(); } @@ -373,54 +269,6 @@ void main() { const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif -#ifdef COOPMAT -#ifdef MUL_MAT_ID - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < BN; col += storestride) { - const uint row_i = dc + cm_col * TN + col + store_c; - if (row_i >= _ne1) break; - - const u16vec2 row_idx = row_ids[row_i]; - - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } - } -#else - const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float - - [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { - const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; - - if (is_aligned && is_in_bounds) { - // Full coopMat is within bounds and stride_d is aligned with 16B - coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); - coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); - } else if (is_in_bounds) { - // Full coopMat is within bounds, but stride_d is not aligned - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { - // Partial coopMat is within bounds - coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - - [[unroll]] for (uint col = 0; col < TN; col += storestride) { - if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { - data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); - } - } - } - } - } -#endif // MUL_MAT_ID -#else [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { @@ -431,19 +279,21 @@ void main() { const uint row_i = dc_warp + cc; if (row_i >= _ne1) break; - const u16vec2 row_idx = row_ids[row_i]; + const u16vec2 row_idx = row_ids[row_i - ic * BN]; #endif // MUL_MAT_ID [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint sums_idx = (wsic * TN + cc) * WMITER * TM + wsir * TM + cr; #ifdef MUL_MAT_ID - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x); + } #else if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[sums_idx].x); } #endif // MUL_MAT_ID } } } } -#endif // COOPMAT } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index fe71eb131c..c0c03fedcc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -6,41 +6,89 @@ // Each iqs value maps to a 32-bit integer -#if defined(DATA_A_Q4_0) +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +// 2-byte loads for Q4_0 blocks (18 bytes) +// 4-byte loads for Q4_1 blocks (20 bytes) i32vec2 repack(uint ib, uint iqs) { - // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 - const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], - data_a[ib].qs[iqs * 2 + 1]); +#ifdef DATA_A_Q4_0 + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); const uint32_t vui = pack32(quants); return i32vec2( vui & 0x0F0F0F0F, (vui >> 4) & 0x0F0F0F0F); +#else // DATA_A_Q4_1 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +#endif } +#ifdef DATA_A_Q4_0 ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); } -#endif - -#if defined(DATA_A_Q4_1) -i32vec2 repack(uint ib, uint iqs) { - // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); -} - +#else // DATA_A_Q4_1 ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); } #endif -#if defined(DATA_A_Q5_0) +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { +#ifdef DATA_A_Q4_0 + buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + } +#else // DATA_A_Q4_1 + buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + } +#endif +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const uint32_t vui = cache_a[ib_a].qs[iqs]; + const i32vec2 qs_a = i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); + + const int32_t qs_b0 = cache_b.qs[iqs]; + const int32_t qs_b1 = cache_b.qs[iqs + 4]; + + q_sum += dotPacked4x8EXT(qs_a.x, qs_b0); + q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM + +#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +// 2-byte loads for Q5_0 blocks (22 bytes) +// 4-byte loads for Q5_1 blocks (24 bytes) i32vec2 repack(uint ib, uint iqs) { - // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4 - const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], - data_a[ib].qs[iqs * 2 + 1]); + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); const uint32_t vui = pack32(quants); - const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs)); +#ifdef DATA_A_Q5_0 + const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); +#else // DATA_A_Q5_1 + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); +#endif const int32_t v0 = int32_t(vui & 0x0F0F0F0F) | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) @@ -50,40 +98,457 @@ i32vec2 repack(uint ib, uint iqs) { return i32vec2(v0, v1); } +#ifdef DATA_A_Q5_0 ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); } -#endif - -#if defined(DATA_A_Q5_1) -i32vec2 repack(uint ib, uint iqs) { - // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); - const int32_t v0 = int32_t(vui & 0x0F0F0F0F) - | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) - - const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) - | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) - - return i32vec2(v0, v1); -} - +#else // DATA_A_Q5_1 ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); } #endif +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { +#ifdef DATA_A_Q5_0 + buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + buf_a[buf_ib].qh = pack32(u16vec2(data_a_packed16[ib].qh[0], data_a_packed16[ib].qh[1])); + } +#else // DATA_A_Q5_1 + buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].qh = data_a_packed32[ib].qh; + } +#endif +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + cache_a[reg_ib].qh = buf_a[buf_ib].qh; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const uint32_t vui = cache_a[ib_a].qs[iqs]; + const int32_t qh = int32_t(cache_a[ib_a].qh >> (4 * iqs)); + const int32_t qs_a0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + const int32_t qs_b0 = cache_b.qs[iqs]; + const int32_t qs_b1 = cache_b.qs[iqs + 4]; + + q_sum += dotPacked4x8EXT(qs_a0, qs_b0); + q_sum += dotPacked4x8EXT(qs_a1, qs_b1); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + #if defined(DATA_A_Q8_0) +// 2-byte loads for Q8_0 blocks (34 bytes) int32_t repack(uint ib, uint iqs) { - // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 - return pack32(i16vec2(data_a[ib].qs[iqs * 2 ], - data_a[ib].qs[iqs * 2 + 1])); + return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1])); } ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { return ACC_TYPE(float(q_sum) * da * dsb.x); } + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2], + data_a_packed16[ib].qs[iqs * 2 + 1])); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE(data_a_packed16[ib].d); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + const int32_t qs_b = cache_b.qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, qs_b); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +#if defined(DATA_A_MXFP4) +// 1-byte loads for mxfp4 blocks (17 bytes) +i32vec2 repack(uint ib, uint iqs) { + const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], + data_a[ib].qs[iqs * 4 + 1], + data_a[ib].qs[iqs * 4 + 2], + data_a[ib].qs[iqs * 4 + 3])); + + return i32vec2( quants & 0x0F0F0F0F, + (quants >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * dsb.x * float(q_sum)); +} + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], + data_a[ib].qs[iqs * 4 + 1], + data_a[ib].qs[iqs * 4 + 2], + data_a[ib].qs[iqs * 4 + 3])); + + const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); + const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); + + buf_a[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])); + buf_a[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])); + + if (iqs == 0) { + buf_a[buf_ib].d = FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e) * 0.5); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].d = buf_a[buf_ib].d; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + + return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide +// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants +#if defined(DATA_A_Q2_K) +// 4-byte loads for Q2_K blocks (84 bytes) +int32_t repack(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + + return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303); +} + +uint8_t get_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + return data_a[ib_k].scales[iqs_k / 4]; +} + +ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m))); +} + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + + // Repack 4x4 quants into one int + const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303; + const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303; + const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303; + const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303; + + buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); + + if (iqs == 0) { + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + cache_a[reg_ib].scales = buf_a[buf_ib].scales; + + [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t sum_d = 0; + int32_t sum_m = 0; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + const uint8_t scale = cache_a[ib_a].scales[iqs / 4]; + const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits. + const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 4] >> ((iqs % 4) * 2)) & 0x03030303); + + sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); + } + + return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +#if defined(DATA_A_Q3_K) +// 2-byte loads for Q3_K blocks (110 bytes) +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint hm_idx = iqs * QUANT_R_MMQ; + const uint iqs_k = (ib % 8) * 8 + hm_idx; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + const uint hm_shift = iqs_k / 8; + + // Repack 2x4 quants into one int + // Add the 3rd bit instead of subtracting it to allow packing the quants + const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[hm_idx * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); + buf_a[buf_ib].qs[iqs] = pack32(u8vec4(vals00.x, vals00.y, vals01.x, vals01.y)) | + (pack32(u8vec4(vals10.x, vals10.y, vals11.x, vals11.y)) << 4); + + if (iqs == 0) { + const uint is = iqs_k / 4; + const i8vec2 scales = i8vec2(unpack8(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) | + (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))); + + buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + float result = 0.0; + int32_t q_sum = 0; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + // Subtract 4 from the quants to correct the 3rd bit offset + const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)); + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[0]) * float(q_sum); + q_sum = 0; + + [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { + const int32_t qs_a = pack32(unpack8(int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F)) - int8_t(4)); + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); + + return ACC_TYPE(cache_b.ds.x * result); +} +#endif // MMQ_SHMEM +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y); +} + +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; + + const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 16) / 8) * 4; + + // Repack 2x4 quants into one int +#if defined(DATA_A_Q4_K) + const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F; + + buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 4); +#else // defined(DATA_A_Q5_K) + const uint qh_idx = iqs * QUANT_R_MMQ; + const uint qh_shift = iqs_k / 8; + + buf_a[buf_ib].qs[iqs] = int32_t(((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4)); +#endif + + + if (iqs == 0) { + // Scale index + const uint is = iqs_k / 8; + u8vec2 scale_dm; + if (is < 4) { + scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); + } else { + scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), + (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); + } + + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].dm = buf_a[buf_ib].dm; + + [[unroll]] for (uint iqs = 0; iqs < 8 / QUANT_R_MMQ; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + int32_t q_sum = 0; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { +#if defined(DATA_A_Q4_K) + const int32_t qs_a = int32_t((cache_a[ib_a].qs[iqs / 2] >> ((iqs % 2) * 4)) & 0x0F0F0F0F); +#else // defined(DATA_A_Q5_K) + const int32_t qs_a = cache_a[ib_a].qs[iqs]; +#endif + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + + return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +} +#endif // MMQ_SHMEM +#endif + +#ifdef MMQ_SHMEM +void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; + + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + } + + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b[buf_ib].qs[iqs * 4 ] = values.x; + buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; + buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; + buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; +} + +void block_b_to_registers(const uint ib) { + cache_b.ds = buf_b[ib].ds; + [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { + cache_b.qs[iqs] = buf_b[ib].qs[iqs]; + } +} +#endif + +#if defined(DATA_A_Q6_K) +// 2-byte loads for Q6_K blocks (210 bytes) +#ifdef MMQ_SHMEM +void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16; + const uint ql_shift = ((iqs_k % 32) / 16) * 4; + + const uint qh_idx = (iqs_k / 32) * 8 + iqs; + const uint qh_shift = ((iqs_k % 32) / 8) * 2; + + const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + buf_a[buf_ib].qs[iqs] = pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)); + + if (iqs == 0) { + const uint is = iqs_k / 4; + const i8vec2 scales = unpack8(data_a_packed16[ib_k].scales[is / 2]); + + buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales); + } +} + +void block_a_to_registers(const uint reg_ib, const uint buf_ib) { + cache_a[reg_ib].d_scales = buf_a[buf_ib].d_scales; + + [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { + cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; + } +} + +ACC_TYPE mmq_dot_product(const uint ib_a) { + float result = 0.0; + int32_t q_sum = 0; + + [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[0]) * float(q_sum); + q_sum = 0; + + [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { + const int32_t qs_a = cache_a[ib_a].qs[iqs]; + + q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); + } + result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); + + return ACC_TYPE(cache_b.ds.x * result); +} +#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) @@ -103,3 +568,10 @@ FLOAT_TYPE_VEC2 get_dm(uint ib) { return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); } #endif + +#if defined(DATA_A_Q2_K) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + const uint ib_k = ib / 8; + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl new file mode 100644 index 0000000000..72fec44049 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -0,0 +1,78 @@ +#if defined(DATA_A_Q4_0) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_Q4_1) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q5_0) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + uint32_t qh; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_Q5_1) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[16/4]; + uint32_t qh; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q8_0) +#define QUANT_R_MMQ 1 +// AMD likes 4, Intel likes 1 and Nvidia likes 2 +#define BK_STEP 1 +struct block_a_cache { + int32_t qs[32/4]; + FLOAT_TYPE dm; +}; +#elif defined(DATA_A_MXFP4) +#define QUANT_R_MMQ 2 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE d; +}; +#elif defined(DATA_A_Q2_K) +#define QUANT_R_MMQ 4 +struct block_a_cache { + uint32_t qs[2]; + u8vec2 scales; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q3_K) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[4]; + FLOAT_TYPE_VEC2 d_scales; +}; +#elif defined(DATA_A_Q4_K) +#define QUANT_R_MMQ 2 +struct block_a_cache { + uint32_t qs[4]; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q5_K) +#define QUANT_R_MMQ 1 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE_VEC2 dm; +}; +#elif defined(DATA_A_Q6_K) +#define QUANT_R_MMQ 1 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE_VEC2 d_scales; +}; +#endif + +struct block_b_cache +{ + int32_t qs[8]; + FLOAT_TYPE_VEC2 ds; +}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp index 12bd174579..8f67be9799 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp @@ -1,6 +1,9 @@ #version 450 #extension GL_EXT_control_flow_attributes : require +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif #include "types.glsl" @@ -84,35 +87,47 @@ void main() { } barrier(); - for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) { - [[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) { - const uint k = (tid % (w >> 1)) + - (D_STATE * (tid / (w >> 1))) + - j * D_STATE * (D_STATE / (w >> 1)); - if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) { - stateC[k] += stateC[k + (w >> 1)]; + [[unroll]] + for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) { + [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) { + const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w); + if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) { + stateC[k] += stateC[k + w]; } } barrier(); } - [[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) { + [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) { const uint idx = (tid % SUBGROUP_SIZE) + D_STATE * (tid / SUBGROUP_SIZE) + j * D_STATE * (D_STATE / SUBGROUP_SIZE); + const uint max_idx = SUBGROUP_SIZE - 1 + + D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) + + j * D_STATE * (D_STATE / SUBGROUP_SIZE); - uint lane = tid % SUBGROUP_SIZE; - - [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { - if (idx + offset < SPLIT_H * D_STATE) { - stateC[idx] += stateC[idx + offset]; + if (idx < SPLIT_H * D_STATE || + max_idx < SPLIT_H * D_STATE) { + float sc; +#if USE_SUBGROUP_ADD + sc = stateC[idx]; + sc = subgroupAdd(sc); +#else + [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { + if (idx + offset < SPLIT_H * D_STATE) { + stateC[idx] += stateC[idx + offset]; + } + barrier(); } - barrier(); - } + if (tid % SUBGROUP_SIZE == 0) { + sc = stateC[idx]; + } +#endif - if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) { - const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); - d[y_base_idx + i * stride_y + k] = stateC[idx]; + if (tid % SUBGROUP_SIZE == 0) { + const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); + d[y_base_idx + i * stride_y + k] = sc; + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index 9e56d5f8a3..bc1c278bf4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -11,6 +11,8 @@ layout (push_constant) uniform parameter { uint n_rows; uint n_expert_used; + float clamp_min; + float clamp_max; }; layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; @@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(constant_id = 0) const uint WARP_SIZE = 32; layout(constant_id = 1) const uint n_experts = 512; layout(constant_id = 2) const bool with_norm = true; +layout(constant_id = 3) const bool late_softmax = false; const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; @@ -25,6 +28,52 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];}; layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; layout (binding = 2, std430) writeonly buffer Ids {uint ids[];}; +const float INFINITY = 1.0 / 0.0; + +// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. +void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) { + float max_val = -INFINITY; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + max_val = max(max_val, vals[i]); + } + } + + max_val = subgroupMax(max_val); + + float sum = 0.f; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + const float val = exp(vals[i] - max_val); + vals[i] = val; + sum += val; + } else { + vals[i] = 0.f; + } + } + + sum = subgroupAdd(sum); + + const float inv_sum = 1.0f / sum; + + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + const uint idx = lane + i * WARP_SIZE; + const bool is_active = !use_limit || (idx < limit); + if (is_active) { + vals[i] *= inv_sum; + } + } +} + void main() { const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y; if (row >= n_rows) { @@ -35,43 +84,16 @@ void main() { const uint weights_offset = n_expert_used * row; const uint ids_offset = n_experts * row; - float logits_r[experts_per_thread]; - - const float INFINITY = 1.0 / 0.0; + float wt[experts_per_thread]; [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { - const uint expert = i + gl_LocalInvocationID.x; - logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY; + const uint expert = i + gl_LocalInvocationID.x; + wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; } - float max_val = logits_r[0]; - - [[unroll]] - for (int i = 1; i < experts_per_thread; i++) { - const float val = logits_r[i]; - max_val = max(val, max_val); - } - - max_val = subgroupMax(max_val); - - float wt[experts_per_thread]; - float tmp = 0.f; - - [[unroll]] - for (int i = 0; i < experts_per_thread; i++) { - const float val = logits_r[i]; - wt[i] = exp(val - max_val); - tmp += wt[i]; - } - - tmp = subgroupAdd(tmp); - - const float inv_sum = 1.0f / tmp; - - [[unroll]] - for (int i = 0; i < experts_per_thread; i++) { - wt[i] = wt[i] * inv_sum; + if (!late_softmax) { + softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false); } // at this point, each thread holds a portion of softmax, @@ -82,6 +104,11 @@ void main() { float output_weights[experts_per_thread]; + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + output_weights[i] = 0.f; + } + for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; uint max_expert = gl_LocalInvocationID.x; @@ -121,6 +148,7 @@ void main() { if (with_norm) { wt_sum = subgroupAdd(wt_sum); + wt_sum = clamp(wt_sum, clamp_min, clamp_max); const float inv_sum = 1.0f / wt_sum; [[unroll]] @@ -129,6 +157,10 @@ void main() { } } + if (late_softmax) { + softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true); + } + [[unroll]] for (uint i = 0; i < experts_per_thread; ++i) { uint idx = i * WARP_SIZE + gl_LocalInvocationID.x; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 2fa54ce51f..02578c77c4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -66,6 +66,7 @@ struct block_q4_0_packed16 #define QUANT_AUXF 1 #define A_TYPE block_q4_0 #define A_TYPE_PACKED16 block_q4_0_packed16 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q4_1 32 @@ -98,6 +99,7 @@ struct block_q4_1_packed32 #define A_TYPE block_q4_1 #define A_TYPE_PACKED16 block_q4_1_packed16 #define A_TYPE_PACKED32 block_q4_1_packed32 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q5_0 32 @@ -123,6 +125,7 @@ struct block_q5_0_packed16 #define QUANT_AUXF 1 #define A_TYPE block_q5_0 #define A_TYPE_PACKED16 block_q5_0_packed16 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q5_1 32 @@ -158,6 +161,7 @@ struct block_q5_1_packed32 #define A_TYPE block_q5_1 #define A_TYPE_PACKED16 block_q5_1_packed16 #define A_TYPE_PACKED32 block_q5_1_packed32 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q8_0 32 @@ -186,6 +190,7 @@ struct block_q8_0_packed32 #define A_TYPE block_q8_0 #define A_TYPE_PACKED16 block_q8_0_packed16 #define A_TYPE_PACKED32 block_q8_0_packed32 +#define DATA_A_QUANT_LEGACY #endif #define QUANT_K_Q8_1 32 @@ -226,21 +231,21 @@ struct block_q2_K { uint8_t scales[QUANT_K_Q2_K/16]; uint8_t qs[QUANT_K_Q2_K/4]; - f16vec2 d; + f16vec2 dm; }; struct block_q2_K_packed16 { uint16_t scales[QUANT_K_Q2_K/16/2]; uint16_t qs[QUANT_K_Q2_K/4/2]; - f16vec2 d; + f16vec2 dm; }; struct block_q2_K_packed32 { uint32_t scales[QUANT_K_Q2_K/16/4]; uint32_t qs[QUANT_K_Q2_K/4/4]; - f16vec2 d; + f16vec2 dm; }; #if defined(DATA_A_Q2_K) @@ -249,6 +254,8 @@ struct block_q2_K_packed32 #define A_TYPE block_q2_K #define A_TYPE_PACKED16 block_q2_K_packed16 #define A_TYPE_PACKED32 block_q2_K_packed32 +#define SCALES_PER_32 2 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q3_K 256 @@ -274,27 +281,28 @@ struct block_q3_K_packed16 #define QUANT_R 1 #define A_TYPE block_q3_K #define A_TYPE_PACKED16 block_q3_K_packed16 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q4_K 256 struct block_q4_K { - f16vec2 d; + f16vec2 dm; uint8_t scales[3*QUANT_K_Q4_K/64]; uint8_t qs[QUANT_K_Q4_K/2]; }; struct block_q4_K_packed16 { - f16vec2 d; + f16vec2 dm; uint16_t scales[3*QUANT_K_Q4_K/64/2]; uint16_t qs[QUANT_K_Q4_K/2/2]; }; struct block_q4_K_packed32 { - f16vec2 d; + f16vec2 dm; uint32_t scales[3*QUANT_K_Q4_K/64/4]; uint32_t qs[QUANT_K_Q4_K/2/4]; }; @@ -310,13 +318,14 @@ struct block_q4_K_packed128 #define A_TYPE block_q4_K #define A_TYPE_PACKED16 block_q4_K_packed16 #define A_TYPE_PACKED32 block_q4_K_packed32 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q5_K 256 struct block_q5_K { - f16vec2 d; + f16vec2 dm; uint8_t scales[12]; uint8_t qh[QUANT_K_Q5_K/8]; uint8_t qs[QUANT_K_Q5_K/2]; @@ -324,12 +333,20 @@ struct block_q5_K struct block_q5_K_packed16 { - f16vec2 d; + f16vec2 dm; uint16_t scales[12/2]; uint16_t qh[QUANT_K_Q5_K/8/2]; uint16_t qs[QUANT_K_Q5_K/2/2]; }; +struct block_q5_K_packed32 +{ + f16vec2 dm; + uint32_t scales[12/4]; + uint32_t qh[QUANT_K_Q5_K/8/4]; + uint32_t qs[QUANT_K_Q5_K/2/4]; +}; + struct block_q5_K_packed128 { uvec4 q5k[11]; @@ -340,6 +357,8 @@ struct block_q5_K_packed128 #define QUANT_R 1 #define A_TYPE block_q5_K #define A_TYPE_PACKED16 block_q5_K_packed16 +#define A_TYPE_PACKED32 block_q5_K_packed32 +#define DATA_A_QUANT_K #endif #define QUANT_K_Q6_K 256 @@ -356,7 +375,7 @@ struct block_q6_K_packed16 { uint16_t ql[QUANT_K_Q6_K/2/2]; uint16_t qh[QUANT_K_Q6_K/4/2]; - int8_t scales[QUANT_K_Q6_K/16]; + int16_t scales[QUANT_K_Q6_K/16/2]; float16_t d; }; @@ -365,6 +384,7 @@ struct block_q6_K_packed16 #define QUANT_R 1 #define A_TYPE block_q6_K #define A_TYPE_PACKED16 block_q6_K_packed16 +#define DATA_A_QUANT_K #endif // IQuants @@ -1363,18 +1383,11 @@ struct block_mxfp4 uint8_t qs[QUANT_K_MXFP4/2]; }; -//struct block_mxfp4_packed16 -//{ -// uint8_t e; -// uint16_t qs[QUANT_K_MXFP4/2/2]; -//}; - #if defined(DATA_A_MXFP4) #define QUANT_K QUANT_K_MXFP4 #define QUANT_R QUANT_R_MXFP4 #define QUANT_AUXF 1 #define A_TYPE block_mxfp4 -//#define A_TYPE_PACKED16 block_mxfp4_packed16 #endif #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) @@ -1397,12 +1410,12 @@ void init_iq_shmem(uvec3 wgsize) #endif #if defined(DATA_A_MXFP4) -const FLOAT_TYPE kvalues_mxfp4_const[16] = { - FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f), - FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f) +const int8_t kvalues_mxfp4_const[16] = { + int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12), + int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12), }; -shared FLOAT_TYPE kvalues_mxfp4[16]; +shared int8_t kvalues_mxfp4[16]; #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp index 154a2172d8..8670aad32c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -7,6 +7,7 @@ layout (push_constant) uniform parameter uint nb00; uint nb01; uint nb02; uint nb03; uint ne10; uint ne11; uint ne12; uint ne13; float sf0; float sf1; float sf2; float sf3; + float pixel_offset; } p; #include "types.glsl" @@ -19,7 +20,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; // from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag #define NEAREST 0 #define BILINEAR 1 -#define ALIGN_CORNERS (1 << 8) layout (constant_id = 0) const uint scale_mode = 0; @@ -52,7 +52,7 @@ float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) { float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { const ivec2 ne0 = ivec2(p.ne00, p.ne01); - const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5; + const vec2 c = (vec2(i10, i11) + p.pixel_offset) / vec2(p.sf0, p.sf1) - p.pixel_offset; const vec2 c0f = floor(c); const vec2 d = c - c0f; const ivec2 c0 = max(ivec2(c0f), 0); @@ -61,16 +61,6 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { return fetch_bilinear(c0, c1, d, i12, i13); } -float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) { - const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1); - const vec2 c0f = floor(c); - const vec2 d = c - c0f; - const ivec2 c0 = ivec2(c0f); - const ivec2 c1 = c0 + 1; - - return fetch_bilinear(c0, c1, d, i12, i13); -} - void main() { const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; @@ -91,9 +81,6 @@ void main() { case BILINEAR: result = interpolate_bilinear(i10, i11, i12, i13); break; - case BILINEAR | ALIGN_CORNERS: - result = interpolate_bilinear_align_corners(i10, i11, i12, i13); - break; } data_d[p.d_offset + idx] = D_TYPE(result); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 49bf6c764f..03fa016398 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -566,7 +566,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) { + // Integer dot mmq performs better with f32 accumulators + if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); } #endif @@ -574,7 +575,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } void process_shaders() { - std::map base_dict = {{"FLOAT_TYPE", "float"}}; + std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; // matmul for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) { @@ -916,7 +917,8 @@ void process_shaders() { string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); - string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}}); + string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}}); + string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}}); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d690f5650b..c40960e4f7 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -3079,6 +3079,7 @@ class VisionProjectorType: VOXTRAL = "voxtral" LFM2 = "lfm2" KIMIVL = "kimivl" + LIGHTONOCR = "lightonocr" # Items here are (block size, type size) diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 7111557bfd..5c6817109b 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -14,12 +14,12 @@ except ImportError: SentencePieceProcessor = None try: - from mistral_common.tokens.tokenizers.mistral import MistralTokenizer - from mistral_common.tokens.tokenizers.tekken import Tekkenizer - from mistral_common.tokens.tokenizers.utils import ( + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer # pyright: ignore[reportMissingImports] + from mistral_common.tokens.tokenizers.tekken import Tekkenizer # pyright: ignore[reportMissingImports] + from mistral_common.tokens.tokenizers.utils import ( # pyright: ignore[reportMissingImports] _filter_valid_tokenizer_files, ) - from mistral_common.tokens.tokenizers.sentencepiece import ( + from mistral_common.tokens.tokenizers.sentencepiece import ( # pyright: ignore[reportMissingImports] SentencePieceTokenizer, ) except ImportError: diff --git a/models/templates/llama-cpp-lfm2.jinja b/models/templates/llama-cpp-lfm2.jinja new file mode 100644 index 0000000000..b7921120bc --- /dev/null +++ b/models/templates/llama-cpp-lfm2.jinja @@ -0,0 +1,37 @@ +{{- bos_token -}} +{%- set system_prompt = "" -%} +{%- set ns = namespace(system_prompt="") -%} +{%- if messages[0]["role"] == "system" -%} + {%- set ns.system_prompt = messages[0]["content"] -%} + {%- set messages = messages[1:] -%} +{%- endif -%} +{%- if tools -%} + {%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "List of tools: <|tool_list_start|>[" -%} + {%- for tool in tools -%} + {%- if tool is not string -%} + {%- set tool = tool | tojson -%} + {%- endif -%} + {%- set ns.system_prompt = ns.system_prompt + tool -%} + {%- if not loop.last -%} + {%- set ns.system_prompt = ns.system_prompt + ", " -%} + {%- endif -%} + {%- endfor -%} + {%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%} +{%- endif -%} +{%- if ns.system_prompt -%} + {{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}} +{%- endif -%} +{%- for message in messages -%} + {{- "<|im_start|>" + message["role"] + "\n" -}} + {%- set content = message["content"] -%} + {%- if content is not string -%} + {%- set content = content | tojson -%} + {%- endif -%} + {%- if message["role"] == "tool" -%} + {%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%} + {%- endif -%} + {{- content + "<|im_end|>\n" -}} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} +{%- endif -%} diff --git a/requirements/requirements-convert_hf_to_gguf.txt b/requirements/requirements-convert_hf_to_gguf.txt index 90c98c3ffe..122b4788d9 100644 --- a/requirements/requirements-convert_hf_to_gguf.txt +++ b/requirements/requirements-convert_hf_to_gguf.txt @@ -1,5 +1,3 @@ -mistral-common>=1.8.3 - -r ./requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/scripts/snapdragon/adb/llama-cli.farf b/scripts/snapdragon/adb/llama-cli.farf new file mode 100644 index 0000000000..de84fe89ad --- /dev/null +++ b/scripts/snapdragon/adb/llama-cli.farf @@ -0,0 +1 @@ +0xffff diff --git a/scripts/snapdragon/adb/run-bench.sh b/scripts/snapdragon/adb/run-bench.sh new file mode 100755 index 0000000000..b2e651e749 --- /dev/null +++ b/scripts/snapdragon/adb/run-bench.sh @@ -0,0 +1,40 @@ +#!/bin/sh +# + +# Basedir on device +basedir=/data/local/tmp/llama.cpp + +branch=. +[ "$B" != "" ] && branch=$B + +adbserial= +[ "$S" != "" ] && adbserial="-s $S" + +model="Llama-3.2-3B-Instruct-Q4_0.gguf" +[ "$M" != "" ] && model="$M" + +device="HTP0" +[ "$D" != "" ] && device="$D" + +verbose="" +[ "$V" != "" ] && verbose="$V" + +opmask= +[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" + +nhvx= +[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" + +ndev= +[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV" + +set -x + +adb $adbserial shell " \ + cd $basedir; \ + LD_LIBRARY_PATH=$basedir/$branch/lib \ + ADSP_LIBRARY_PATH=$basedir/$branch/lib \ + $ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \ + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ + --batch-size 128 -ngl 99 $@ \ +" diff --git a/scripts/snapdragon/adb/run-cli.sh b/scripts/snapdragon/adb/run-cli.sh new file mode 100755 index 0000000000..ab8d6d49a2 --- /dev/null +++ b/scripts/snapdragon/adb/run-cli.sh @@ -0,0 +1,53 @@ +#!/bin/sh +# + +# Basedir on device +basedir=/data/local/tmp/llama.cpp + +cli_opts= + +branch=. +[ "$B" != "" ] && branch=$B + +adbserial= +[ "$S" != "" ] && adbserial="-s $S" + +model="Llama-3.2-3B-Instruct-Q4_0.gguf" +[ "$M" != "" ] && model="$M" + +device="HTP0" +[ "$D" != "" ] && device="$D" + +verbose= +[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" + +experimental= +[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E" + +sched= +[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v" + +profile= +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" + +opmask= +[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" + +nhvx= +[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" + +ndev= +[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV" + +set -x + +adb $adbserial shell " \ + cd $basedir; ulimit -c unlimited; \ + LD_LIBRARY_PATH=$basedir/$branch/lib \ + ADSP_LIBRARY_PATH=$basedir/$branch/lib \ + $verbose $experimental $sched $opmask $profile $nhvx $ndev \ + ./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \ + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ + --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \ + -ngl 99 --device $device $cli_opts $@ \ +" diff --git a/scripts/snapdragon/adb/run-tool.sh b/scripts/snapdragon/adb/run-tool.sh new file mode 100755 index 0000000000..bfc213e4c5 --- /dev/null +++ b/scripts/snapdragon/adb/run-tool.sh @@ -0,0 +1,51 @@ +#!/bin/sh +# + +# Basedir on device +basedir=/data/local/tmp/llama.cpp + +cli_opts= + +branch=. +[ "$B" != "" ] && branch=$B + +adbserial= +[ "$S" != "" ] && adbserial="-s $S" + +device="HTP0" +[ "$D" != "" ] && device="$D" + +verbose= +[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" + +experimental= +[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$V" + +sched= +[ "$SCHED" != "" ] && sched="GGML_SCHED_DEBUG=2" cli_opts="$cli_opts -v" + +profile= +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" + +opmask= +[ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" + +nhvx= +[ "$NHVX" != "" ] && nhvx="GGML_HEXAGON_NHVX=$NHVX" + +ndev= +[ "$NDEV" != "" ] && ndev="GGML_HEXAGON_NDEV=$NDEV" + +hb= +[ "$HB" != "" ] && hb="GGML_HEXAGON_HOSTBUF=$HB" + +set -x + +tool=$1; shift + +adb $adbserial shell " \ + cd $basedir; ulimit -c unlimited; \ + LD_LIBRARY_PATH=$basedir/$branch/lib \ + ADSP_LIBRARY_PATH=$basedir/$branch/lib \ + $verbose $experimental $sched $opmask $profile $nhvx $ndev $hb ./$branch/bin/$tool $@ \ +" diff --git a/scripts/snapdragon/qdc/readme.md b/scripts/snapdragon/qdc/readme.md new file mode 100644 index 0000000000..b92cf243aa --- /dev/null +++ b/scripts/snapdragon/qdc/readme.md @@ -0,0 +1 @@ +This directory includes pytest based scripts for running CI jobs on Qualcomm Device Cloud (QDC). diff --git a/scripts/snapdragon/qdc/requirements.txt b/scripts/snapdragon/qdc/requirements.txt new file mode 100644 index 0000000000..f04bd682ea --- /dev/null +++ b/scripts/snapdragon/qdc/requirements.txt @@ -0,0 +1,25 @@ +Appium-Python-Client==5.2.4 +attrs==25.4.0 +certifi==2025.10.5 +exceptiongroup==1.3.0 +h11==0.16.0 +idna==3.11 +iniconfig==2.1.0 +outcome==1.3.0.post0 +packaging==25.0 +pluggy==1.6.0 +Pygments==2.19.2 +PySocks==1.7.1 +pytest==8.4.2 +pytest-dependency==0.6.0 +selenium==4.36.0 +setuptools==80.9.0 +sniffio==1.3.1 +sortedcontainers==2.4.0 +tomli==2.3.0 +trio==0.31.0 +trio-websocket==0.12.2 +typing_extensions==4.15.0 +urllib3==2.5.0 +websocket-client==1.9.0 +wsproto==1.2.0 diff --git a/scripts/snapdragon/qdc/tests/test_bench.py b/scripts/snapdragon/qdc/tests/test_bench.py new file mode 100644 index 0000000000..651ab5b717 --- /dev/null +++ b/scripts/snapdragon/qdc/tests/test_bench.py @@ -0,0 +1,63 @@ +import pytest +import subprocess +import sys + +tmp_path='/data/local/tmp' +pkg_path=f'{tmp_path}/llama.cpp' +lib_path=f'{pkg_path}/lib' +bin_path=f'{pkg_path}/bin' + +model='../gguf/Llama-3.2-1B-Instruct-Q4_0.gguf' +cli_pref=f'cd {pkg_path} && LD_LIBRARY_PATH={lib_path} ADSP_LIBRARY_PATH={lib_path} {bin_path}' + + +def run_cmd(cmd): + p = subprocess.run(cmd, text = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT) + sys.stdout.write(p.stdout) + assert(p.returncode == 0) + + +@pytest.mark.dependency() +def test_install(): + run_cmd(['adb', 'push', 'llama.cpp', f'{tmp_path}']) + run_cmd(['adb', 'shell', f'chmod 755 {bin_path}/*']) + + +## Basic cli tests +def run_llama_cli(dev, opts): + prompt='what is the most popular cookie in the world?\nPlease provide a very brief bullet point summary.\nBegin your answer with **BEGIN**.' + opts = '--batch-size 128 -n 128 -no-cnv --seed 42 ' + opts + run_cmd(['adb', 'shell', f'{cli_pref}/llama-cli -m {model} --device {dev} -ngl 99 -t 4 {opts} -p "{prompt}"']) + + +@pytest.mark.dependency(depends=['test_install']) +def test_llama_cli_cpu(): + run_llama_cli('none', '-ctk q8_0 -ctv q8_0 -fa on') + + +@pytest.mark.dependency(depends=['test_install']) +def test_llama_cli_gpu(): + run_llama_cli('GPUOpenCL', '-fa on') + + +@pytest.mark.dependency(depends=['test_install']) +def test_llama_cli_npu(): + run_llama_cli('HTP0', '-ctk q8_0 -ctv q8_0 -fa on') + + +## Basic bench tests +def run_llama_bench(dev): + run_cmd(['adb', 'shell', f'{cli_pref}/llama-bench -m {model} --device {dev} -ngl 99 --batch-size 128 -t 4 -p 128 -n 32']) + + +@pytest.mark.dependency(depends=['test_install']) +def test_llama_bench_cpu(): + run_llama_bench('none') + + +def test_llama_bench_gpu(): + run_llama_bench('GPUOpenCL') + + +def test_llama_bench_npu(): + run_llama_bench('HTP0') diff --git a/src/llama-context.cpp b/src/llama-context.cpp index bd348bcad3..f6192a36e0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -268,9 +268,7 @@ llama_context::llama_context( if (pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); } - } - if (!hparams.vocab_only) { llama_memory_context_ptr mctx; if (memory) { LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); @@ -343,7 +341,14 @@ llama_context::llama_context( { auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); + if (pipeline_parallel) { + LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); + gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + } + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } } n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 41fa689437..112d195f29 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -810,6 +810,9 @@ ggml_tensor * llm_graph_context::build_ffn( GGML_ABORT("fatal error"); } + //expand here so that we can fuse ffn gate + ggml_build_forward_expand(gf, cur); + if (gate && type_gate == LLM_FFN_PAR) { cur = ggml_mul(ctx0, cur, tmp); cb(cur, "ffn_gate_par", il); @@ -1006,10 +1009,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] cb(weights_sum, "ffn_moe_weights_sum", il); - if (arch == LLM_ARCH_BAILINGMOE2) { - weights_sum = ggml_scale_bias(ctx0, weights_sum, 1.0, 1e-20); - cb(weights_sum, "ffn_moe_weights_sum_biased", il); - } + // Avoid division by zero, clamp to smallest number representable by F16 + weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY); + cb(weights_sum, "ffn_moe_weights_sum_clamped", il); weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] cb(weights, "ffn_moe_weights_norm", il); @@ -1091,6 +1093,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( GGML_ABORT("fatal error"); } + //expand here so that we can fuse ffn gate + ggml_build_forward_expand(gf, cur); + experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 736693e174..6d5dd6051e 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -37,8 +38,15 @@ llama_kv_cache::llama_kv_cache( const uint32_t n_layer_kv = hparams.n_layer_kv(); + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + std::map ctx_map; + // create a context for each buffer type - std::map ctx_map; auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { @@ -53,13 +61,12 @@ llama_kv_cache::llama_kv_cache( return nullptr; } - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); + ctx_map.emplace(buft, ctx); return ctx; } - return it->second; + return it->second.get(); }; GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max); @@ -167,11 +174,8 @@ llama_kv_cache::llama_kv_cache( } // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); if (!buf) { throw std::runtime_error("failed to allocate buffer for kv cache"); } @@ -179,7 +183,7 @@ llama_kv_cache::llama_kv_cache( LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); ggml_backend_buffer_clear(buf, 0); - bufs.emplace_back(buf); + ctxs_bufs.emplace_back(std::move(ctx), buf); } { @@ -203,7 +207,7 @@ void llama_kv_cache::clear(bool data) { } if (data) { - for (auto & buf : bufs) { + for (auto & [_, buf] : ctxs_bufs) { ggml_backend_buffer_clear(buf.get(), 0); } } @@ -472,8 +476,8 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { std::map llama_kv_cache::memory_breakdown() const { std::map ret; - for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { - ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + for (const auto & [_, buf] : ctxs_bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); } return ret; } @@ -957,10 +961,14 @@ bool llama_kv_cache::get_has_shift() const { uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; + // pad the n_kv value so that the graph remains constant across batches and can be reused + // note: this also helps some backends with performance (f.ex https://github.com/ggml-org/llama.cpp/pull/16812#issuecomment-3455112220) + const uint32_t n_pad_cur = std::max(n_pad, 256u); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { const auto & cells = v_cells[sinfo.strm[s]]; - result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); + result = std::max(std::min(cells.size(), std::max(n_pad_cur, GGML_PAD(cells.used_max_p1(), n_pad_cur))), result); } return result; @@ -1298,7 +1306,7 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch size_t llama_kv_cache::total_size() const { size_t size = 0; - for (const auto & buf : bufs) { + for (const auto & [_, buf] : ctxs_bufs) { size += ggml_backend_buffer_get_size(buf.get()); } @@ -2010,8 +2018,3 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } - -uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; -} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 85f0663d8c..bf7821c07c 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -19,8 +19,6 @@ struct llama_context; class llama_kv_cache : public llama_memory_i { public: - static uint32_t get_padding(const llama_cparams & cparams); - struct stream_copy_info { bool empty() const { assert(ssrc.size() == sdst.size()); @@ -217,8 +215,8 @@ private: // this is the SWA type of the cache - not to be confused with the model SWA type const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; - std::vector ctxs; - std::vector bufs; + // ggml contexts for the KV cache along with the allocated backend buffers: + std::vector> ctxs_bufs; // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) // note: this is not part of the KV state and it's only used to speed-up the find_slot() method diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index d67f5a5f47..276e1697d4 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -32,8 +33,15 @@ llama_memory_recurrent::llama_memory_recurrent( cells.clear(); cells.resize(mem_size); + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + std::map ctx_map; + // create a context for each buffer type - std::map ctx_map; auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { @@ -48,13 +56,12 @@ llama_memory_recurrent::llama_memory_recurrent( return nullptr; } - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); + ctx_map.emplace(buft, ctx); return ctx; } - return it->second; + return it->second.get(); }; r_l.resize(n_layer); @@ -93,17 +100,14 @@ llama_memory_recurrent::llama_memory_recurrent( } // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + for (auto & [buft, ctx] : ctx_map) { + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx.get(), buft); if (!buf) { throw std::runtime_error("failed to allocate buffer for rs cache"); } ggml_backend_buffer_clear(buf, 0); LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - bufs.emplace_back(buf); + ctxs_bufs.emplace_back(std::move(ctx), buf); } { @@ -129,7 +133,7 @@ void llama_memory_recurrent::clear(bool data) { used = 0; if (data) { - for (auto & buf : bufs) { + for (auto & [_, buf] : ctxs_bufs) { ggml_backend_buffer_clear(buf.get(), 0); } } @@ -364,8 +368,8 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; - for (const ggml_backend_buffer_ptr & buf_ptr : bufs) { - ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get()); + for (const auto & [_, buf] : ctxs_bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); } return ret; } @@ -662,7 +666,7 @@ bool llama_memory_recurrent::get_can_shift() const { size_t llama_memory_recurrent::total_size() const { size_t size = 0; - for (const auto & buf : bufs) { + for (const auto & [_, buf] : ctxs_bufs) { size += ggml_backend_buffer_get_size(buf.get()); } diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 077c6e3ce9..47f01d7391 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -109,8 +109,8 @@ private: const uint32_t n_seq_max = 1; - std::vector ctxs; - std::vector bufs; + // ggml contexts for the KV cache along with the allocated backend buffers: + std::vector> ctxs_bufs; size_t total_size() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 28fa4c8f53..fcb152c7de 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -15,7 +15,6 @@ #include #include -#include #include #include #include @@ -404,6 +403,19 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s // add the device default buffer type buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + // add the device extra buffer type (if any) + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts"); + + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(dev, *extra_bufts); + ++extra_bufts; + } + } + return buft_list; } @@ -425,7 +437,7 @@ struct llama_model::impl { llama_mlocks mlock_mmaps; // contexts where the model tensors metadata is stored as well ass the corresponding buffers: - std::vector> ctxs_bufs; + std::vector>> ctxs_bufs; buft_list_t cpu_buft_list; std::map gpu_buft_list; @@ -2241,7 +2253,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // define a comparator for the buft -> ctx map to ensure that the order is well-defined: struct ggml_backend_buft_comparator { bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { - return ggml_backend_buft_name(lhs) < ggml_backend_buft_name(rhs); + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; } }; std::map ctx_map; @@ -6226,7 +6238,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); - ggml_backend_buffer_t buf = nullptr; + std::vector bufs; if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { for (uint32_t idx = 0; idx < ml.files.size(); idx++) { // only the mmap region containing the tensors in the model is mapped to the backend buffer @@ -6239,15 +6251,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { continue; } const size_t max_size = ggml_get_max_tensor_size(ctx); - buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); + ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } + bufs.emplace_back(buf); buf_map.emplace(idx, buf); } } else { - buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (buf == nullptr) { throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } @@ -6257,11 +6270,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { mlock_buf->init (ggml_backend_buffer_get_base(buf)); mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); } + bufs.emplace_back(buf); for (uint32_t idx = 0; idx < ml.files.size(); idx++) { buf_map.emplace(idx, buf); } } - pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), buf); + pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); for (auto & buf : buf_map) { // indicate that this buffer contains weights @@ -6287,8 +6301,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // print memory requirements per buffer type - for (auto & [_, buf] : pimpl->ctxs_bufs) { - LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); + for (auto & [_, bufs] : pimpl->ctxs_bufs) { + for (auto & buf: bufs) { + LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", + __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); + } } // populate tensors_by_name @@ -6340,8 +6357,10 @@ size_t llama_model::n_devices() const { std::map llama_model::memory_breakdown() const { std::map ret; - for (const auto & [_, buf] : pimpl->ctxs_bufs) { - ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); + for (const auto & [_, bufs] : pimpl->ctxs_bufs) { + for (const auto & buf : bufs) { + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); + } } return ret; } @@ -6409,6 +6428,8 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); + LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); @@ -6509,8 +6530,6 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); - LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); @@ -18129,6 +18148,8 @@ struct llm_build_plamo2 : public llm_graph_context_mamba { cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); + res->t_embd = cur; + // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1); @@ -19501,6 +19522,7 @@ struct llm_build_smallthinker : public llm_graph_context{ cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); + res->t_embd = cur; // lm_head cur = build_lora_mm(model.output, cur); @@ -19796,7 +19818,7 @@ struct llm_build_apertus : public llm_graph_context { } }; -llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { +llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const { llama_memory_i * res; switch (arch) { @@ -19848,17 +19870,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } - const auto padding = llama_kv_cache::get_padding(cparams); - - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); - res = new llama_memory_hybrid( /* model */ *this, /* attn_type_k */ params.type_k, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, /* attn_kv_size */ cparams.n_ctx, - /* attn_n_pad */ padding, + /* attn_n_pad */ 1, /* attn_n_swa */ hparams.n_swa, /* attn_swa_type */ hparams.swa_type, /* recurrent_type_k */ GGML_TYPE_F32, @@ -19870,23 +19888,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* filter_attn */ std::move(filter_attn), /* filter_recr */ std::move(filter_recr)); } else { - const auto padding = llama_kv_cache::get_padding(cparams); - uint32_t n_ctx_per_stream = cparams.n_ctx; if (!cparams.kv_unified) { n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max; - n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); - - cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max; - } else { - n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); - - cparams.n_ctx = n_ctx_per_stream; } - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - llama_memory_i::layer_reuse_cb reuse = nullptr; if (arch == LLM_ARCH_GEMMA3N) { @@ -19913,7 +19920,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, n_ctx_per_stream, cparams.n_seq_max, cparams.n_ubatch, - padding, + 1, nullptr, reuse); } else { @@ -19928,7 +19935,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.kv_unified, n_ctx_per_stream, cparams.n_seq_max, - padding, + 1, hparams.n_swa, hparams.swa_type, nullptr, diff --git a/src/llama-model.h b/src/llama-model.h index 682ac831fb..3c5ad94cfe 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -503,9 +503,8 @@ struct llama_model { ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const; - // note: can mutate `cparams` // TODO: move this to new llm_arch_model_i interface - llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; + llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const; // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 991c625979..aee1730137 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -511,7 +511,7 @@ struct test_result { }; // Printer classes for different output formats -enum class test_status_t { NOT_SUPPORTED, OK, FAIL }; +enum class test_status_t { NOT_SUPPORTED, OK, FAIL, SKIPPED }; struct test_operation_info { std::string op_name; @@ -687,6 +687,8 @@ struct printer { virtual void print_backend_status(const backend_status_info & info) { (void) info; } virtual void print_overall_summary(const overall_summary_info & info) { (void) info; } + + virtual void print_failed_tests(const std::vector & failed_tests) { (void) failed_tests; } }; struct console_printer : public printer { @@ -804,6 +806,17 @@ struct console_printer : public printer { } } + void print_failed_tests(const std::vector & failed_tests) override { + if (failed_tests.empty()) { + return; + } + + printf("\nFailing tests:\n"); + for (const auto & test_name : failed_tests) { + printf(" %s\n", test_name.c_str()); + } + } + private: void print_test_console(const test_result & result) { printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str()); @@ -1056,6 +1069,8 @@ struct test_case { std::vector sentinels; + std::string current_op_name; + void add_sentinel(ggml_context * ctx) { if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) { return; @@ -1127,7 +1142,10 @@ struct test_case { } } - bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) { + test_status_t eval(ggml_backend_t backend1, + ggml_backend_t backend2, + const char * op_names_filter, + printer * output_printer) { mode = MODE_TEST; ggml_init_params params = { @@ -1144,11 +1162,12 @@ struct test_case { add_sentinel(ctx); ggml_tensor * out = build_graph(ctx); - std::string current_op_name = op_desc(out); + current_op_name = op_desc(out); + if (!matches_filter(out, op_names_filter)) { //printf(" %s: skipping\n", op_desc(out).c_str()); ggml_free(ctx); - return true; + return test_status_t::SKIPPED; } // check if the backends support the ops @@ -1172,7 +1191,7 @@ struct test_case { } ggml_free(ctx); - return true; + return test_status_t::NOT_SUPPORTED; } // post-graph sentinel @@ -1184,7 +1203,7 @@ struct test_case { if (buf == NULL) { printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1)); ggml_free(ctx); - return false; + return test_status_t::FAIL; } // build graph @@ -1289,7 +1308,7 @@ struct test_case { output_printer->print_test_result(result); } - return test_passed; + return test_passed ? test_status_t::OK : test_status_t::FAIL; } bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) { @@ -1306,7 +1325,7 @@ struct test_case { GGML_ASSERT(ctx); ggml_tensor * out = build_graph(ctx.get()); - std::string current_op_name = op_desc(out); + current_op_name = op_desc(out); if (!matches_filter(out, op_names_filter)) { //printf(" %s: skipping\n", op_desc(out).c_str()); return true; @@ -1435,8 +1454,9 @@ struct test_case { ggml_context_ptr ctx(ggml_init(params)); // smart ptr GGML_ASSERT(ctx); - ggml_tensor * out = build_graph(ctx.get()); - std::string current_op_name = op_desc(out); + ggml_tensor * out = build_graph(ctx.get()); + current_op_name = op_desc(out); + if (!matches_filter(out, op_names_filter)) { return true; } @@ -4712,6 +4732,7 @@ struct test_topk_moe: public test_case { out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens); ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens] + weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY); out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens] out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens); } @@ -4721,6 +4742,140 @@ struct test_topk_moe: public test_case { } }; +struct test_mul_mat_vec_fusion : public test_case { + const ggml_type type; + const ggml_glu_op glu_op; + const int64_t m; + const int64_t n; + const int64_t k; + const bool use_id; + const int n_mats; + const int n_used; + const bool b; // broadcast b matrix (only for use_id) + const bool with_bias; + const bool with_gate; + + test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k, + bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true) + : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) { + if (use_id) { + GGML_ASSERT(n_used <= n_mats); + } + } + + std::string vars() override { + return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate); + } + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "MUL_MAT_VEC_FUSION"; + } + + bool run_whole_graph() override { return true; } + + ggml_tensor * build_gate(ggml_context * ctx, ggml_tensor * ffn_gate, ggml_tensor * ffn_up) { + ggml_tensor * out = nullptr; + if (with_gate) { + if (glu_op == GGML_GLU_OP_SWIGLU_OAI) { + constexpr float alpha = 1.702f; + constexpr float limit = 7.0f; + out = ggml_swiglu_oai(ctx, ffn_gate, ffn_up, alpha, limit); + } else { + out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op); + } + } + return out; + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + if (!use_id) { + std::array ne = {k, m, 1, 1}; + std::array ne0 = {k, n, 1, 1}; + + ggml_tensor * cur = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data()); + ggml_tensor * gate = with_gate ? ggml_new_tensor(ctx, type, 4, ne0.data()) : nullptr; + ggml_tensor * up = ggml_new_tensor(ctx, type, 4, ne0.data()); + + ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur); + if (with_bias) { + std::array bias_ne = {ffn_up->ne[0], 1, 1, 1}; + ggml_tensor * up_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data()); + ffn_up = ggml_add(ctx, ffn_up, up_bias); + } + + ggml_tensor * ffn_gate = with_gate ? ggml_mul_mat(ctx, gate, cur) : nullptr; + if (with_bias && with_gate) { + std::array bias_ne = {ffn_gate->ne[0], 1, 1, 1}; + ggml_tensor * gate_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data()); + ffn_gate = ggml_add(ctx, ffn_gate, gate_bias); + } + + ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up; + ggml_set_name(out, "out"); + return out; + } else { + ggml_tensor * gates = ggml_new_tensor_3d(ctx, type, k, n, n_mats); + ggml_tensor * ups = ggml_new_tensor_3d(ctx, type, k, n, n_mats); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, m); + + if (n_used != n_mats) { + ids = ggml_view_2d(ctx, ids, n_used, m, ids->nb[1], 0); + } + + ggml_tensor * cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k, this->b ? 1 : n_used, m); + ggml_set_name(cur, "cur"); + + ggml_tensor * ffn_up = ggml_mul_mat_id(ctx, ups, cur, ids); + if (with_bias) { + ggml_tensor * up_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_up->ne[0], n_mats); + ffn_up = ggml_add_id(ctx, ffn_up, up_bias_param, ids); + } + + ggml_tensor * ffn_gate = with_gate? ggml_mul_mat_id(ctx, gates, cur, ids) : nullptr; + if (with_bias && with_gate) { + ggml_tensor * gate_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_gate->ne[0], n_mats); + ffn_gate = ggml_add_id(ctx, ffn_gate, gate_bias_param, ids); + } + + ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up; + ggml_set_name(out, "out"); + return out; + } + } + + void initialize_tensors(ggml_context * ctx) override { + if (!use_id) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t); + } + } else { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i % n_mats; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } + } + } + + double max_nmse_err() override { + return 5e-3; + } +}; + // GGML_OP_SUM struct test_sum : public test_case { const ggml_type type; @@ -6407,6 +6562,7 @@ static std::vector> make_test_cases_eval() { add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1}); add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1}); add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(type, {64, 262144, 1, 1}, {1, 1, 1, 1}); //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1}); //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1}); } @@ -6562,6 +6718,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1})); + + // test cases with large batch size + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {1536, 1}, {1, 1})); } } for (ggml_type type_a : other_types) { @@ -6890,6 +7049,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5, 7, 11}, mode)); } test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {1, 4, 3, 2}, {2, 8, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); + test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {4, 1, 3, 2}, {1, 1, 3, 2}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS)); test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum_rows()); @@ -6982,6 +7143,33 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3})); + for (ggml_type type : base_types) { + for (bool with_gate : {false, true}) { + for (bool use_id : {false, true}) { + for (bool b : {false, true}) { + if (!use_id && b) { + continue; + } + for (bool with_bias : {false, true}) { + if (!with_gate && !with_bias) { + continue; + } + for (ggml_glu_op glu_op : {GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU}) { + if (!with_bias && glu_op == GGML_GLU_OP_SWIGLU_OAI) { + continue; + } + if (!with_gate && glu_op != GGML_GLU_OP_SWIGLU) { + continue; + } + test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256, + use_id, 16, 8, b, with_bias, with_gate)); + } + } + } + } + } + } + for (bool with_norm : {false, true}) { test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm)); test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm)); @@ -7194,16 +7382,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } size_t n_ok = 0; + size_t tests_run = 0; + std::vector failed_tests; for (auto & test : test_cases) { - if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) { + test_status_t status = test->eval(backend, backend_cpu, op_names_filter, output_printer); + if (status == test_status_t::SKIPPED || status == test_status_t::NOT_SUPPORTED) { + continue; + } + tests_run++; + if (status == test_status_t::OK) { n_ok++; + } else if (status == test_status_t::FAIL) { + failed_tests.push_back(test->current_op_name + "(" + test->vars() + ")"); } } - output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false)); + output_printer->print_summary(test_summary_info(n_ok, tests_run, false)); + output_printer->print_failed_tests(failed_tests); ggml_backend_free(backend_cpu); - return n_ok == test_cases.size(); + return n_ok == tests_run; } if (mode == MODE_GRAD) { diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 52e23b5ac6..4a8ba849b3 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -16,6 +16,7 @@ #include #include +#include #include using json = nlohmann::ordered_json; @@ -2138,6 +2139,154 @@ static void test_template_output_parsers() { assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get())); } + { + // LFM2 format tests + auto tmpls = read_templates("models/templates/llama-cpp-lfm2.jinja"); + std::vector end_tokens{ "<|im_end|>" }; + + auto inputs_tools_forced_json_schema = std::invoke([&]() -> common_chat_templates_inputs { + common_chat_templates_inputs inputs; + inputs.messages = { + std::invoke([&]() -> common_chat_msg { + common_chat_msg msg; + msg.role = "system"; + msg.content = "force json schema.\n"; + return msg; + }), + message_user, + }; + inputs.tools = {special_function_tool}; + return inputs; + }); + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_no_tools); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format); + assert_equals(false, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + } + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_tools); + assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format); + assert_equals(false, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>system +List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|> +<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + assert_equals(true, params.grammar.empty()); + } + + { + auto params = common_chat_templates_apply(tmpls.get(), inputs_tools_forced_json_schema); + assert_equals(COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, params.format); + assert_equals(true, params.grammar_lazy); + assert_equals(std::string(R"(<|im_start|>system +List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|> +<|im_start|>user +Hey there!<|im_end|> +<|im_start|>assistant +)"), params.prompt); + assert_equals(false, params.grammar.empty()); + } + + // Test parsing regular content + assert_msg_equals(message_assist, + common_chat_parse( + "Hello, world!\nWhat's up?", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test single tool call with JSON format + common_chat_msg msg_single_tool_call; + msg_single_tool_call.role = "assistant"; + msg_single_tool_call.tool_calls.push_back({"special_function", "{\"arg1\":1}", ""}); + assert_msg_equals( + msg_single_tool_call, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with string argument + common_chat_msg msg_tool_call_string; + msg_tool_call_string.role = "assistant"; + msg_tool_call_string.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_tool_call_string, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with multiple arguments + common_chat_msg msg_multi_args; + msg_multi_args.role = "assistant"; + msg_multi_args.tool_calls.push_back({"calculate", "{\"x\":10,\"y\":20,\"operation\":\"add\"}", ""}); + assert_msg_equals( + msg_multi_args, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"calculate\", \"arguments\": {\"x\": 10, \"y\": 20, \"operation\": \"add\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test multiple tool calls in single array + common_chat_msg msg_multiple_tools; + msg_multiple_tools.role = "assistant"; + msg_multiple_tools.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + msg_multiple_tools.tool_calls.push_back({"get_time", "{\"timezone\":\"UTC\"}", ""}); + assert_msg_equals( + msg_multiple_tools, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}, {\"name\": \"get_time\", \"arguments\": {\"timezone\": \"UTC\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with content before + common_chat_msg msg_content_before_tool; + msg_content_before_tool.role = "assistant"; + msg_content_before_tool.content = "Let me check the weather for you."; + msg_content_before_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_content_before_tool, + common_chat_parse( + "Let me check the weather for you.<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with content after + common_chat_msg msg_content_after_tool; + msg_content_after_tool.role = "assistant"; + msg_content_after_tool.content = "Here's the result."; + msg_content_after_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_content_after_tool, + common_chat_parse( + "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>Here's the result.", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Test tool call with newlines (common in LLM output) + common_chat_msg msg_tool_call_newlines; + msg_tool_call_newlines.role = "assistant"; + msg_tool_call_newlines.tool_calls.push_back({"get_current_time", "{\"location\":\"Paris\"}", ""}); + assert_msg_equals( + msg_tool_call_newlines, + common_chat_parse( + "<|tool_call_start|>[{\n \"name\": \"get_current_time\",\n \"arguments\": {\n \"location\": \"Paris\"\n }\n}]<|tool_call_end|>", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); + + // Note: LFM2 uses JSON format for tool calls: [{"name": "...", "arguments": {...}}] + // Unlike other formats, LFM2 template does not render tool calls in conversation history, + // so we don't use test_templates() for tool call generation. Instead, the parsing tests + // above verify edge cases and format variations for the tool call output format. + } } diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 67df240c6f..8a55bc54ae 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1124,9 +1124,9 @@ static void test_all(const std::string & lang, std::function` option, each test can be run at a specified context depth, pr For a description of the other options, see the [main example](../main/README.md). +> [!NOTE] +> The measurements with `llama-bench` do not include the times for tokenization and for sampling. + ## Examples ### Text generation with different models @@ -131,7 +134,7 @@ $ ./llama-bench -n 0 -n 16 -p 64 -t 1,2,4,8,16,32 | llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 16 | pp 64 | 33.52 ± 0.03 | | llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 16 | tg 16 | 15.32 ± 0.05 | | llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | pp 64 | 59.00 ± 1.11 | -| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | tg 16 | 16.41 ± 0.79 || +| llama 7B mostly Q4_0 | 3.56 GiB | 6.74 B | CPU | 32 | tg 16 | 16.41 ± 0.79 | ### Different numbers of layers offloaded to the GPU diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1669fad99b..ad2108d179 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -139,6 +139,7 @@ enum projector_type { PROJECTOR_TYPE_VOXTRAL, PROJECTOR_TYPE_LFM2, PROJECTOR_TYPE_KIMIVL, + PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_UNKNOWN, }; @@ -161,6 +162,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, { PROJECTOR_TYPE_LFM2, "lfm2"}, { PROJECTOR_TYPE_KIMIVL, "kimivl"}, + { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f2abf88523..b44f0a3a28 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -171,7 +171,7 @@ struct clip_hparams { int32_t n_head; int32_t n_layer; // idefics3 - int32_t preproc_image_size = 0; + int32_t preproc_image_size = 0; // aka max_dimension int32_t proj_scale_factor = 0; float image_mean[3]; @@ -621,7 +621,7 @@ struct clip_graph { } // arrangement of the [IMG_BREAK] token - { + if (model.token_embd_img_break) { // not efficient, but works // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows] // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension @@ -2095,6 +2095,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 res = graph.build_siglip(); } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { res = graph.build_pixtral(); } break; @@ -2380,6 +2381,7 @@ struct clip_model_loader { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { hparams.rope_theta = 10000.0f; hparams.warmup_image_size = hparams.patch_size * 8; @@ -2722,6 +2724,15 @@ struct clip_model_loader { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); } break; + case PROJECTOR_TYPE_LIGHTONOCR: + { + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); + } break; case PROJECTOR_TYPE_ULTRAVOX: { model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); @@ -3210,8 +3221,8 @@ struct image_manipulation { return {0, 0}; } - float scale = std::min(1.0f, std::min(static_cast(max_dimension) / inp_size.width, - static_cast(max_dimension) / inp_size.height)); + float scale = std::min(static_cast(max_dimension) / inp_size.width, + static_cast(max_dimension) / inp_size.height); float target_width_f = static_cast(inp_size.width) * scale; float target_height_f = static_cast(inp_size.height) * scale; @@ -3374,7 +3385,7 @@ struct llava_uhd { // resize to overview size clip_image_u8_ptr resized_img(clip_image_u8_init()); - image_manipulation::bicubic_resize(*img, *resized_img, inst.overview_size.width, inst.overview_size.height); + image_manipulation::resize_and_pad_image(*img, *resized_img, inst.overview_size); output.push_back(std::move(resized_img)); if (inst.slices.empty()) { // no slices, just return the resized image @@ -3576,6 +3587,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737 const clip_image_size refined_size = image_manipulation::calc_size_preserved_ratio( original_size, params.image_size, params.preproc_image_size); + // LOG_INF("%s: original size: %d x %d, refined size: %d x %d\n", + // __func__, original_size.width, original_size.height, + // refined_size.width, refined_size.height); llava_uhd::slice_instructions instructions; instructions.overview_size = clip_image_size{params.image_size, params.image_size}; @@ -3586,6 +3600,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str }; for (int y = 0; y < refined_size.height; y += params.image_size) { for (int x = 0; x < refined_size.width; x += params.image_size) { + // LOG_INF("%s: adding slice at x=%d, y=%d\n", __func__, x, y); instructions.slices.push_back(llava_uhd::slice_coordinates{ /* x */x, /* y */y, @@ -3622,7 +3637,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(img_f32)); return true; - } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) { + } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL + || ctx->proj_type() == PROJECTOR_TYPE_LIGHTONOCR + ) { clip_image_u8 resized_image; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size); image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height); @@ -3865,12 +3882,17 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im n_patches = x_patch * y_patch; } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: { // dynamic size int n_merge = params.spatial_merge_size; int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1); int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1); - n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row + if (ctx->model.token_embd_img_break) { + n_patches = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row + } else { + n_patches = n_patches_y * n_patches_x; + } } break; case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: @@ -4247,6 +4269,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_LIGHTONOCR: { // set the 2D positions int n_patches_per_col = image_size_width / patch_size; @@ -4377,6 +4400,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_peg_0_b->ne[0]; case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_LIGHTONOCR: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_MLP_NORM: return ctx->model.mm_3_b->ne[0]; diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 5fde6ca0c3..fd1fb6581b 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -76,9 +76,11 @@ struct mtmd_cli_context { mtmd::bitmaps bitmaps; - // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another - // so here we don't need to keep track of chat history + // chat template common_chat_templates_ptr tmpls; + std::vector chat_history; + bool use_jinja = false; + // TODO: support for --system-prompt with /clear command // support for legacy templates (models not having EOT token) llama_tokens antiprompt_tokens; @@ -108,6 +110,8 @@ struct mtmd_cli_context { } tmpls = common_chat_templates_init(model, params.chat_template); + use_jinja = params.use_jinja; + chat_history.clear(); LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str()); init_vision_context(params); @@ -193,19 +197,33 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) { return 1; } } + + std::string generated_text = common_detokenize(ctx.lctx, generated_tokens); + common_chat_msg msg; + msg.role = "assistant"; + msg.content = generated_text; + ctx.chat_history.push_back(std::move(msg)); + return 0; } -static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) { - common_chat_templates_inputs tmpl_inputs; - tmpl_inputs.messages = {msg}; - tmpl_inputs.add_generation_prompt = true; - tmpl_inputs.use_jinja = false; // jinja is buggy here - auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs); - LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str()); +static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & new_msg) { + LOG_DBG("chat_add_and_format: new_msg.role='%s', new_msg.content='%s'\n", + new_msg.role.c_str(), new_msg.content.c_str()); + auto formatted = common_chat_format_single(ctx.tmpls.get(), ctx.chat_history, + new_msg, new_msg.role == "user", + ctx.use_jinja); + ctx.chat_history.push_back(new_msg); + return formatted; +} + +static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { + bool add_bos = ctx.chat_history.empty(); + auto formatted_chat = chat_add_and_format(ctx, msg); + LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str()); mtmd_input_text text; - text.text = formatted_chat.prompt.c_str(); + text.text = formatted_chat.c_str(); text.add_special = add_bos; text.parse_special = true; @@ -303,7 +321,7 @@ int main(int argc, char ** argv) { return 1; // error is already printed by libmtmd } } - if (eval_message(ctx, msg, true)) { + if (eval_message(ctx, msg)) { return 1; } if (!g_is_interrupted && generate_response(ctx, n_predict)) { @@ -322,7 +340,6 @@ int main(int argc, char ** argv) { LOG("\n /quit or /exit exit the program"); LOG("\n"); - bool is_first_msg = true; std::string content; while (!g_is_interrupted) { @@ -342,7 +359,8 @@ int main(int argc, char ** argv) { } if (line == "/clear") { ctx.n_past = 0; - llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS + ctx.chat_history.clear(); + llama_memory_clear(llama_get_memory(ctx.lctx), true); LOG("Chat history cleared\n\n"); continue; } @@ -367,7 +385,7 @@ int main(int argc, char ** argv) { common_chat_msg msg; msg.role = "user"; msg.content = content; - int ret = eval_message(ctx, msg, is_first_msg); + int ret = eval_message(ctx, msg); if (ret) { return 1; } @@ -376,7 +394,6 @@ int main(int argc, char ** argv) { return 1; } content.clear(); - is_first_msg = false; } } if (g_is_interrupted) LOG("\nInterrupted by user\n"); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 4d487581ae..3b901bfac8 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -275,6 +275,11 @@ struct mtmd_context { img_beg = ""; img_end = ""; + } else if (proj == PROJECTOR_TYPE_LIGHTONOCR) { + // <|im_start|> ... (image embeddings) ... <|im_end|> + img_beg = "<|im_start|>"; + img_end = "<|im_end|>"; + } } diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index dbdf7656a6..c227074636 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -70,6 +70,7 @@ add_test_vision "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0" add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" add_test_vision "ggml-org/LFM2-VL-450M-GGUF:Q8_0" add_test_vision "ggml-org/granite-docling-258M-GGUF:Q8_0" +add_test_vision "ggml-org/LightOnOCR-1B-1025-GGUF:Q8_0" add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0" add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M" @@ -138,7 +139,10 @@ for i in "${!arr_hf[@]}"; do echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log - if echo "$output" | grep -iq "new york"; then + # either contains "new york" or both "men" and "walk" + if echo "$output" | grep -iq "new york" \ + || (echo "$output" | grep -iq "men" && echo "$output" | grep -iq "walk") + then result="$prefix \033[32mOK\033[0m: $bin $hf" else result="$prefix \033[31mFAIL\033[0m: $bin $hf" diff --git a/tools/run/CMakeLists.txt b/tools/run/CMakeLists.txt index e52294ccc0..6ad7534e29 100644 --- a/tools/run/CMakeLists.txt +++ b/tools/run/CMakeLists.txt @@ -13,5 +13,11 @@ endif () if(LLAMA_TOOLS_INSTALL) install(TARGETS ${TARGET} RUNTIME) endif() + +if (CMAKE_SYSTEM_NAME MATCHES "AIX") + # AIX's flock() function comes from libbsd.a + target_link_libraries(${TARGET} PRIVATE -lbsd) +endif() + target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} ${LLAMA_RUN_EXTRA_LIBS}) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 7b56d87e43..026b53b286 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/public_legacy/json-schema-to-grammar.mjs b/tools/server/public_legacy/json-schema-to-grammar.mjs index 6f09529744..1d9dc5105e 100644 --- a/tools/server/public_legacy/json-schema-to-grammar.mjs +++ b/tools/server/public_legacy/json-schema-to-grammar.mjs @@ -345,10 +345,14 @@ export class SchemaConverter { const selectors = ref.split('#')[1].split('/').slice(1); for (const sel of selectors) { - if (!target || !(sel in target)) { + const selIndex = parseInt(sel, 10); + if (target && sel in target) { + target = target[sel]; + } else if (target && selIndex in target) { + target = target[selIndex]; + } else { throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`); } - target = target[sel]; } this._refs[ref] = target; @@ -594,7 +598,8 @@ export class SchemaConverter { } _resolveRef(ref) { - let refName = ref.split('/').pop(); + let refFragment = ref.split('#').pop(); + let refName = 'ref' + refFragment.replace(/[^a-zA-Z0-9-]+/g, '-'); if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) { this._refsBeingResolved.add(ref); const resolved = this._refs[ref]; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 8737fba124..cb794ab647 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2839,7 +2839,7 @@ struct server_context { slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } else if (slot.has_next_token) { + } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) { stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); send_text = stop_pos == std::string::npos; } @@ -2866,10 +2866,12 @@ struct server_context { // if context shifting is disabled, make sure that we don't run out of context if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { + slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx); + SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); } // check the limits @@ -2929,16 +2931,6 @@ struct server_context { } } - // if context shift is disabled, we stop when it reaches the context limit - if (slot.n_past >= slot.n_ctx) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); - } - if (llama_vocab_is_eog(vocab, result.tok)) { slot.stop = STOP_TYPE_EOS; slot.has_next_token = false; @@ -2946,19 +2938,6 @@ struct server_context { SLT_DBG(slot, "%s", "stopped by EOS\n"); } - const auto n_ctx_train = llama_model_n_ctx_train(model); - - if (slot.task->params.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; // stop prediction - - SLT_WRN(slot, - "n_predict (%d) is set for infinite generation. " - "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n", - slot.task->params.n_predict, n_ctx_train); - } - SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); return slot.has_next_token; // continue @@ -5714,6 +5693,7 @@ int main(int argc, char ** argv) { clean_up(); t.join(); + llama_memory_breakdown_print(ctx_server.ctx); return 0; } diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index 4adbbde64f..7b047b7b3b 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -45,7 +45,7 @@ def test_ctx_shift_enabled(): @pytest.mark.parametrize("n_predict,n_token_output,truncated", [ (64, 64, False), - (-1, 120, True), + (-1, 248, True), # 8 tokens prompt + 248 tokens generated = 256 tokens total ]) def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): global server diff --git a/tools/server/webui/src/routes/+page.svelte b/tools/server/webui/src/routes/+page.svelte index 2cd2d5c373..cd18dabccb 100644 --- a/tools/server/webui/src/routes/+page.svelte +++ b/tools/server/webui/src/routes/+page.svelte @@ -2,6 +2,9 @@ import { ChatScreen } from '$lib/components/app'; import { chatStore, isInitialized } from '$lib/stores/chat.svelte'; import { onMount } from 'svelte'; + import { page } from '$app/state'; + + let qParam = $derived(page.url.searchParams.get('q')); onMount(async () => { if (!isInitialized) { @@ -9,6 +12,11 @@ } chatStore.clearActiveConversation(); + + if (qParam !== null) { + await chatStore.createConversation(); + await chatStore.sendMessage(qParam); + } }); diff --git a/vendor/minja/minja.hpp b/vendor/minja/minja.hpp index dad75efbba..57b138add6 100644 --- a/vendor/minja/minja.hpp +++ b/vendor/minja/minja.hpp @@ -55,7 +55,7 @@ inline std::string normalize_newlines(const std::string & s) { } /* Values that behave roughly like in Python. */ -class Value : public std::enable_shared_from_this { +class Value { public: using CallableType = std::function &, ArgumentsValue &)>; using FilterType = std::function &, ArgumentsValue &)>; @@ -158,12 +158,14 @@ public: Value(const json & v) { if (v.is_object()) { auto object = std::make_shared(); + object->reserve(v.size()); for (auto it = v.begin(); it != v.end(); ++it) { - (*object)[it.key()] = it.value(); + object->emplace_back(it.key(), Value(it.value())); } object_ = std::move(object); } else if (v.is_array()) { auto array = std::make_shared(); + array->reserve(v.size()); for (const auto& item : v) { array->push_back(Value(item)); } @@ -610,7 +612,7 @@ static std::string error_location_suffix(const std::string & source, size_t pos) return out.str(); } -class Context : public std::enable_shared_from_this { +class Context { protected: Value values_; std::shared_ptr parent_; @@ -706,7 +708,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; class TemplateToken { public: - enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue }; + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue, Call, EndCall }; static std::string typeToString(Type t) { switch (t) { @@ -729,6 +731,8 @@ public: case Type::EndGeneration: return "endgeneration"; case Type::Break: return "break"; case Type::Continue: return "continue"; + case Type::Call: return "call"; + case Type::EndCall: return "endcall"; } return "Unknown"; } @@ -846,6 +850,17 @@ struct LoopControlTemplateToken : public TemplateToken { LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {} }; +struct CallTemplateToken : public TemplateToken { + std::shared_ptr expr; + CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) + : TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {} +}; + +struct EndCallTemplateToken : public TemplateToken { + EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) + : TemplateToken(Type::EndCall, loc, pre, post) {} +}; + class TemplateNode { Location location_; protected: @@ -1047,36 +1062,48 @@ public: } } } - void do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { + void do_render(std::ostringstream &, const std::shared_ptr & context) const override { if (!name) throw std::runtime_error("MacroNode.name is null"); if (!body) throw std::runtime_error("MacroNode.body is null"); - auto callable = Value::callable([&](const std::shared_ptr & context, ArgumentsValue & args) { - auto call_context = macro_context; + + // Use init-capture to avoid dangling 'this' pointer and circular references + auto callable = Value::callable([weak_context = std::weak_ptr(context), + name = name, params = params, body = body, + named_param_positions = named_param_positions] + (const std::shared_ptr & call_context, ArgumentsValue & args) { + auto context_locked = weak_context.lock(); + if (!context_locked) throw std::runtime_error("Macro context no longer valid"); + auto execution_context = Context::make(Value::object(), context_locked); + + if (call_context->contains("caller")) { + execution_context->set("caller", call_context->get("caller")); + } + std::vector param_set(params.size(), false); for (size_t i = 0, n = args.args.size(); i < n; i++) { auto & arg = args.args[i]; if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name()); param_set[i] = true; - auto & param_name = params[i].first; - call_context->set(param_name, arg); + const auto & param_name = params[i].first; + execution_context->set(param_name, arg); } for (auto & [arg_name, value] : args.kwargs) { auto it = named_param_positions.find(arg_name); if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); - call_context->set(arg_name, value); + execution_context->set(arg_name, value); param_set[it->second] = true; } // Set default values for parameters that were not passed for (size_t i = 0, n = params.size(); i < n; i++) { if (!param_set[i] && params[i].second != nullptr) { - auto val = params[i].second->evaluate(context); - call_context->set(params[i].first, val); + auto val = params[i].second->evaluate(call_context); + execution_context->set(params[i].first, val); } } - return body->render(call_context); + return body->render(execution_context); }); - macro_context->set(name->get_name(), callable); + context->set(name->get_name(), callable); } }; @@ -1611,6 +1638,44 @@ public: } }; +class CallNode : public TemplateNode { + std::shared_ptr expr; + std::shared_ptr body; + +public: + CallNode(const Location & loc, std::shared_ptr && e, std::shared_ptr && b) + : TemplateNode(loc), expr(std::move(e)), body(std::move(b)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("CallNode.expr is null"); + if (!body) throw std::runtime_error("CallNode.body is null"); + + // Use init-capture to avoid dangling 'this' pointer and circular references + auto caller = Value::callable([weak_context = std::weak_ptr(context), body=body] + (const std::shared_ptr &, ArgumentsValue &) -> Value { + auto context_locked = weak_context.lock(); + if (!context_locked) throw std::runtime_error("Caller context no longer valid"); + return Value(body->render(context_locked)); + }); + + context->set("caller", caller); + + auto call_expr = dynamic_cast(expr.get()); + if (!call_expr) { + throw std::runtime_error("Invalid call block syntax - expected function call"); + } + + Value function = call_expr->object->evaluate(context); + if (!function.is_callable()) { + throw std::runtime_error("Call target must be callable: " + function.dump()); + } + ArgumentsValue args = call_expr->args.evaluate(context); + + Value result = function.call(context, args); + out << result.to_str(); + } +}; + class FilterExpr : public Expression { std::vector> parts; public: @@ -2320,7 +2385,7 @@ private: static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); static std::regex block_open_regex(R"(^\{%([-~])?\s*)"); - static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue|call|endcall)\b)"); static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); static std::regex expr_close_regex(R"(\s*([-~])?\}\})"); static std::regex block_close_regex(R"(\s*([-~])?%\})"); @@ -2443,6 +2508,15 @@ private: } else if (keyword == "endmacro") { auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "call") { + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(expr))); + } else if (keyword == "endcall") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); } else if (keyword == "filter") { auto filter = parseExpression(); if (!filter) throw std::runtime_error("Expected expression in filter block"); @@ -2575,6 +2649,12 @@ private: throw unterminated(**start); } children.emplace_back(std::make_shared(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); + } else if (auto call_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndCall) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(call_token->expr), std::move(body))); } else if (auto filter_token = dynamic_cast(token.get())) { auto body = parseTemplate(begin, it, end); if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) { @@ -2588,6 +2668,7 @@ private: } else if (dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) + || dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get())