diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 51a3dc76e9..6c7ab71143 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -295,6 +295,7 @@ jobs: -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ -DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) - name: Build (no OpenMP) @@ -307,6 +308,7 @@ jobs: -DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ -DGGML_OPENMP=OFF + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) - name: Test diff --git a/.github/workflows/server-metal.yml b/.github/workflows/server-metal.yml new file mode 100644 index 0000000000..1d707bef44 --- /dev/null +++ b/.github/workflows/server-metal.yml @@ -0,0 +1,73 @@ +name: Server-Metal + +on: + workflow_dispatch: # allows manual triggering + inputs: + sha: + description: 'Commit SHA1 to build' + required: false + type: string + slow_tests: + description: 'Run slow tests' + required: true + type: boolean + push: + branches: + - master + paths: ['.github/workflows/server-metal.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*'] + +env: + LLAMA_LOG_COLORS: 1 + LLAMA_LOG_PREFIX: 1 + LLAMA_LOG_TIMESTAMPS: 1 + LLAMA_LOG_VERBOSITY: 10 + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + server-metal: + runs-on: [self-hosted, macOS, ARM64] + + name: server-metal (${{ matrix.wf_name }}) + strategy: + matrix: + build_type: [Release] + wf_name: ["GPUx1"] + include: + - build_type: Release + extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1" + wf_name: "GPUx1, backend-sampling" + - build_type: Release + extra_args: "GGML_METAL_DEVICES=2" + wf_name: "GPUx2" + - build_type: Release + extra_args: "GGML_METAL_DEVICES=2 LLAMA_ARG_BACKEND_SAMPLING=1" + wf_name: "GPUx2, backend-sampling" + fail-fast: false + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} + + - name: Build + id: cmake_build + run: | + cmake -B build -DGGML_SCHED_NO_REALLOC=ON + cmake --build build --config ${{ matrix.build_type }} -j $(sysctl -n hw.logicalcpu) --target llama-server + + - name: Tests + id: server_integration_tests + if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }} + run: | + cd tools/server/tests + python3 -m venv venv + source venv/bin/activate + pip install -r requirements.txt + export ${{ matrix.extra_args }} + pytest -v -x -m "not slow" diff --git a/.github/workflows/server-webui.yml b/.github/workflows/server-webui.yml index 6d1b617371..94899c9376 100644 --- a/.github/workflows/server-webui.yml +++ b/.github/workflows/server-webui.yml @@ -8,10 +8,6 @@ on: description: 'Commit SHA1 to build' required: false type: string - slow_tests: - description: 'Run slow tests' - required: true - type: boolean push: branches: - master @@ -101,119 +97,3 @@ jobs: if: ${{ always() && steps.playwright.conclusion == 'success' }} run: npm run test:e2e working-directory: tools/server/webui - - server-build: - runs-on: ubuntu-latest - - strategy: - matrix: - sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken - build_type: [RelWithDebInfo] - include: - - build_type: Release - sanitizer: "" - fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken - - steps: - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get -y install \ - build-essential \ - xxd \ - git \ - cmake \ - curl \ - wget \ - language-pack-en \ - libssl-dev - - - name: Clone - id: checkout - uses: actions/checkout@v6 - with: - fetch-depth: 0 - ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - - - name: Python setup - id: setup_python - uses: actions/setup-python@v6 - with: - python-version: '3.11' - - - name: Tests dependencies - id: test_dependencies - run: | - pip install -r tools/server/tests/requirements.txt - - - name: Setup Node.js for WebUI - uses: actions/setup-node@v6 - with: - node-version: "22" - cache: "npm" - cache-dependency-path: "tools/server/webui/package-lock.json" - - - name: Install WebUI dependencies - run: npm ci - working-directory: tools/server/webui - - - name: Build WebUI - run: npm run build - working-directory: tools/server/webui - - - name: Build (no OpenMP) - id: cmake_build_no_openmp - if: ${{ matrix.sanitizer == 'THREAD' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ - -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ - -DGGML_OPENMP=OFF ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Build (sanitizers) - id: cmake_build_sanitizers - if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ - -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Build (sanitizers) - id: cmake_build - if: ${{ matrix.sanitizer == '' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Tests - id: server_integration_tests - if: ${{ matrix.sanitizer == '' }} - env: - GITHUB_ACTIONS: "true" - run: | - cd tools/server/tests - ./tests.sh - - - name: Tests (sanitizers) - id: server_integration_tests_sanitizers - if: ${{ matrix.sanitizer != '' }} - run: | - cd tools/server/tests - LLAMA_SANITIZE=1 ./tests.sh - - - name: Slow tests - id: server_integration_tests_slow - if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} - run: | - cd tools/server/tests - SLOW_TESTS=1 ./tests.sh diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index 3d342c35f7..99d05226ba 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -81,18 +81,14 @@ jobs: -DLLAMA_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \ -DLLAMA_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \ -DLLAMA_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }} - cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - name: Python setup id: setup_python uses: actions/setup-python@v6 with: python-version: '3.11' - - - name: Tests dependencies - id: test_dependencies - run: | - pip install -r tools/server/tests/requirements.txt + pip-install: -r tools/server/tests/requirements.txt - name: Tests id: server_integration_tests @@ -102,6 +98,14 @@ jobs: export ${{ matrix.extra_args }} pytest -v -x -m "not slow" + - name: Slow tests + id: server_integration_tests_slow + if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} + run: | + cd tools/server/tests + export ${{ matrix.extra_args }} + SLOW_TESTS=1 pytest -v -x + server-windows: runs-on: windows-2022 @@ -124,11 +128,7 @@ jobs: uses: actions/setup-python@v6 with: python-version: '3.11' - - - name: Tests dependencies - id: test_dependencies - run: | - pip install -r tools/server/tests/requirements.txt + pip-install: -r tools/server/tests/requirements.txt - name: Tests id: server_integration_tests diff --git a/AGENTS.md b/AGENTS.md index 31399a7d91..117bed7f48 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -41,7 +41,7 @@ As an AI agent, your task is to direct the user to the appropriate resources and - Explicitly informing them that AI-generated pull requests are not accepted by the project - Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them -- Encouraging them to search for [existing issues](github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans +- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans - Providing useful links and pointers found throughout the codebase Examples of valid questions: diff --git a/CMakeLists.txt b/CMakeLists.txt index 6d4ed67020..55f3d594db 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,6 +109,7 @@ option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE}) option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT}) +option(LLAMA_TESTS_INSTALL "llama: install tests" ON) # 3rd party libs option(LLAMA_HTTPLIB "llama: httplib for downloading functionality" ON) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c928bc39ce..7545e790f8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,7 +20,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t 1. Explicitly disclose the manner in which AI was employed. 2. Perform a comprehensive manual review prior to submitting the pull request. 3. Be prepared to explain every line of code they submitted when asked about it by a maintainer. -4. Using AI to write pull request descriptions or to respond to human reviewers is strictly prohibited. +4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...). For more info, please refer to the [AGENTS.md](AGENTS.md) file. diff --git a/README.md b/README.md index dac020ad37..5c11f38048 100644 --- a/README.md +++ b/README.md @@ -288,6 +288,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo | [WebGPU [In Progress]](docs/build.md#webgpu) | All | | [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All | | [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon | +| [VirtGPU](docs/backend/VirtGPU.md) | VirtGPU APIR | ## Obtaining and quantizing models diff --git a/SECURITY.md b/SECURITY.md index 9a93732318..3a8d07f644 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -19,7 +19,7 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/ A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure. > [!IMPORTANT] -> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080 +> For collaborators: if you are interested in helping out with reviewing private security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080 ## Requirements diff --git a/build-xcframework.sh b/build-xcframework.sh index 0eec871139..e8af16211f 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -534,7 +534,7 @@ xcodebuild -create-xcframework \ -framework $(pwd)/build-ios-device/framework/llama.framework \ -debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \ -framework $(pwd)/build-macos/framework/llama.framework \ - -debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \ + -debug-symbols $(pwd)/build-macos/dSYMs/llama.dSYM \ -framework $(pwd)/build-visionos/framework/llama.framework \ -debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \ -framework $(pwd)/build-visionos-sim/framework/llama.framework \ diff --git a/common/arg.cpp b/common/arg.cpp index 5fbc9022c0..18f953a38e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, bool value) { params.kv_unified = value; } - ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH})); + ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL})); add_opt(common_arg( {"--context-shift"}, {"--no-context-shift"}, @@ -3437,16 +3437,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.ngram_size_m = value; } ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(common_arg( - {"--spec-ngram-check-rate"}, "N", - string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate), - [](common_params & params, int value) { - if (value < 1) { - throw std::invalid_argument("ngram check rate must be at least 1"); - } - params.speculative.ngram_check_rate = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--spec-ngram-min-hits"}, "N", string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits), diff --git a/common/chat.cpp b/common/chat.cpp index 2bf4632669..47a34d5822 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -380,15 +380,46 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa return msgs; } -json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text) { +static json render_message_to_json(const std::vector & msgs, const jinja::caps & c) { + if (!c.supports_string_content && !c.supports_typed_content) { + LOG_WRN("%s: Neither string content nor typed content is supported by the template. This is unexpected and may lead to issues.\n", __func__); + } + + bool only_string_accepted = c.supports_string_content && !c.supports_typed_content; + bool only_typed_accepted = !c.supports_string_content && c.supports_typed_content; + json messages = json::array(); for (const auto & msg : msgs) { - json jmsg = msg.to_json_oaicompat(concat_typed_text); - messages.push_back(jmsg); + if (only_string_accepted) { + json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ true); + messages.push_back(jmsg); + } else if (only_typed_accepted) { + json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false); + if (jmsg.at("content").is_string()) { + jmsg["content"] = json::array({ + json{ + {"type", "text"}, + {"text", jmsg.at("content").get()}, + } + }); + } + messages.push_back(jmsg); + } else { + json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false); + messages.push_back(jmsg); + } } return messages; } +// DEPRECATED: only used in tests +json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text) { + jinja::caps c; + c.supports_string_content = true; + c.supports_typed_content = !concat_typed_text; + return render_message_to_json(msgs, c); +} + std::vector common_chat_tools_parse_oaicompat(const json & tools) { std::vector result; @@ -3020,7 +3051,7 @@ static common_chat_params common_chat_templates_apply_jinja( : *tmpls->template_default; const auto & src = tmpl.source(); const auto & caps = tmpl.original_caps(); - params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); + params.messages = render_message_to_json(inputs.messages, tmpl.original_caps()); params.add_generation_prompt = inputs.add_generation_prompt; params.tool_choice = inputs.tool_choice; params.reasoning_format = inputs.reasoning_format; diff --git a/common/chat.h b/common/chat.h index 24aa4aab5c..1bf43f7261 100644 --- a/common/chat.h +++ b/common/chat.h @@ -240,6 +240,8 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates * // Parses a JSON array of messages in OpenAI's chat completion API format. std::vector common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages); + +// DEPRECATED: only used in tests nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); diff --git a/common/common.cpp b/common/common.cpp index 3aa396127c..ec15804c91 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1,7 +1,3 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "ggml.h" #include "gguf.h" @@ -9,12 +5,12 @@ #include "log.h" #include "llama.h" #include "sampling.h" +#include "unicode.h" #include #include #include #include -#include #include #include #include @@ -706,45 +702,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { return false; } - std::u32string filename_utf32; - try { -#if defined(__clang__) - // disable C++17 deprecation warning for std::codecvt_utf8 -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif + size_t offset = 0; + while (offset < filename.size()) { + utf8_parse_result result = parse_utf8_codepoint(filename, offset); - std::wstring_convert, char32_t> converter; - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - - filename_utf32 = converter.from_bytes(filename); - - // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used, - // or invalid encodings were encountered. Reject such attempts - std::string filename_reencoded = converter.to_bytes(filename_utf32); - if (filename_reencoded != filename) { + if (result.status != utf8_parse_result::SUCCESS) { return false; } - } catch (const std::exception &) { - return false; - } + uint32_t c = result.codepoint; - // Check for forbidden codepoints: - // - Control characters - // - Unicode equivalents of illegal characters - // - UTF-16 surrogate pairs - // - UTF-8 replacement character - // - Byte order mark (BOM) - // - Illegal characters: / \ : * ? " < > | - for (char32_t c : filename_utf32) { + if ((result.bytes_consumed == 2 && c < 0x80) || + (result.bytes_consumed == 3 && c < 0x800) || + (result.bytes_consumed == 4 && c < 0x10000)) { + return false; + } + + // Check for forbidden codepoints: + // - Control characters + // - Unicode equivalents of illegal characters + // - UTF-16 surrogate pairs + // - UTF-8 replacement character + // - Byte order mark (BOM) + // - Illegal characters: / \ : * ? " < > | if (c <= 0x1F // Control characters (C0) || c == 0x7F // Control characters (DEL) || (c >= 0x80 && c <= 0x9F) // Control characters (C1) @@ -752,6 +731,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { || c == 0x2215 // Division Slash (forward slash equivalent) || c == 0x2216 // Set Minus (backslash equivalent) || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs + || c > 0x10FFFF // Max Unicode limit || c == 0xFFFD // Replacement Character (UTF-8) || c == 0xFEFF // Byte Order Mark (BOM) || c == ':' || c == '*' // Illegal characters @@ -762,6 +742,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { // Subdirectories not allowed, reject path separators return false; } + offset += result.bytes_consumed; } // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename @@ -1469,66 +1450,6 @@ void common_batch_add( batch.n_tokens++; } -// -// Token utils -// - -size_t common_lcp(const llama_tokens & a, const llama_tokens & b) { - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} - - return i; -} - -size_t common_lcs(const llama_tokens & a, const llama_tokens & b) { - // check for empty sequences - if (a.empty() || b.empty()) { - return 0; - } - - // get the lengths of the input sequences - size_t a_len = a.size(); - size_t b_len = b.size(); - - // initialize the maximum length of the longest common subsequence (LCS) - size_t max_length = 0; - - // use two rows instead of a 2D matrix to optimize space - std::vector prev_row(b_len + 1, 0); - std::vector curr_row(b_len + 1, 0); - - // iterate through the elements of a - for (size_t i = 1; i <= a_len; i++) { - // iterate through the elements of b - for (size_t j = 1; j <= b_len; j++) { - // if elements at the current positions match - if (a[i - 1] == b[j - 1]) { - // if it's the first element of either sequences, set LCS length to 1 - if (i == 1 || j == 1) { - curr_row[j] = 1; - } else { - // increment LCS length by 1 compared to the previous element - curr_row[j] = prev_row[j - 1] + 1; - } - - // update max_length if necessary - if (curr_row[j] > max_length) { - max_length = curr_row[j]; - } - } else { - // reset LCS length if elements don't match - curr_row[j] = 0; - } - } - - // update the previous row for the next iteration - prev_row = curr_row; - } - - // return the maximum length of the LCS - return max_length; -} - // // Vocab utils // diff --git a/common/common.h b/common/common.h index 398ebb0960..804485fb19 100644 --- a/common/common.h +++ b/common/common.h @@ -269,7 +269,6 @@ struct common_params_speculative { uint16_t ngram_size_n = 12; // ngram size for lookup uint16_t ngram_size_m = 48; // mgram size for speculative tokens - uint16_t ngram_check_rate = 1; // check rate for ngram lookup uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed std::shared_ptr ngram_mod; @@ -780,16 +779,6 @@ void common_batch_add( const std::vector & seq_ids, bool logits); -// -// Token utils -// - -// longest common prefix -size_t common_lcp(const llama_tokens & a, const llama_tokens & b); - -// longet common subsequence -size_t common_lcs(const llama_tokens & a, const llama_tokens & b); - // // Vocab utils // diff --git a/common/download.cpp b/common/download.cpp index 57f29a23ba..8710438aa4 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -305,7 +305,10 @@ static bool common_pull_file(httplib::Client & cli, ); if (!res) { - LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1); + LOG_ERR("%s: download failed: %s (status: %d)\n", + __func__, + httplib::to_string(res.error()).c_str(), + res ? res->status : -1); return false; } diff --git a/common/jinja/caps.cpp b/common/jinja/caps.cpp index f27490f1fb..dbaaed500a 100644 --- a/common/jinja/caps.cpp +++ b/common/jinja/caps.cpp @@ -63,7 +63,8 @@ static void caps_print_stats(value & v, const std::string & path) { std::map caps::to_map() const { return { - {"requires_typed_content", requires_typed_content}, + {"supports_string_content", supports_string_content}, + {"supports_typed_content", supports_typed_content}, {"supports_tools", supports_tools}, {"supports_tool_calls", supports_tool_calls}, {"supports_parallel_tool_calls", supports_parallel_tool_calls}, @@ -89,7 +90,7 @@ caps caps_get(jinja::program & prog) { return v->stats.ops.find(op_name) != v->stats.ops.end(); }; - // case: typed content requirement + // case: typed content support caps_try_execute( prog, [&]() { @@ -105,12 +106,16 @@ caps caps_get(jinja::program & prog) { // tools return json{nullptr}; }, - [&](bool, value & messages, value &) { + [&](bool success, value & messages, value &) { auto & content = messages->at(0)->at("content"); caps_print_stats(content, "messages[0].content"); if (has_op(content, "selectattr") || has_op(content, "array_access")) { // accessed as an array - result.requires_typed_content = true; + result.supports_typed_content = true; + } + if (!success) { + // failed to execute with content as string + result.supports_string_content = false; } } ); diff --git a/common/jinja/caps.h b/common/jinja/caps.h index 77df117baa..e694e7bfaa 100644 --- a/common/jinja/caps.h +++ b/common/jinja/caps.h @@ -14,7 +14,9 @@ struct caps { bool supports_parallel_tool_calls = true; bool supports_preserve_reasoning = false; // support assistant message with reasoning_content - bool requires_typed_content = false; // default: use string content + // one of the 2 content capabilities must be true + bool supports_string_content = true; + bool supports_typed_content = false; // for reporting on server std::map to_map() const; diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index 4453d86e6d..cc012c892f 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -446,6 +446,12 @@ value for_statement::execute_impl(context & ctx) { value iterable_val = iter_expr->execute(scope); + // mark the variable being iterated as used for stats + if (ctx.is_get_stats) { + iterable_val->stats.used = true; + iterable_val->stats.ops.insert("array_access"); + } + if (iterable_val->is_undefined()) { JJ_DEBUG("%s", "For loop iterable is undefined, skipping loop"); iterable_val = mk_val(); diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp index c5b8fc75ed..ebf771a24a 100644 --- a/common/ngram-map.cpp +++ b/common/ngram-map.cpp @@ -231,10 +231,9 @@ void common_ngram_map_draft(common_ngram_map & map, GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len); } - // Only check every check_rate tokens to save compute - // i.e., perform check if (cur_len - idx_last_check) >= check_rate - if (map.idx_last_check + map.check_rate > cur_len) { - return; + if (map.idx_last_check > cur_len) { + // Should not happen because of common_ngram_map_begin(). + GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len); } map.idx_last_check = cur_len; @@ -462,7 +461,7 @@ void common_ngram_map_draft(common_ngram_map & map, slot_max = v; } } - // What is sum of the other occurences? + // What is sum of the other occurrences? uint32_t sum_occur = 0; for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { if (v == slot_max) { diff --git a/common/ngram-map.h b/common/ngram-map.h index 9668bd5a7c..d84e719151 100644 --- a/common/ngram-map.h +++ b/common/ngram-map.h @@ -24,7 +24,6 @@ struct common_ngram_simple_config { uint16_t size_ngram; // size of n-grams to lookup in self-mode uint16_t size_mgram; // size of m-grams to draft in self-mode - uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token }; // Searches for a n-gram in the history and checks whether a draft sequence should be generated. @@ -45,7 +44,7 @@ llama_tokens common_ngram_simple_draft( // statistics of a m-gram after a known n-gram struct common_ngram_map_value { size_t value_idx = 0; // index of value m-gram in token-history (0 if unused) - uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot) + uint16_t value_num = 0; // number of occurrences of this value m-gram after the key n-gram (0 in an unused values-slot) int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused) }; @@ -54,7 +53,7 @@ struct common_ngram_map_key { size_t key_idx; // index of key n-gram in token-history size_t stat_idx; // index of last token of stastistics computation (key_num, values) - uint16_t key_num; // number of occurences of this key n-gram in token-history + uint16_t key_num; // number of occurrences of this key n-gram in token-history common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key }; @@ -66,15 +65,14 @@ struct common_ngram_map { bool key_only; // true if only key n-grams are used, no values. std::vector keys; // key n-grams which occur several times in token-history - uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token uint16_t min_hits; // minimum number of key hits to consider a draft - bool show_key_map_stats = false; // true, if statitics of the key_map should be printed. + bool show_key_map_stats = false; // true, if statistics of the key_map should be printed. common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys, - uint16_t check_rate, uint16_t min_hits) + uint16_t min_hits) : size_key(sz_key), size_value(sz_value), key_only(only_keys), - check_rate(check_rate), min_hits(min_hits) { + min_hits(min_hits) { key_map.resize(COMMON_NGRAM_HASH_MAP_SIZE); // 2^18 hash entries, 0 entries if key_map shouldn't be used } diff --git a/common/speculative.cpp b/common/speculative.cpp index c99b19dbfd..3e68c38e49 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -113,13 +113,14 @@ static bool common_speculative_are_compatible( struct common_speculative_state { const enum common_speculative_type type; - // TODO: rename to n_call_draft, n_gen_drafts, n_acc_drafts, n_gen_tokens, n_acc_tokens - // TODO: add n_call_begin, n_call_accept - size_t drafts_call_count = 0; // number of times this implementation was called. - size_t drafts_generated_count = 0; // number of times a draft or part was generated by this implementation. - size_t drafts_accepted_count = 0; // number of times a draft or part was accepted by the target model. - size_t drafts_generated_tokens = 0; // number of tokens generated by this implementation. - size_t drafts_accepted_tokens = 0; // number of tokens accepted by the target model. + size_t n_call_begin = 0; // number of times this implementation was called for refresh. + size_t n_call_draft = 0; // number of times this implementation was called for generation. + size_t n_call_accept = 0; // number of times this implementation was called for accumulation. + + size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation. + size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model. + size_t n_gen_tokens = 0; // number of tokens generated by this implementation. + size_t n_acc_tokens = 0; // number of tokens accepted by the target model. // TODO: track performance of most recent calls const bool gen_perf = true; // whether to generate performance stats. @@ -465,8 +466,6 @@ struct common_speculative_state_eagle3 : public common_speculative_state { struct common_speculative_state_ngram_simple : public common_speculative_state { common_ngram_simple_config config; - uint16_t check_id = 0; // used to control the frequency of generating drafts - common_speculative_state_ngram_simple( enum common_speculative_type type, common_ngram_simple_config config) @@ -481,11 +480,6 @@ struct common_speculative_state_ngram_simple : public common_speculative_state { const llama_tokens & prompt_tgt, llama_token id_last, llama_tokens & result) override { - ++check_id; - if (check_id < config.check_rate) { - return; - } - check_id = 0; result = common_ngram_simple_draft(config, prompt_tgt, id_last); GGML_UNUSED(params); @@ -752,10 +746,9 @@ static common_ngram_map get_common_ngram_map(const common_speculative_config & c uint16_t size_key = config.params.ngram_size_n; uint16_t size_value = config.params.ngram_size_m; bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); - uint16_t check_rate = config.params.ngram_check_rate; uint16_t min_hits = config.params.ngram_min_hits; - return common_ngram_map(size_key, size_value, key_only, check_rate, min_hits); + return common_ngram_map(size_key, size_value, key_only, min_hits); } static common_speculative_state_ngram_cache create_state_ngram_cache( @@ -805,6 +798,42 @@ enum common_speculative_type common_speculative_type_from_name(const std::string return it->second; } +bool common_speculative_is_compat(llama_context * ctx_tgt) { + auto * mem = llama_get_memory(ctx_tgt); + if (mem == nullptr) { + return false; + } + + bool res = true; + + llama_memory_clear(mem, true); + + // eval 2 tokens to check if the context is compatible + std::vector tmp; + tmp.push_back(0); + tmp.push_back(0); + + int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size())); + if (ret != 0) { + LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); + res = false; + goto done; + } + + // try to remove the last tokens + if (!llama_memory_seq_rm(mem, 0, 1, -1)) { + LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + res = false; + goto done; + } + +done: + llama_memory_clear(mem, true); + llama_synchronize(ctx_tgt); + + return res; +} + // initialization of the speculative decoding system // common_speculative * common_speculative_init( @@ -895,12 +924,10 @@ common_speculative * common_speculative_init( uint16_t ngram_size_key = ngram_map.size_key; uint16_t mgram_size_value = ngram_map.size_value; - uint16_t check_rate = ngram_map.check_rate; auto config_simple = common_ngram_simple_config { /* .size_ngram = */ ngram_size_key, - /* .size_mgram = */ mgram_size_value, - /* .check_rate = */ check_rate + /* .size_mgram = */ mgram_size_value }; auto state = std::make_unique( /* .type = */ config.type, @@ -961,6 +988,7 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); impl->begin(prompt); + impl->n_call_begin++; } } @@ -977,17 +1005,17 @@ llama_tokens common_speculative_draft( { common_time_meas tm(impl->t_draft_us, !impl->gen_perf); impl->draft(params, prompt_tgt, id_last, result); - impl->drafts_call_count++; + impl->n_call_draft++; } if (!result.empty()) { LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), - impl.get()->drafts_call_count, result.size()); + impl.get()->n_call_draft, result.size()); spec->curr_impl = impl.get(); // set current implementation for stats - impl->drafts_generated_count++; - impl->drafts_generated_tokens += result.size(); + impl->n_gen_drafts++; + impl->n_gen_tokens += result.size(); break; // We have a draft, so break out of the loop and return it. } @@ -1008,11 +1036,12 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { { common_time_meas tm(impl->t_accept_us, !impl->gen_perf); if (n_accepted > 0) { - impl->drafts_accepted_count++; - impl->drafts_accepted_tokens += n_accepted; + impl->n_acc_drafts++; + impl->n_acc_tokens += n_accepted; } impl->accept(n_accepted); + impl->n_call_accept++; } } @@ -1033,13 +1062,13 @@ void common_speculative_print_stats(const common_speculative * spec) { str_perf = ""; } - LOG_INF("statistics %s: #calls = %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n", + LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n", common_speculative_type_to_str(impl->type).c_str(), - impl->drafts_call_count, - impl->drafts_generated_count, - impl->drafts_accepted_count, - impl->drafts_generated_tokens, - impl->drafts_accepted_tokens, + impl->n_call_begin, impl->n_call_draft, impl->n_call_accept, + impl->n_gen_drafts, + impl->n_acc_drafts, + impl->n_gen_tokens, + impl->n_acc_tokens, str_perf.c_str()); } } diff --git a/common/speculative.h b/common/speculative.h index 76fe6bb7bc..876cde3d18 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -14,6 +14,10 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); +// check if the llama_context is compatible for speculative decoding +// note: clears the memory of the context +bool common_speculative_is_compat(llama_context * ctx_tgt); + common_speculative * common_speculative_init( common_params_speculative & params, llama_context * ctx_tgt); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c167de8a46..2afaf85fb8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -160,8 +160,6 @@ class ModelBase: self.ftype = gguf.LlamaFileType.MOSTLY_F16 logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16") - self.dequant_model() - # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @@ -527,6 +525,8 @@ class ModelBase: return () def prepare_tensors(self): + self.dequant_model() + # Handle empty tensor_map for models with block_count=0 (like MobileNetV5) if self.tensor_map.mapping: max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") @@ -920,7 +920,7 @@ class TextModel(ModelBase): self.gguf_writer.add_expert_group_used_count(n_group_used) logger.info(f"gguf: expert groups used count = {n_group_used}") - if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func", "moe_router_activation_func"], optional=True)) is not None: + if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func", "moe_router_activation", "moe_router_activation_func"], optional=True)) is not None: if score_func == "sigmoid": self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) elif score_func == "softmax": @@ -1261,6 +1261,9 @@ class TextModel(ModelBase): if chkhsh == "6c81ce329e0802883b22eabab0d3fa48357337ef1ecb45443828bf1f6254833f": # ref: https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B res = "exaone-moe" + if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4": + # ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct + res = "qwen35" if res is None: logger.warning("\n") @@ -1812,7 +1815,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1867,7 +1870,15 @@ class MmprojModel(ModelBase): preprocessor_config_path = self.dir_model / "preprocessor_config.json" if preprocessor_config_path.is_file(): with open(preprocessor_config_path, "r", encoding="utf-8") as f: - self.preprocessor_config = json.load(f) + cfg = json.load(f) + # move media_proc_cfg to root level for compat + if "media_proc_cfg" in cfg: + cfg = { + **cfg, + **cfg["media_proc_cfg"], + } + # merge configs + self.preprocessor_config = {**self.preprocessor_config, **cfg} # prefer processor_config.json if possible processor_config_path = self.dir_model / "processor_config.json" @@ -1916,10 +1927,10 @@ class MmprojModel(ModelBase): self.image_size = self.find_vparam(["image_size"]) self.gguf_writer.add_vision_image_size(self.image_size) self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) - self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) - self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"])) + self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"])) + self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"])) self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) - self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"])) + self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"])) # preprocessor config image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] @@ -4109,37 +4120,29 @@ class Qwen2MoeModel(TextModel): # Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"): mapped = f"{name}.weight" if not name.endswith(".weight") else name - # Input: (n_expert=128, n_ff_exp=768, n_embd=2048) - # Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128} - # Need PyTorch: (128, 2048, 768) [reversed of GGML] - # So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768) - permuted = data_torch.permute(0, 2, 1).contiguous() - yield from super().modify_tensors(permuted, mapped, bid) + # HF: [n_expert, n_embd, n_ff] -> GGML: {n_ff, n_embd, n_expert} + yield from super().modify_tensors(data_torch, mapped, bid) return if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"): - if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0: + if data_torch.ndim < 3 or data_torch.shape[-2] % 2 != 0: raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}") - split_dim = data_torch.shape[-1] // 2 - gate = data_torch[..., :split_dim].contiguous() - up = data_torch[..., split_dim:].contiguous() - # Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768) - # Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128} - # Need PyTorch: (128, 768, 2048) [reversed of GGML] - # So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048) - base_name = name.removesuffix(".weight") - base = base_name.rsplit('.', 1)[0] - mapped_gate = f"{base}.gate_proj.weight" - mapped_up = f"{base}.up_proj.weight" - perm_gate = gate.permute(0, 2, 1).contiguous() - perm_up = up.permute(0, 2, 1).contiguous() - yield from super().modify_tensors(perm_gate, mapped_gate, bid) - yield from super().modify_tensors(perm_up, mapped_up, bid) + # HF: [n_expert, 2*n_ff, n_embd] -> split on dim=-2 + n_ff = data_torch.shape[-2] // 2 + gate = data_torch[..., :n_ff, :].contiguous() + up = data_torch[..., n_ff:, :].contiguous() + # gate/up: [n_expert, n_ff, n_embd] -> GGML: {n_embd, n_ff, n_expert} + base_name = name.removesuffix(".weight").removesuffix(".gate_up_proj") + mapped_gate = f"{base_name}.gate_proj.weight" + mapped_up = f"{base_name}.up_proj.weight" + yield from super().modify_tensors(gate, mapped_gate, bid) + yield from super().modify_tensors(up, mapped_up, bid) return if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"): # skip visual tensors return + if name.find("experts") != -1: n_experts = self.hparams["num_experts"] assert bid is not None @@ -4295,6 +4298,7 @@ class Qwen3NextModel(Qwen2MoeModel): self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"]) self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"]) self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"]) + self.gguf_writer.add_full_attention_interval(self.hparams.get("full_attention_interval", 4)) if (rope_dim := self.hparams.get("head_dim")) is None: rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25))) @@ -4359,7 +4363,7 @@ class RND1Model(Qwen2MoeModel): self.gguf_writer.add_mask_token_id(mask_token_id) -@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration") +@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration") class Qwen3VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -4405,6 +4409,10 @@ class Qwen3VLVisionModel(MmprojModel): if name.startswith("model.language_model.") or name.startswith("lm_head."): return + # Skip MTP tensors + if name.startswith("mtp."): + return + if name.startswith("model.visual."): name = name.replace("model.visual.", "visual.", 1) @@ -4535,9 +4543,125 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel): if name.startswith("model.visual."): return + # Qwen3VL has transposed packed tensors, so we treat it differently from general Qwen2MoE packed tensors + if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"): + name = name.replace("language_model.", "") + mapped = f"{name}.weight" if not name.endswith(".weight") else name + permuted = data_torch.permute(0, 2, 1).contiguous() + yield from ModelBase.modify_tensors(self, permuted, mapped, bid) + return + + if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"): + name = name.replace("language_model.", "") + if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0: + raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}") + split_dim = data_torch.shape[-1] // 2 + gate = data_torch[..., :split_dim].contiguous() + up = data_torch[..., split_dim:].contiguous() + # Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768) + # Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128} + # Need PyTorch: (128, 768, 2048) [reversed of GGML] + # So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048) + base_name = name.removesuffix(".weight") + base = base_name.rsplit('.', 1)[0] + mapped_gate = f"{base}.gate_proj.weight" + mapped_up = f"{base}.up_proj.weight" + perm_gate = gate.permute(0, 2, 1).contiguous() + perm_up = up.permute(0, 2, 1).contiguous() + yield from ModelBase.modify_tensors(self, perm_gate, mapped_gate, bid) + yield from ModelBase.modify_tensors(self, perm_up, mapped_up, bid) + return + yield from super().modify_tensors(data_torch, name, bid) +class _LinearAttentionVReorderBase(Qwen3NextModel): + model_arch = gguf.MODEL_ARCH.QWEN3NEXT # overridden by subclasses + """reorders V heads from grouped to tiled order for ggml broadcast + + see https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 + + Linear attention may has num_k_heads < num_v_heads. The HF weights store + V heads grouped by K head: [G0_v0..v{r-1}, G1_v0..v{r-1}, ...]. + ggml binary ops use tiled broadcast: [K0, K1, ..., K0, K1, ...]. + We reorder V heads to tiled order so ggml_repeat can replace the expensive + interleaved repeat: [G0_v0, G1_v0, ..., G0_v1, G1_v1, ...]. + """ + + @staticmethod + def _reorder_v_heads(tensor: Tensor, dim: int, num_k_heads: int, num_v_per_k: int, head_dim: int) -> Tensor: + """Reorder V heads from grouped (by K head) to tiled order along the given dimension.""" + shape = list(tensor.shape) + if dim < 0: + dim += len(shape) + new_shape = shape[:dim] + [num_k_heads, num_v_per_k, head_dim] + shape[dim + 1:] + tensor = tensor.reshape(*new_shape) + perm = list(range(len(new_shape))) + perm[dim], perm[dim + 1] = perm[dim + 1], perm[dim] + return tensor.permute(*perm).contiguous().reshape(*shape) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + num_k_heads = self.hparams.get("linear_num_key_heads", 0) + num_v_heads = self.hparams.get("linear_num_value_heads", 0) + + if num_k_heads > 0 and num_v_heads > 0 and num_k_heads != num_v_heads and "linear_attn." in name: + head_k_dim = self.hparams["linear_key_head_dim"] + head_v_dim = self.hparams["linear_value_head_dim"] + num_v_per_k = num_v_heads // num_k_heads + + if ".in_proj_qkv." in name: + # QKV weight: reorder only the V rows + q_dim = head_k_dim * num_k_heads + k_dim = head_k_dim * num_k_heads + q = data_torch[:q_dim] + k = data_torch[q_dim:q_dim + k_dim] + v = data_torch[q_dim + k_dim:] + v = self._reorder_v_heads(v, 0, num_k_heads, num_v_per_k, head_v_dim) + data_torch = torch.cat([q, k, v], dim=0) + + elif ".in_proj_z." in name: + # Z gate weight: reorder rows (num_v_heads * head_v_dim) + data_torch = self._reorder_v_heads(data_torch, 0, num_k_heads, num_v_per_k, head_v_dim) + + elif ".in_proj_b." in name or ".in_proj_a." in name: + # Beta/Alpha weight: reorder rows (num_v_heads, head_dim=1) + data_torch = self._reorder_v_heads(data_torch, 0, num_k_heads, num_v_per_k, 1) + + elif ".A_log" in name or ".dt_bias" in name or ".dt_proj" in name: + # A_log / dt_bias: 1D parameters with num_v_heads elements + if data_torch.ndim == 1: + data_torch = self._reorder_v_heads( + data_torch.unsqueeze(-1), 0, num_k_heads, num_v_per_k, 1 + ).squeeze(-1) + else: + data_torch = self._reorder_v_heads(data_torch, -1, num_k_heads, num_v_per_k, 1) + + elif ".conv1d" in name: + # Conv1d kernel: reorder only the V channel portion + data = data_torch.squeeze() + qk_channels = head_k_dim * num_k_heads * 2 + qk_part = data[:qk_channels] + v_part = data[qk_channels:] + v_part = self._reorder_v_heads(v_part, 0, num_k_heads, num_v_per_k, head_v_dim) + data_torch = torch.cat([qk_part, v_part], dim=0) + + elif ".out_proj." in name: + # Out projection weight: reorder columns (input dimension) + data_torch = self._reorder_v_heads(data_torch, 1, num_k_heads, num_v_per_k, head_v_dim) + + yield from super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("Qwen3_5ForConditionalGeneration") +class Qwen3_5TextModel(_LinearAttentionVReorderBase): + model_arch = gguf.MODEL_ARCH.QWEN35 + + +@ModelBase.register("Qwen3_5MoeForConditionalGeneration") +class Qwen3_5MoeTextModel(_LinearAttentionVReorderBase): + model_arch = gguf.MODEL_ARCH.QWEN35MOE + + @ModelBase.register("GPT2LMHeadModel") class GPT2Model(TextModel): model_arch = gguf.MODEL_ARCH.GPT2 @@ -7579,6 +7703,7 @@ class DeepseekModel(TextModel): "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "KimiVLForConditionalGeneration", + "KimiK25ForConditionalGeneration", "YoutuForCausalLM", "YoutuVLForConditionalGeneration", ) @@ -7697,8 +7822,8 @@ class DeepseekV2Model(TextModel): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # skip vision tensors and remove "language_model." for Kimi-VL - if "vision_tower" in name or "multi_modal_projector" in name: + # skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5 + if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name: return if name.startswith("siglip2.") or name.startswith("merger."): return @@ -7912,6 +8037,135 @@ class MimoV2Model(TextModel): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("Step3p5ForCausalLM") +class Step35Model(TextModel): + model_arch = gguf.MODEL_ARCH.STEP35 + + def set_gguf_parameters(self): + rope_theta = self.hparams.get("rope_theta") + if isinstance(rope_theta, list): + self.hparams["rope_theta"] = float(rope_theta[0]) + self.hparams["local_rope_theta"] = float(rope_theta[1]) + self.rope_parameters["rope_theta"] = self.hparams["rope_theta"] + self.rope_parameters["sliding_attention"] = {"rope_theta": self.hparams["local_rope_theta"]} + + super().set_gguf_parameters() + + layer_types = self.hparams.get("layer_types") or [] + partial_rotary_factors = self.hparams.get("partial_rotary_factors") or [] + attn_other = self.hparams.get("attention_other_setting") or {} + + n_head_base = self.hparams["num_attention_heads"] + n_kv_base = self.hparams["num_attention_groups"] + + n_head_swa = attn_other.get("num_attention_heads", n_head_base) + n_kv_swa = attn_other.get("num_attention_groups", n_kv_base) + + layer_types = layer_types[: self.block_count] + partial_rotary_factors = partial_rotary_factors[: self.block_count] + assert [1.0 if lt == "sliding_attention" else 0.5 for lt in layer_types] == partial_rotary_factors + head_arr = [n_head_swa if lt == "sliding_attention" else n_head_base for lt in layer_types] + kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types] + swa_pat = [lt == "sliding_attention" for lt in layer_types] + + self.gguf_writer.add_head_count(head_arr) + self.gguf_writer.add_head_count_kv(kv_arr) + + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + self.gguf_writer.add_sliding_window_pattern(swa_pat) + + self.gguf_writer.add_value_length(self.hparams["head_dim"]) + + # MoE params + self.gguf_writer.add_expert_count(self.hparams["moe_num_experts"]) + self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"]) + self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["share_expert_dim"]) + + if (moe_router_scaling_factor := self.hparams.get("moe_router_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(moe_router_scaling_factor) + if (norm_expert_weight := self.hparams.get("norm_expert_weight")) is not None: + self.gguf_writer.add_expert_weights_norm(norm_expert_weight) + + # leading dense blocks + leading_dense = 0 + moe_layers_enum = self.hparams.get("moe_layers_enum") + if isinstance(moe_layers_enum, str) and moe_layers_enum.strip(): + moe_layers = sorted(int(i) for i in moe_layers_enum.strip().split(",")) + if moe_layers: + leading_dense = max(0, moe_layers[0]) + self.gguf_writer.add_leading_dense_block_count(leading_dense) + self.gguf_writer.add_moe_every_n_layers(int(self.hparams.get("moe_every_n_layer", 1))) + + self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5)) + + # Optional per-layer SwiGLU clamps. + if (limits := self.hparams.get("swiglu_limits")) is not None: + limits_f = [0.0 if v is None else float(v) for v in limits[: self.block_count]] + self.gguf_writer.add_swiglu_clamp_exp(limits_f) + if (limits_shared := self.hparams.get("swiglu_limits_shared")) is not None: + limits_shared_f = [0.0 if v is None else float(v) for v in limits_shared[: self.block_count]] + self.gguf_writer.add_swiglu_clamp_shexp(limits_shared_f) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + # remove mtp layers + if (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None: + il = int(m.group(1)) + n_main = int(self.hparams.get("num_hidden_layers", self.block_count)) + if il >= n_main: + return + if name.endswith("norm.weight"): + data_torch += 1.0 + # Map router bias (expert selection bias) to a GGUF bias tensor + if name.endswith(".moe.router_bias"): + name += ".bias" + + if name.endswith((".self_attn.g_proj.weight", ".moe.gate.weight", ".moe.up_proj.weight", ".moe.gate_proj.weight", ".moe.down_proj.weight")): + data_torch = data_torch.squeeze().contiguous() + + yield from super().modify_tensors(data_torch, name, bid) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # Step35 can optionally use Llama-3 style RoPE scaling (HF: rope_scaling.rope_type == "llama3"). + # llama.cpp represents this via a single extra tensor: "rope_freqs.weight" (aka MODEL_TENSOR.ROPE_FREQS). + rope_params = self.rope_parameters.get("full_attention", self.rope_parameters) + rope_type = rope_params.get("rope_type") or "" + if rope_type.lower() != "llama3": + return + + # Step35 configs can carry per-layer rope_theta as a list; for llama3 rope factors we use the base value. + rope_theta = self.hparams.get("rope_theta", 10000.0) + if isinstance(rope_theta, list): + rope_theta = rope_theta[0] + base = float(rope_theta) + if (dim := self.hparams.get("head_dim")) is None: + dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + dim = int(dim) + + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + factor = float(rope_params.get("factor", 8.0)) + low_freq_factor = float(rope_params.get("low_freq_factor", 1.0)) + high_freq_factor = float(rope_params.get("high_freq_factor", 4.0)) + old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192))) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + rope_factors: list[float] = [] + for freq in freqs: + wavelen = 2 * math.pi / float(freq) + if wavelen < high_freq_wavelen: + rope_factors.append(1.0) + elif wavelen > low_freq_wavelen: + rope_factors.append(factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + rope_factors.append(1.0 / ((1.0 - smooth) / factor + smooth)) + + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + + @ModelBase.register("PanguEmbeddedForCausalLM") class PanguEmbeddedModel(TextModel): model_arch = gguf.MODEL_ARCH.PANGU_EMBED @@ -10931,6 +11185,103 @@ class KimiVLModel(MmprojModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("KimiK25ForConditionalGeneration") +class KimiK25Model(MmprojModel): + """Kimi-K2.5 with MoonViT3d vision encoder""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert self.hparams_vision is not None, "Kimi-K2.5 requires vision_config in model config" + + self.merge_kernel_size = tuple(self.hparams_vision.get("merge_kernel_size", [2, 2])) + self.patch_size = self.hparams_vision.get("patch_size", 14) + + # Set image_size for compatibility with base class + # Use position embedding dimensions as image_size reference + pos_emb_h = self.hparams_vision.get("init_pos_emb_height", 64) + self.hparams_vision["image_size"] = pos_emb_h * self.patch_size + + def set_gguf_parameters(self): + # Base class MmprojModel.set_gguf_parameters() already writes: + # - vision_block_count, vision_head_count, vision_embedding_length + # - vision_feed_forward_length, vision_patch_size, image_mean, image_std + # via find_vparam() which handles the vt_* prefixed keys in Kimi-K2.5's config + super().set_gguf_parameters() + assert self.hparams_vision is not None + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25) + + # Position embedding parameters (for interpolation) + self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64)) + self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64)) + self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4)) + + # Projector parameters + self.gguf_writer.add_vision_use_gelu(self.hparams_vision.get("projector_hidden_act", "gelu") == "gelu") + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5)) + self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0]) + + # Image size limits + # Note: in_patch_limit is for images, in_patch_limit_each_frame is for video (not supported yet) + in_patch_limit = self.preprocessor_config.get("in_patch_limit", 16384) + min_patches = 8 # reasonable minimum + pixels_per_patch = self.patch_size ** 2 + self.gguf_writer.add_vision_min_pixels(min_patches * pixels_per_patch) + self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch) + + @staticmethod + def permute(weights: Tensor, n_head: int) -> Tensor: + out_dim, in_dim = weights.shape + head_dim = out_dim // n_head + w = weights.reshape(n_head, head_dim // 4, 2, 2, in_dim) + w = w.permute(0, 2, 1, 3, 4) + return w.reshape(out_dim, in_dim) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Only process vision and projector tensors + is_vision = any(x in name for x in ["vision_tower", "mm_projector"]) + + if not is_vision: + return + + assert self.hparams_vision is not None + n_head = self.hparams_vision.get("num_attention_heads", 16) + + # Permute Q/K weights/biases from interleaved to split RoPE format + # This allows using build_rope_2d at runtime without post-permutation. + if "wqkv" in name: + out_dim = data_torch.shape[0] + qkv_dim = out_dim // 3 + head_dim = qkv_dim // n_head + + if "weight" in name: + wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2 * qkv_dim, :], data_torch[2 * qkv_dim:, :] + wq = self.permute(wq, n_head) + wk = self.permute(wk, n_head) + data_torch = torch.cat([wq, wk, wv], dim=0) + elif "bias" in name: + bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2 * qkv_dim], data_torch[2 * qkv_dim:] + bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) + bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) + data_torch = torch.cat([bq, bk, bv], dim=0) + + # Temporal embeddings: (T, 1, C) → (T, C) + if "pos_emb.time_weight" in name: + T, _, C = data_torch.shape + data_torch = data_torch.reshape(T, C) + + # PatchMergerMLP tensor name mapping + # proj.0.weight → proj.linear_1.weight + # proj.2.weight → proj.linear_2.weight + if "mm_projector.proj.0." in name: + name = name.replace(".proj.0.", ".proj.linear_1.") + elif "mm_projector.proj.2." in name: + name = name.replace(".proj.2.", ".proj.linear_2.") + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("CogVLMForCausalLM") class CogVLMVisionModel(MmprojModel): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 2811f7f884..a683451508 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -148,6 +148,7 @@ models = [ {"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", }, {"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", }, {"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", }, + {"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", } ] # some models are known to be broken upstream, so we will skip them as exceptions diff --git a/docs/backend/VirtGPU.md b/docs/backend/VirtGPU.md new file mode 100644 index 0000000000..c81468da13 --- /dev/null +++ b/docs/backend/VirtGPU.md @@ -0,0 +1,180 @@ +# GGML-VirtGPU Backend + +The GGML-VirtGPU backend enables GGML applications to run machine +learning computations on host hardware while the application itself +runs inside a virtual machine. It uses host-guest shared memory to +efficiently share data buffers between the two sides. + +This backend relies on the virtio-gpu, and VirglRenderer API Remoting +(APIR) component. The backend is split into two libraries: +- a GGML implementation (the "remoting frontend"), running in the + guest and interacting with the virtgpu device +- a VirglRenderer APIR compatible library (the "remoting backend"), + running in the host and interacting with Virglrenderer and an actual + GGML device backend. + +## OS support + +| OS | Status | Backend | CI testing | Notes +| -------- | ----------------- | ----------- | ----------- | ----- +| MacOS 14 | Supported | ggml-metal | X | Working when compiled on MacOS 14 +| MacOS 15 | Supported | ggml-metal | X | Working when compiled on MacOS 14 or MacOS 15 +| MacOS 26 | Not tested | | | +| Linux | Under development | ggml-vulkan | not working | Working locally, CI running into deadlocks + + +## Architecture Overview + +The GGML-VirtGPU backend consists of three main components: + +```mermaid +graph TD + %% Nodes + + subgraph GuestVM ["Guest VM - Frontend"] + App([GGML Application
llama.cpp, etc.]) + + direction TB + Interface[GGML Backend Interface] + Comm["GGML-VirtGPU
(hypercalls + shared mem)"] + + App --> Interface + Interface --> Comm + end + + API[virtio-gpu / virglrenderer API] + + subgraph HostSystem [Host System - Backend] + direction TB + Dispatcher[GGML-VirtGPU-Backend] + BackendLib[GGML Backend library
Metal / Vulkan / CPU / ...] + + Dispatcher --> BackendLib + end + + %% Connections + Comm --> API + API --> HostSystem +``` + +### Key Components + +1. **Guest-side Frontend** (`ggml-virtgpu/`): Implements the GGML backend interface and forwards operations to the host +2. **Host-side Backend** (`ggml-virtgpu/backend/`): Receives forwarded operations and executes them on actual hardware backends +3. **Communication Layer**: Uses virtio-gpu hypercalls and shared memory for efficient data transfer + +## Features + +- **Dynamic backend loading** on the host side (CPU, CUDA, Metal, etc.) +- **Zero-copy data transfer** via host-guest shared memory pages + +## Communication Protocol + +### Hypercalls and Shared Memory + +The backend uses two primary communication mechanisms: + +1. **Hypercalls (`DRM_IOCTL_VIRTGPU_EXECBUFFER`)**: Trigger remote execution from guest to host +2. **Shared Memory Pages**: Zero-copy data transfer for tensors and parameters + +#### Shared Memory Layout + +Each connection uses two shared memory buffers: + +- **Data Buffer** (24 MiB): For command/response data and tensor transfers +- **Reply Buffer** (16 KiB): For command replies and status information +- **Data Buffers**: Dynamically allocated host-guest shared buffers + served as GGML buffers. + +### APIR Protocol + +The Virglrender API Remoting protocol defines three command types: + +- `HANDSHAKE`: Protocol version negotiation and capability discovery +- `LOADLIBRARY`: Dynamic loading of backend libraries on the host +- `FORWARD`: API function call forwarding + +### Binary Serialization + +Commands and data are serialized using a custom binary protocol with: + +- Fixed-size encoding for basic types +- Variable-length arrays with size prefixes +- Buffer bounds checking +- Error recovery mechanisms + +## Supported Operations + +### Device Operations +- Device enumeration and capability queries +- Memory information (total/free) +- Backend type detection + +### Buffer Operations +- Buffer allocation and deallocation +- Tensor data transfer (host ↔ guest) +- Memory copying and clearing + +### Computation Operations +- Graph execution forwarding + +## Build Requirements + +### Guest-side Dependencies +- `libdrm` for DRM/virtio-gpu communication +- C++20 compatible compiler +- CMake 3.14+ + +### Host-side Dependencies +- virglrenderer with APIR support (pending upstream review) +- Target backend libraries (libggml-metal, libggml-vulkan, etc.) + +## Configuration + +### Environment Variables + +- `GGML_VIRTGPU_BACKEND_LIBRARY`: Path to the host-side backend library +- `GGML_VIRTGPU_DEBUG`: Enable debug logging + +### Build Options + +- `GGML_VIRTGPU`: Enable the VirtGPU backend (`ON` or `OFF`, default: `OFF`) +- `GGML_VIRTGPU_BACKEND`: Build the host-side backend component (`ON`, `OFF` or `ONLY`, default: `OFF`) + +### System Requirements + +- VM with virtio-gpu support +- VirglRenderer with APIR patches +- Compatible backend libraries on host + +## Limitations + +- **VM-specific**: Only works in virtual machines with virtio-gpu support +- **Host dependency**: Requires properly configured host-side backend +- **Latency**: Small overhead from VM escaping for each operation + + +* This work is pending upstream changes in the VirglRenderer + project. + * The backend can be tested with Virglrenderer compiled from source + using this PR: + https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590 +* This work is pending changes in the VMM/hypervisor running the + virtual machine, which need to know how to route the newly + introduced APIR capset. + * The environment variable `VIRGL_ROUTE_VENUS_TO_APIR=1` allows + using the Venus capset, until the relevant hypervisors have been + patched. However, setting this flag breaks the Vulkan/Venus normal + behavior. + * The environment variable `GGML_REMOTING_USE_APIR_CAPSET` tells the + `ggml-virtgpu` backend to use the APIR capset. This will become + the default when the relevant hypervisors have been patched. + +* This work focused on improving the performance of llama.cpp running + on MacOS containers, and is mainly tested on this platform. The + linux support (via `krun`) is in progress. + +## See Also + +- [Development and Testing](VirtGPU/development.md) +- [Backend configuration](VirtGPU/configuration.md) diff --git a/docs/backend/VirtGPU/configuration.md b/docs/backend/VirtGPU/configuration.md new file mode 100644 index 0000000000..597862d5c8 --- /dev/null +++ b/docs/backend/VirtGPU/configuration.md @@ -0,0 +1,174 @@ +# GGML-VirtGPU Backend Configuration + +This document describes the environment variables used by the ggml-virtgpu backend system, covering both the frontend (guest-side) and backend (host-side) components. + +## Environment Variables Overview + +The ggml-virtgpu backend uses environment variables for configuration across three main components: +- **Frontend (Guest)**: GGML applications running in VMs +- **Hypervisor**: Virglrenderer/APIR system +- **Backend (Host)**: Host-side GGML backend integration + +## Frontend (Guest-side) Configuration + +### GGML_REMOTING_USE_APIR_CAPSET +- **Location**: `ggml/src/ggml-virtgpu/virtgpu.cpp` +- **Type**: Boolean flag (presence-based) +- **Purpose**: Controls which virtio-gpu capability set to use for communication +- **Values**: + - Set (any value): Use the APIR capset (long-term setup) + - Unset: Use the Venus capset (easier for testing with an unmodified hypervisor) +- **Default**: Unset (Venus capset) +- **Usage**: + ```bash + export GGML_REMOTING_USE_APIR_CAPSET=1 # Use APIR capset + # or leave unset for Venus capset + ``` + +## Hypervisor (Virglrenderer/APIR) Configuration + +These environment variables are used during the transition phase for +running with an unmodified hypervisor (not supporting the +VirglRenderer APIR component). They will be removed in the future, and +the hypervisor will instead configure VirglRenderer with the APIR +_Configuration Key_. + +### VIRGL_APIR_BACKEND_LIBRARY +- **Location**: `virglrenderer/src/apir/apir-context.c` +- **Configuration Key**: `apir.load_library.path` +- **Type**: File path string +- **Purpose**: Path to the APIR backend library that virglrenderer should dynamically load +- **Required**: Yes +- **Example**: + ```bash + export VIRGL_APIR_BACKEND_LIBRARY="/path/to/libggml-remotingbackend.so" + ``` + +### VIRGL_ROUTE_VENUS_TO_APIR +- **Location**: `virglrenderer/src/apir/apir-renderer.h` +- **Type**: Boolean flag (presence-based) +- **Purpose**: Temporary workaround to route Venus capset calls to APIR during hypervisor transition period +- **Status**: will be removed once hypervisors support APIR natively +- **Warning**: Breaks normal Vulkan/Venus functionality +- **Usage**: + ```bash + export VIRGL_ROUTE_VENUS_TO_APIR=1 # For testing with an unmodified hypervisor + ``` + +### VIRGL_APIR_LOG_TO_FILE +- **Location**: `virglrenderer/src/apir/apir-renderer.c` +- **Environment Variable**: `VIRGL_APIR_LOG_TO_FILE` +- **Type**: File path string +- **Purpose**: Enable debug logging from the VirglRenderer APIR component to specified file +- **Required**: No (optional debugging) +- **Default**: Logging to `stderr` +- **Usage**: + ```bash + export VIRGL_APIR_LOG_TO_FILE="/tmp/apir-debug.log" + ``` + +## Backend (Host-side) Configuration + +These environment variables are used during the transition phase for +running with an unmodified hypervisor (not supporting the +VirglRenderer APIR component). They will be removed in the future, and +the hypervisor will instead configure VirglRenderer with the APIR +_Configuration Key_. + +### APIR_LLAMA_CPP_GGML_LIBRARY_PATH +- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp` +- **Environment Variable**: `APIR_LLAMA_CPP_GGML_LIBRARY_PATH` +- **Configuration Key**: `ggml.library.path` +- **Type**: File path string +- **Purpose**: Path to the actual GGML backend library (Metal, CUDA, Vulkan, etc.) +- **Required**: **Yes** - backend initialization fails without this +- **Examples**: + ```bash + # macOS with Metal backend + export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-metal.dylib" + + # Linux with CUDA backend + export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-cuda.so" + + # macOS or Linux with Vulkan backend + export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-vulkan.so" + ``` + +### APIR_LLAMA_CPP_GGML_LIBRARY_REG +- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp` +- **Environment Variable**: `APIR_LLAMA_CPP_GGML_LIBRARY_REG` +- **Configuration Key**: `ggml.library.reg` +- **Type**: Function symbol name string +- **Purpose**: Name of the backend registration function to call after loading the library +- **Required**: No (defaults to `ggml_backend_init`) +- **Default**: `ggml_backend_init` +- **Examples**: + ```bash + # Metal backend + export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_metal_reg" + + # CUDA backend + export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_cuda_reg" + + # Vulkan backend + export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_vulkan_reg" + + # Generic fallback (default) + # export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_init" + ``` + +### APIR_LLAMA_CPP_LOG_TO_FILE +- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp:62` +- **Environment Variable**: `APIR_LLAMA_CPP_LOG_TO_FILE` +- **Type**: File path string +- **Purpose**: Enable debug logging from the GGML backend to specified file +- **Required**: No (optional debugging) +- **Usage**: + ```bash + export APIR_LLAMA_CPP_LOG_TO_FILE="/tmp/ggml-backend-debug.log" + ``` + +## Configuration Flow + +The configuration system works as follows: + +1. **Hypervisor Setup**: Virglrenderer loads the APIR backend library specified by `VIRGL_APIR_BACKEND_LIBRARY` + +2. **Context Creation**: When an APIR context is created, it populates a configuration table with environment variables: + - `apir.load_library.path` ← `VIRGL_APIR_BACKEND_LIBRARY` + - `ggml.library.path` ← `APIR_LLAMA_CPP_GGML_LIBRARY_PATH` + - `ggml.library.reg` ← `APIR_LLAMA_CPP_GGML_LIBRARY_REG` + - this step will eventually be performed by the hypervisor itself, with command-line arguments instead of environment variables. + +3. **Backend Initialization**: The backend queries the configuration via callbacks: + - `virgl_cbs->get_config(ctx_id, "ggml.library.path")` returns the library path + - `virgl_cbs->get_config(ctx_id, "ggml.library.reg")` returns the registration function + +4. **Library Loading**: The backend dynamically loads and initializes the specified GGML library + +## Error Messages + +Common error scenarios and their messages: + +- **Missing library path**: `"cannot open the GGML library: env var 'APIR_LLAMA_CPP_GGML_LIBRARY_PATH' not defined"` +- **Missing registration function**: `"cannot register the GGML library: env var 'APIR_LLAMA_CPP_GGML_LIBRARY_REG' not defined"` + +## Example Complete Configuration + +Here's an example configuration for a macOS host with Metal backend: + +```bash +# Hypervisor environment +export VIRGL_APIR_BACKEND_LIBRARY="/opt/llama.cpp/lib/libggml-virtgpu-backend.dylib" + +# Backend configuration +export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-metal.dylib" +export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_metal_reg" + +# Optional logging +export VIRGL_APIR_LOG_TO_FILE="/tmp/apir.log" +export APIR_LLAMA_CPP_LOG_TO_FILE="/tmp/ggml.log" + +# Guest configuration +export GGML_REMOTING_USE_APIR_CAPSET=1 +``` diff --git a/docs/backend/VirtGPU/development.md b/docs/backend/VirtGPU/development.md new file mode 100644 index 0000000000..ca2e47772a --- /dev/null +++ b/docs/backend/VirtGPU/development.md @@ -0,0 +1,220 @@ +# Development and Testing + +## Development + +### Code Generation + +The backend uses code generation from YAML configuration: + +```bash +# Regenerate protocol code +cd ggml-virtgpu/ +python regenerate_remoting.py +``` + +### Adding New Operations + +1. Add function definition to `ggmlremoting_functions.yaml` +2. Regenerate code with `regenerate_remoting.py` +3. Implement guest-side forwarding in `virtgpu-forward-*.cpp` +4. Implement host-side handling in `backend-dispatched-*.cpp` + +## Testing + +This document provides instructions for building and testing the GGML-VirtGPU backend on macOS with containers. + +### Prerequisites + +The testing setup requires: + +- macOS host system +- Container runtime with `libkrun` provider (podman machine) +- Access to development patchset for VirglRenderer + +### Required Patchsets + +The backend requires patches that are currently under review: + +- **Virglrenderer APIR upstream PR**: https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590 (for reference) +- **MacOS Virglrenderer (for krunkit)**: https://gitlab.freedesktop.org/kpouget/virglrenderer/-/tree/main-macos +- **Linux Virglrenderer (for krun)**: https://gitlab.freedesktop.org/kpouget/virglrenderer/-/tree/main-linux + +### Build Instructions + +#### 1. Build ggml-virtgpu-backend (Host-side, macOS) + +```bash +# Build the backend that runs natively on macOS +mkdir llama.cpp +cd llama.cpp +git clone https://github.com/ggml-org/llama.cpp.git src +cd src + +LLAMA_MAC_BUILD=$PWD/build/ggml-virtgpu-backend + +cmake -S . -B $LLAMA_MAC_BUILD \ + -DGGML_NATIVE=OFF \ + -DLLAMA_CURL=ON \ + -DGGML_REMOTINGBACKEND=ONLY \ + -DGGML_METAL=ON + +TARGETS="ggml-metal" +cmake --build $LLAMA_MAC_BUILD --parallel 8 --target $TARGETS + +# Build additional tools for native benchmarking +EXTRA_TARGETS="llama-run llama-bench" +cmake --build $LLAMA_MAC_BUILD --parallel 8 --target $EXTRA_TARGETS +``` + +#### 2. Build virglrenderer (Host-side, macOS) + +```bash +# Build virglrenderer with APIR support +mkdir virglrenderer +git clone https://gitlab.freedesktop.org/kpouget/virglrenderer -b main-macos src +cd src + +VIRGL_BUILD_DIR=$PWD/build + +# -Dvenus=true and VIRGL_ROUTE_VENUS_TO_APIR=1 route the APIR requests via the Venus backend, for easier testing without a patched hypervisor + +meson setup $VIRGL_BUILD_DIR \ + -Dvenus=true \ + -Dapir=true + +ninja -C $VIRGL_BUILD_DIR +``` + +#### 3. Build ggml-virtgpu (Guest-side, Linux) + +Option A: Build from a script: + +```bash +# Inside a Linux container +mkdir llama.cpp +git clone https://github.com/ggml-org/llama.cpp.git src +cd src + +LLAMA_LINUX_BUILD=$PWD//build-virtgpu + +cmake -S . -B $LLAMA_LINUX_BUILD \ + -DGGML_VIRTGPU=ON + +ninja -C $LLAMA_LINUX_BUILD +``` + +Option B: Build container image with frontend: + +```bash +cat << EOF > remoting.containerfile +FROM quay.io/fedora/fedora:43 +USER 0 + +WORKDIR /app/remoting + +ARG LLAMA_CPP_REPO="https://github.com/ggml-org/llama.cpp.git" +ARG LLAMA_CPP_VERSION="master" +ARG LLAMA_CPP_CMAKE_FLAGS="-DGGML_VIRTGPU=ON" +ARG LLAMA_CPP_CMAKE_BUILD_FLAGS="--parallel 4" + +RUN dnf install -y git cmake gcc gcc-c++ libcurl-devel libdrm-devel + +RUN git clone "\${LLAMA_CPP_REPO}" src \\ + && git -C src fetch origin \${LLAMA_CPP_VERSION} \\ + && git -C src reset --hard FETCH_HEAD + +RUN mkdir -p build \\ + && cd src \\ + && set -o pipefail \\ + && cmake -S . -B ../build \${LLAMA_CPP_CMAKE_FLAGS} \\ + && cmake --build ../build/ \${LLAMA_CPP_CMAKE_BUILD_FLAGS} + +ENTRYPOINT ["/app/remoting/src/build/bin/llama-server"] +EOF + +mkdir -p empty_dir +podman build -f remoting.containerfile ./empty_dir -t localhost/llama-cpp.virtgpu +``` + +### Environment Setup + +#### Set krunkit Environment Variables + +```bash +# Define the base directories (adapt these paths to your system) +VIRGL_BUILD_DIR=$HOME/remoting/virglrenderer/build +LLAMA_MAC_BUILD=$HOME/remoting/llama.cpp/build-backend + +# For krunkit to load the custom virglrenderer library +export DYLD_LIBRARY_PATH=$VIRGL_BUILD_DIR/src + +# For Virglrenderer to load the ggml-remotingbackend library +export VIRGL_APIR_BACKEND_LIBRARY="$LLAMA_MAC_BUILD/bin/libggml-virtgpu-backend.dylib" + +# For llama.cpp remotingbackend to load the ggml-metal backend +export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="$LLAMA_MAC_BUILD/bin/libggml-metal.dylib" +export APIR_LLAMA_CPP_GGML_LIBRARY_REG=ggml_backend_metal_reg +``` + +#### Launch Container Environment + +```bash +# Set container provider to libkrun +export CONTAINERS_MACHINE_PROVIDER=libkrun +podman machine start +``` + +#### Verify Environment + +Confirm that krunkit is using the correct virglrenderer library: + +```bash +lsof -c krunkit | grep virglrenderer +# Expected output: +# krunkit 50574 user txt REG 1,14 2273912 10849442 ($VIRGL_BUILD_DIR/src)/libvirglrenderer.1.dylib +``` + +### Running Tests + +#### Launch Test Container + +```bash +# Optional model caching +mkdir -p models +PODMAN_CACHE_ARGS="-v models:/models --user root:root --cgroupns host --security-opt label=disable -w /models" + +podman run $PODMAN_CACHE_ARGS -it --rm --device /dev/dri localhost/llama-cpp.virtgpu +``` + +#### Test llama.cpp in Container + +```bash + +# Run performance benchmark +/app/remoting/build/bin/llama-bench -m ./llama3.2 +``` + +Expected output (performance may vary): +``` +| model | size | params | backend | ngl | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: | +| llama 3B Q4_K - Medium | 1.87 GiB | 3.21 B | ggml-virtgpu | 99 | pp512 | 991.30 ± 0.66 | +| llama 3B Q4_K - Medium | 1.87 GiB | 3.21 B | ggml-virtgpu | 99 | tg128 | 85.71 ± 0.11 | +``` + +### Troubleshooting + +#### SSH Environment Variable Issues + +⚠️ **Warning**: Setting `DYLD_LIBRARY_PATH` from SSH doesn't work on macOS. Here is a workaround: + +**Workaround 1: Replace system library** +```bash +VIRGL_BUILD_DIR=$HOME/remoting/virglrenderer/build # ⚠️ adapt to your system +BREW_VIRGL_DIR=/opt/homebrew/Cellar/virglrenderer/0.10.4d/lib +VIRGL_LIB=libvirglrenderer.1.dylib + +cd $BREW_VIRGL_DIR +mv $VIRGL_LIB ${VIRGL_LIB}.orig +ln -s $VIRGL_BUILD_DIR/src/$VIRGL_LIB +``` diff --git a/docs/backend/snapdragon/README.md b/docs/backend/snapdragon/README.md index 8e1f37b206..2c3f88e91a 100644 --- a/docs/backend/snapdragon/README.md +++ b/docs/backend/snapdragon/README.md @@ -35,7 +35,7 @@ Adapt below build commands accordingly. Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets: ``` -[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json . +[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json . [d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon Preset CMake variables: diff --git a/docs/ops.md b/docs/ops.md index ef1ebff8b0..5754b0a96c 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -22,7 +22,7 @@ Legend: | ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | -| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | +| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | diff --git a/docs/ops/SYCL.csv b/docs/ops/SYCL.csv index 2aa51304b3..c1622cc6f0 100644 --- a/docs/ops/SYCL.csv +++ b/docs/ops/SYCL.csv @@ -77,8 +77,8 @@ "SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" -"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" -"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" +"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL" +"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" "SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" @@ -161,8 +161,8 @@ "SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" -"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" -"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" +"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL" +"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL" "SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" "SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL" "SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL" diff --git a/docs/speculative.md b/docs/speculative.md index 03afab5b41..29da332875 100644 --- a/docs/speculative.md +++ b/docs/speculative.md @@ -119,8 +119,6 @@ If a draft model is combined with a draftless decoding the draftless decoding ha of lookup n-gram (default: 12) --spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: 48) ---spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding - (default: 1) --spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1) ``` @@ -153,10 +151,6 @@ Sets the size M of the draft m-gram for n-gram map based speculative decoding. The m-gram size determines how many tokens to draft when a match is found. Larger values can provide more speedup but may reduce acceptance rate. -### `--spec-ngram-check-rate R` - -This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token). - ### `--spec-ngram-min-hits H` This option defines how often a key has to appear in the token history to be used as a draft (default is 1). @@ -175,7 +169,12 @@ draft acceptance rate = 0.70312 ( 90 accepted / 128 generated) statistics ngram_mod: #calls = 810, #gen drafts = 15, #acc drafts = 15, #gen tokens = 960, #acc tokens = 730, dur(b,g,a) = 0.149, 0.347, 0.005 ms ``` -- `#calls`: number of calls of this implementations +``` +statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts = 26, #gen tokens = 1248, #acc tokens = 968, dur(b,g,a) = 2.234, 1.427, 0.016 ms +``` + + +- `#calls(b,g,a)`: number of calls of begin (new prompt), generation and accumulation of this implementations - `#gen drafts`: number of drafts generated by this implementation - `#acc drafts`: number of drafts accepted (partially) by the main model - `#gen tokens`: number of tokens generated by this implementation (including rejected tokens) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 8a693f84af..311fa5fe36 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -471,9 +471,10 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, int best_score = 0; fs::path best_path; + std::error_code ec; for (const auto & search_path : search_paths) { - if (std::error_code ec; !fs::exists(search_path, ec)) { + if (!fs::exists(search_path, ec)) { if (ec) { GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str()); } else { @@ -483,7 +484,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, } fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); for (const auto & entry : dir_it) { - if (entry.is_regular_file()) { + if (entry.is_regular_file(ec)) { auto filename = entry.path().filename(); auto ext = entry.path().extension(); if (filename.native().find(file_prefix) == 0 && ext == file_extension) { diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 87ac05748e..fc7c3e3b72 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3286,130 +3286,223 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor } /** - * @brief Performs expert-specific matrix multiplication (MoE) with - * quantized precision using the CANN backend. + * @brief Performs quantized matrix multiplication for Mixture of Experts (MoE) + * models using the CANN backend. * - * This function executes a matrix multiplication operation tailored for - * Mixture of Experts (MoE) models, where the input tensor is multiplied - * with expert-specific quantized weight matrices. It leverages the CANN - * backend to perform efficient low-precision computations and stores the - * quantized result in the destination tensor `dst`. + * This function implements MUL_MAT_ID operation for quantized weight matrices + * (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on + * the provided expert indices, and computes matrix multiplication using CANN's + * WeightQuantBatchMatmulV2 operator. * - * Quantization techniques reduce memory footprint and improve performance - * by using lower-bit representations (e.g., int8) instead of floating-point. - * This function is designed to work with such formats and may incorporate - * optimizations like identity-based fast paths or routing masks for sparse - * expert selection. + * The function performs the following steps: + * 1. Converts input/output tensors to F16 format if necessary + * 2. Uses IndexSelect to extract expert-specific weights and scales based on indices + * 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2 + * 4. Converts output back to the target type if needed * - * @param ctx The context for executing CANN backend operations. - * @param dst The destination tensor where the quantized MoE multiplication result - * will be stored. + * Tensor shapes: + * - dst: [M, K, N, 1] - output tensor + * - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0) + * - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast) + * - ids: [K, N] - expert indices for routing * - * @note This function assumes quantized data types and is designed for - * MoE architectures with potential sparse expert routing. + * @param ctx The CANN backend context for operation execution. + * @param dst The destination tensor where the multiplication result will be stored. + * + * @note Only Q4_0 and Q8_0 quantization formats are supported. + * @note The function handles automatic type conversion to/from F16 as needed by the hardware. */ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - // TODO: Use aclnnGroupedMatMul - //dst [M, K, N, 1] - ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] - ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 - ggml_tensor * ids = dst->src[2]; //ids [K, N] + // dst: [M, K, N, 1] + // src0: [D, M, A, 1] - quantized weights + // src1: [D, B, N, 1] - input activations, B = K or B = 1 + // ids: [K, N] - expert indices + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + ggml_tensor * ids = dst->src[2]; - GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[3] == 1); + GGML_ASSERT(dst->ne[3] == 1); + GGML_ASSERT(src1->ne[2] == ids->ne[1]); - // copy index from npu to cpu - int64_t n_as = ne02; // A - int64_t n_ids = ids->ne[0]; // K + const int64_t n_batches = ids->ne[1]; + const int64_t n_select_experts = ids->ne[0]; + const enum ggml_type type = src0->type; - std::vector ids_host(ggml_nbytes(ids)); - ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids), - ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream())); - ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + const int32_t group_size = QK8_0; // Both Q4_0 and Q8_0 use group size of 32 + GGML_ASSERT(group_size == QK4_0); - char * src0_original = (char *) src0->data; - char * src1_original = (char *) src1->data; - char * dst_original = (char *) dst->data; + // Calculate element size for quantized weights + const float weight_elem_size = + (type == GGML_TYPE_Q4_0) ? 0.5f : + (type == GGML_TYPE_Q8_0) ? 1.0f : + (GGML_ABORT("MUL_MAT_ID only supports Q4_0 and Q8_0"), 0.0f); - ggml_tensor src0_row = *src0; - ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; + // Calculate scale offset in memory + const size_t weight_size = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size; + const size_t scale_elem_size = sizeof(uint16_t); + char * scale_data = (char *) src0->data + weight_size; - const enum ggml_type type = dst->src[0]->type; - float weight_elem_size; - if (type == GGML_TYPE_Q4_0) { - weight_elem_size = float(sizeof(uint8_t)) / 2; - } else if (type == GGML_TYPE_Q8_0) { - weight_elem_size = float(sizeof(uint8_t)); - } else { - GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 "); - } + // Allocate buffers for selected expert weights and scales + const size_t selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size; + ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size); + void * selected_weight_buffer = selected_weight_alloc.get(); - // src0_row [D, M, 1, 1] weight without permute - src0_row.ne[2] = 1; - src0_row.ne[3] = 1; - src0_row.nb[0] = weight_elem_size; - src0_row.nb[1] = weight_elem_size * ne00; - src0_row.nb[2] = weight_elem_size * ne00; - src0_row.nb[3] = weight_elem_size * ne00; - size_t weight_stride = ne00 * ne01 * weight_elem_size; - size_t weight_size = weight_stride * ne02 * ne03; + const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size; + ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size); + void * selected_scale_buffer = selected_scale_alloc.get(); - // scale [D, M, 1, 1] -> scale && permute - size_t scale_elem_size = sizeof(uint16_t); - size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; + // Helper lambda to allocate and cast tensor to F16 if needed + constexpr size_t f16_elem_size = sizeof(uint16_t); + auto prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator, + bool need_cast = false) -> void * { + if (tensor->type == GGML_TYPE_F16) { + return tensor->data; + } - // src1_row [D, 1, 1, 1] -> input - src1_row.ne[1] = 1; - src1_row.ne[2] = 1; - src1_row.ne[3] = 1; - src1_row.nb[2] = nb11; - src1_row.nb[3] = nb11; + size_t total_size = f16_elem_size; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + total_size *= tensor->ne[i]; + } + void * buffer = allocator.alloc(total_size); - // dst_row [M, 1, 1, 1] -> out - dst_row.ne[1] = 1; - dst_row.ne[2] = 1; - dst_row.ne[3] = 1; - dst_row.nb[2] = nb1; - dst_row.nb[3] = nb1; + if (need_cast == false) { + return buffer; + } - //create weight for one row - ggml_cann_pool_alloc weight_allocator(ctx.pool()); - void * weight_buffer = weight_allocator.alloc(nb02); - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - // expert index - int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS] = { f16_elem_size }; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + ne[i] = tensor->ne[i]; + if (i > 0) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + } - // If B = 1 (broadcast), always use 0; otherwise, use id. - int64_t i11 = (ne11 == 1 ? 0 : id); - int64_t i12 = iid1; + acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor); + acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS); + aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16); - int64_t i1 = id; - int64_t i2 = i12; + return buffer; + }; - void * src0_tmp_ptr = src0_original + i02 * weight_stride; - void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride; - void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12; - void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2; + // Prepare input and output buffers + ggml_cann_pool_alloc input_alloc(ctx.pool()); + void * input_buffer = prepare_f16_buffer(src1, input_alloc, true); - // mem cpy - ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); - void * scale_buffer = (char *) weight_buffer + weight_stride; - ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + ggml_cann_pool_alloc output_alloc(ctx.pool()); + void * output_buffer = prepare_f16_buffer(dst, output_alloc, false); - src0_row.data = weight_buffer; - src1_row.data = src1_tmp_ptr; - dst_row.data = dst_tmp_ptr; - dst_row.src[0] = &src0_row; - dst_row.src[1] = &src1_row; + // Process each batch + for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) { + // Create index tensor for current batch + const size_t index_offset = batch_idx * ids->nb[1]; + acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset); - ggml_cann_mul_mat(ctx, &dst_row); + // Select quantized weights using expert indices + // Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte + const int64_t weight_d = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0]; + const int64_t weight_m = src0->ne[1]; + const int64_t weight_n_experts = src0->ne[2]; + + int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts }; + size_t weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) }; + + acl_tensor_ptr all_weights = + ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3); + + int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts }; + size_t selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), + weight_d * weight_m * sizeof(int8_t) }; + + acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t), + selected_weight_ne, selected_weight_nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get()); + + // Select scales using the same expert indices + const int64_t scale_d = src0->ne[0] / group_size; + int64_t scale_ne[3] = { scale_d, weight_m, weight_n_experts }; + size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size }; + + acl_tensor_ptr all_scales = + ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3); + + int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts }; + size_t selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, + scale_d * weight_m * scale_elem_size }; + + acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, + selected_scale_ne, selected_scale_nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get()); + + // Process each expert for current batch + // IndexSelect output layout: [D, M, K] in contiguous format + // WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride + for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) { + // Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input + const size_t input_offset = + (batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size; + const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size; + + // Create weight view for current expert: [D, M, K] -> [M, D] + int64_t weight_view_ne[2] = { weight_m, src0->ne[0] }; + float weight_view_nb[2] = { src0->ne[0] * weight_elem_size, weight_elem_size }; + const size_t weight_view_offset = expert_idx * selected_weight_nb[2]; + + acl_tensor_ptr weight_view = + ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size, + weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset); + + // Create scale view for current expert: [D, M, K] -> [M, D] + int64_t scale_view_ne[2] = { weight_m, scale_d }; + size_t scale_view_nb[2] = { selected_scale_nb[1], selected_scale_nb[0] }; + const size_t scale_view_offset = expert_idx * selected_scale_nb[2]; + + acl_tensor_ptr scale_view = + ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne, + scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset); + + // Create input activation tensor [D, 1] + int64_t input_ne[2] = { src1->ne[0], 1 }; + size_t input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size }; + + acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne, + input_nb, 2, ACL_FORMAT_ND, input_offset); + + // Create output tensor [M, 1] + int64_t output_ne[2] = { dst->ne[0], 1 }; + size_t output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size }; + + acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne, + output_nb, 2, ACL_FORMAT_ND, output_offset); + + // Perform quantized matrix multiplication + GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(), + scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size, + output_tensor.get()); } } - return; + + // Cast output back to original type if we used a temporary F16 buffer + if (dst->type != GGML_TYPE_F16) { + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS] = { f16_elem_size }; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + ne[i] = dst->ne[i]; + if (i > 0) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + } + + acl_tensor_ptr f16_output = + ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS); + acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst); + + aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type)); + } } void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 6b2dbdd359..3f3de9f0bc 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -794,19 +794,44 @@ struct ggml_backend_cann_buffer_context { ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); } }; +// cann buffer type /** - * @brief Check if a buffer is a CANN buffer. - * - * This function checks if a given buffer is a CANN buffer by comparing its - * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`. - * - * @param buffer The buffer to check. - * @return true if the buffer is a CANN buffer, false otherwise. + * @brief Structure representing context information for a specific backend + * buffer type. */ -static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft); +struct ggml_backend_cann_buffer_type_context { + int32_t device; /**< Device identifier associated with the buffer context. */ + std::string name; /**< Name associated with the buffer context. */ +}; -static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) { - return ggml_backend_buft_is_cann(buffer->buft); +/** + * @brief Retrieves the name associated with a CANN buffer type. + * + * This function returns the descriptive name associated with the specified + * CANN buffer type context. + * + * @param buft Pointer to the buffer type context. + * @return Const pointer to the C-style string containing the name. + */ +static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context; + + return buft_ctx->name.c_str(); +} + +/** + * @brief Checks if the backend buffer type is associated with the CANN backend. + * + * This function checks whether the provided backend buffer type is associated + * with the CANN backend based on the comparison of its name retrieval function + * pointer. + * + * @param buft Pointer to the backend buffer type to check. + * @return bool Returns true if the buffer type is associated with the CANN + * backend, otherwise false. + */ +static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cann_buffer_type_name; } /** @@ -1271,7 +1296,7 @@ static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer, static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { - if (ggml_backend_buffer_is_cann(src->buffer)) { + if (ggml_backend_buft_is_cann(src->buffer->buft)) { ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context; ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context; @@ -1335,31 +1360,6 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .reset = */ NULL, }; -// cann buffer type -/** - * @brief Structure representing context information for a specific backend - * buffer type. - */ -struct ggml_backend_cann_buffer_type_context { - int32_t device; /**< Device identifier associated with the buffer context. */ - std::string name; /**< Name associated with the buffer context. */ -}; - -/** - * @brief Retrieves the name associated with a CANN buffer type. - * - * This function returns the descriptive name associated with the specified - * CANN buffer type context. - * - * @param buft Pointer to the buffer type context. - * @return Const pointer to the C-style string containing the name. - */ -static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) { - ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context; - - return buft_ctx->name.c_str(); -} - /** * @brief Allocates a new CANN buffer of the specified type and size. * @@ -1997,7 +1997,7 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src, GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src)); - if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) { + if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) { return false; } @@ -2523,21 +2523,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten GGML_UNUSED(dev); } -/** - * @brief Checks if the backend buffer type is associated with the CANN backend. - * - * This function checks whether the provided backend buffer type is associated - * with the CANN backend based on the comparison of its name retrieval function - * pointer. - * - * @param buft Pointer to the backend buffer type to check. - * @return bool Returns true if the buffer type is associated with the CANN - * backend, otherwise false. - */ -static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_cann_buffer_type_name; -} - /** * @brief Records an event on the CANN backend stream. * diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 427c1146e4..c6eb75b230 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -43,6 +43,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -55,7 +56,8 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K -# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -76,6 +78,7 @@ #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -84,6 +87,7 @@ #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -107,6 +111,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -119,6 +124,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -143,6 +149,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -155,6 +162,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -186,6 +194,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -197,6 +206,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -227,6 +237,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -239,6 +250,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -271,6 +283,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -283,6 +296,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 99bb70274c..fd05c609f7 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -1072,6 +1072,195 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q6_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_groups = ncols_interleaved / 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[2]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int i = 0; i < col_groups; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d); + + int32x4_t acc[col_groups]; + for (int i = 0; i < col_groups; i++) { + acc[i] = vdupq_n_s32(0); + } + + // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } + + // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift + int32x4_t bias_lo = vdupq_n_s32(0); + int32x4_t bias_hi = vdupq_n_s32(0); + + // Load bsums in chunks of 4 to process with vectorized operations + for (int i = 0; i < 16; i += 4) { + int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i); + int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8); + int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4); + int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8); + int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4); + int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8); + int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4); + int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8); + int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4); + + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3); + } + bias_lo = vshlq_n_s32(bias_lo, 5); + bias_hi = vshlq_n_s32(bias_hi, 5); + + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16; + const int8_t * q8_base_h = q8_base_l + 64; + + // Load and duplicate q8 values (each register covers four interleaved columns of q6) + int8x16_t q8_l[4]; + int8x16_t q8_h[4]; + for (int i = 0; i < 4; i++) { + q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4)); + q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4)); + } + + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes + + // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1) + uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base); + uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64); + uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base); + uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64); + + // Adjust qh for subblocks 2 and 3 (shift right by 2) + if (sb > 1) { + q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2); + q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2); + q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2); + q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2); + q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2); + q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2); + q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2); + q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2); + } + + const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3], + q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] }; + const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3], + q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] }; + + // Process column groups (0-3, 4-7) + for (int g = 0; g < col_groups; g++) { + int32x4_t sb_acc_l = vdupq_n_s32(0); + int32x4_t sb_acc_h = vdupq_n_s32(0); + + for (int chunk = 0; chunk < 4; chunk++) { + const int idx = chunk * 2 + g; + + const uint8x16_t q6_qs_l = q6_ql[idx]; + const uint8x16_t q6_qs_h = q6_qh[idx]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi); + + // q6 = (low4 | high2<<4), without -32 bias (handled via bsums) + const int8x16_t q6_l = + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4)); + const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh)); + + sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]); + } + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4)); + const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4)); + + acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l); + acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h); + } + } + } // for half + + // Bias correction + acc[0] = vsubq_s32(acc[0], bias_lo); + acc[1] = vsubq_s32(acc[1], bias_hi); + + // Apply superblock scale (no mins for q6_K) + // acc[g] has [c0, c1, c2, c3] + float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0); + float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1); + + acc_f32[0] = vaddq_f32(acc_f32[0], w_0123); + acc_f32[1] = vaddq_f32(acc_f32[1], w_4567); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, @@ -1177,15 +1366,14 @@ void ggml_gemv_q6_K_8x8_q8_K(int n, q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8)); } - // TODO: Test other qh repack patterns to reduce loads const int ql_off_base = sb * QK_K / 2; const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1) - ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base); - ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64); - ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base); - ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64); + uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base); + uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64); + uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base); + uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64); // Adjust qh for subblocks 2 and 3 (shift right by 2) if (sb > 1) { @@ -3474,6 +3662,208 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q6_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int q8_k_blocklen = 4; + constexpr int col_groups = ncols_interleaved / 4; + constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + const int8x16_t m32s = vdupq_n_s8(32); + + float32x4_t acc_f32[acc_size]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int i = 0; i < acc_size; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); + float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); + float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); + + float32x4_t sbd_scale_0123[q8_k_blocklen]; + float32x4_t sbd_scale_4567[q8_k_blocklen]; + + sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0); + sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0); + sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1); + sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1); + sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2); + sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2); + sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3); + sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3); + + int32x4_t acc_s32[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_s32[i] = vdupq_n_s32(0); + } + + int16_t q6_scales[8 * 16]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } + + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + for (int sb = 0; sb < QK_K / 64; sb++) { + int32x4_t acc_lo[acc_size]; + int32x4_t acc_hi[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + + const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64; + const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64; + + // 4 rows * 16 elements per scale + // 4 reads of 16 bytes each + constexpr int reads_per_sb = 4; + int8x16_t q8_l[reads_per_sb]; + int8x16_t q8_h[reads_per_sb]; + for (int k = 0; k < reads_per_sb; k++) { + q8_l[k] = vld1q_s8(q8_base_l + 16 * k); + q8_h[k] = vld1q_s8(q8_base_h + 16 * k); + } + + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; + + uint8x16_t q6_ql_0123[reads_per_sb]; + uint8x16_t q6_ql_4567[reads_per_sb]; + uint8x16_t q6_qh_0123[reads_per_sb]; + uint8x16_t q6_qh_4567[reads_per_sb]; + + for (int k = 0; k < reads_per_sb; k++) { + q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32); + q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16); + q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32); + q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16); + } + + if (sb > 1) { + for (int k = 0; k < reads_per_sb; k++) { + q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2); + q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2); + } + } + + for (int k = 0; k < reads_per_sb; k++) { + // q = (ql | qh) - 32 + const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo); + const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi); + const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo); + const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi); + + const int8x16_t q6_0123_lo = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s); + const int8x16_t q6_0123_hi = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s); + + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123 + + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123 + + const int8x16_t q6_4567_lo = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s); + const int8x16_t q6_4567_hi = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s); + + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567 + + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567 + } + + // Scale and bias + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + for (int g = 0; g < col_groups; g++) { + const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4); + const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4); + const int32x4_t scale_vec_l = vmovl_s16(scales_l16); + const int32x4_t scale_vec_h = vmovl_s16(scales_h16); + const int acc_offset = g * q8_k_blocklen; + + for (int row = 0; row < q8_k_blocklen; row++) { + const int idx = row * 2 + g; + acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l); + acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h); + } + } + } + } + + // Finally we apply the superblock scales + for (int row = 0; row < q8_k_blocklen; row++) { + const int idx0 = 2 * row; + const int idx1 = 2 * row + 1; + const int32x4_t acc_0123 = acc_s32[idx0]; + const int32x4_t acc_4567 = acc_s32[idx1]; + + acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]); + acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]); + } + } // for b + + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, diff --git a/ggml/src/ggml-cpu/binary-ops.cpp b/ggml/src/ggml-cpu/binary-ops.cpp index 14f5b43ae0..75e3829001 100644 --- a/ggml/src/ggml-cpu/binary-ops.cpp +++ b/ggml/src/ggml-cpu/binary-ops.cpp @@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds GGML_ASSERT(nb00 == sizeof(src0_t)); const auto [ir0, ir1] = get_thread_range(params, src0); - const bool is_src1_contiguous = (nb10 == sizeof(src1_t)); - - if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - } + const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1); #ifdef GGML_USE_ACCELERATE vDSP_fn_t vDSP_op = nullptr; @@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - if (is_src1_contiguous) { + if (is_src1_contiguous_rows) { // src1 is broadcastable across src0 and dst in i1, i2, i3 const int64_t nr0 = ne00 / ne10; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ce15b18ce0..4352e13280 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2096,10 +2096,14 @@ static void ggml_compute_forward_gelu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2113,10 +2117,14 @@ static void ggml_compute_forward_gelu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2135,10 +2143,14 @@ static void ggml_compute_forward_gelu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2152,10 +2164,14 @@ static void ggml_compute_forward_gelu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2276,10 +2292,14 @@ static void ggml_compute_forward_gelu_erf_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2293,10 +2313,14 @@ static void ggml_compute_forward_gelu_erf_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2315,10 +2339,14 @@ static void ggml_compute_forward_gelu_erf_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2332,10 +2360,14 @@ static void ggml_compute_forward_gelu_erf_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2379,10 +2411,14 @@ static void ggml_compute_forward_gelu_quick_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2396,10 +2432,14 @@ static void ggml_compute_forward_gelu_quick_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2418,10 +2458,14 @@ static void ggml_compute_forward_gelu_quick_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2435,10 +2479,14 @@ static void ggml_compute_forward_gelu_quick_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2482,10 +2530,14 @@ static void ggml_compute_forward_silu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2499,10 +2551,14 @@ static void ggml_compute_forward_silu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2521,10 +2577,14 @@ static void ggml_compute_forward_silu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2538,10 +2598,14 @@ static void ggml_compute_forward_silu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -7629,8 +7693,7 @@ static void ggml_compute_forward_pad_f32( const ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT( dst->nb[0] == sizeof(float)); + assert(dst->nb[0] == sizeof(float)); const int ith = params->ith; const int nth = params->nth; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 24e8ab4618..4cb7cdeb07 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -256,6 +256,200 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } +template +static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + for (int j = 0; j < ncols_interleaved; j++) { + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t a_l = a_ptr[l].qs[base_l + i]; + const int8_t a_h = a_ptr[l].qs[base_h + i]; + + sumi_l += q_l * a_l; + sumi_h += q_h * a_h; + } + + sumf[j] += + (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +template +static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; + const int q8_half_stride = 512; + const int q8_low_high_step = 256; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + + float sumf[4][8]; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0f; + } + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = + qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = + qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i]; + const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step]; + + sumi_l += q_l * q8_l; + sumi_h += q_h * q8_h; + } + + sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * + a_ptr[l].d[m]; + } + } + } + } + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + extern "C" { void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -704,94 +898,12 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, } +void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - constexpr int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[8]; - - const block_q8_K * a_ptr = (const block_q8_K *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0f; - } - - for (int l = 0; l < nb; l++) { - - - for (int k = 0; k < 16; k++) { - // k = 0.. 7 weights 0-63 low, 64-127 high - // k = 8..15 weights 128-191 low, 192-255 high - const int base_l = (k / 8) * 128 + (k % 8) * 8; - const int base_h = base_l + 64; - - const int scale_idx_l = base_l / 16; - const int scale_idx_h = base_h / 16; - - // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half - const int qh_shift_l = ((base_l % 128) / 32) * 2; - const int qh_shift_h = ((base_h % 128) / 32) * 2; - - // qh_half: offset to the correct 32-byte half (0 or 32) - const int qh_half_l = (base_l / 128) * 32; - const int qh_half_h = (base_h / 128) * 32; - - for (int j = 0; j < ncols_interleaved; j++) { - // Interleaved scales - const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; - const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; - - int sumi_l = 0; - int sumi_h = 0; - - for (int i = 0; i < blocklen; i++) { - const int ql_pos = k * 64 + j * 8 + i; - const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; - const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; - - // qh indexing with 8-byte interleaving (like q5_K) - const int qh_byte_l = qh_half_l + ((base_l + i) % 32); - const int qh_chunk_l = qh_byte_l / 8; - const int qh_pos_l = qh_byte_l % 8; - const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; - const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; - - const int qh_byte_h = qh_half_h + ((base_h + i) % 32); - const int qh_chunk_h = qh_byte_h / 8; - const int qh_pos_h = qh_byte_h % 8; - const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; - const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; - - const int q_l = ((hi_2_l << 4) | l_4) - 32; - const int q_h = ((hi_2_h << 4) | hi_4) - 32; - - const int8_t a_l = a_ptr[l].qs[base_l + i]; - const int8_t a_h = a_ptr[l].qs[base_h + i]; - - sumi_l += q_l * a_l; - sumi_h += q_h * a_h; - } - - sumf[j] += - (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; - } - } - } - - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j]; - } - } + ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1485,109 +1597,12 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n, } } -void ggml_gemm_q6_K_8x8_q8_K_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; +void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - - float sumf[4][8]; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); - - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0f; - } - } - - for (int l = 0; l < nb; l++) { - for (int k = 0; k < 16; k++) { - // k = 0.. 7 weights 0-63 low, 64-127 high - // k = 8..15 weights 128-191 low, 192-255 high - const int base_l = (k / 8) * 128 + (k % 8) * 8; - const int base_h = base_l + 64; - - const int scale_idx_l = base_l / 16; - const int scale_idx_h = base_h / 16; - - // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half - const int qh_shift_l = ((base_l % 128) / 32) * 2; - const int qh_shift_h = ((base_h % 128) / 32) * 2; - - // qh_half: offset to the correct 32-byte half (0 or 32) - const int qh_half_l = (base_l / 128) * 32; - const int qh_half_h = (base_h / 128) * 32; - - // Activation base indices for q8_Kx4 interleaved format - // Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32 - const int q8_base = (k / 8) * 512 + (k % 8) * 32; - - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - // Interleaved scales - const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; - const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; - - int sumi_l = 0; - int sumi_h = 0; - - for (int i = 0; i < blocklen; i++) { - const int ql_pos = k * 64 + j * 8 + i; - const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; - const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; - - const int qh_idx_l = qh_half_l + ((base_l + i) % 32); - const int qh_chunk_l = qh_idx_l / 8; - const int qh_pos_l = qh_idx_l % 8; - const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; - const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; - - const int qh_idx_h = qh_half_h + ((base_h + i) % 32); - const int qh_chunk_h = qh_idx_h / 8; - const int qh_pos_h = qh_idx_h % 8; - const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; - const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; - - const int q_l = ((hi_2_l << 4) | l_4) - 32; - const int q_h = ((hi_2_h << 4) | hi_4) - 32; - - const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i]; - const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256]; - - sumi_l += q_l * q8_l; - sumi_h += q_h * q8_h; - } - - sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * - a_ptr[l].d[m]; - } - } - } - } - - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } +void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -2097,18 +2112,18 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in } const int end_ls = QK_K * 4 / blck_size_interleave; - // Interleave Q6_K quants by taking 8 bytes at a time + // Interleave Q6_K quants by taking blck_size_interleave bytes at a time for (int i = 0; i < end_ls; ++i) { int src_id = i % n_blocks; int src_offset = (i / n_blocks) * blck_size_interleave; int dst_offset = i * blck_size_interleave; uint64_t elem_ls; - memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t)); - memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t)); + memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave); + memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave); } - // Interleave high bits using same 8-byte pattern as low bits + // Interleave high bits using same chunk size as low bits const int end_hs = end_ls / 2; for (int i = 0; i < end_hs; ++i) { int src_id = i % n_blocks; @@ -2116,8 +2131,8 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in int dst_offset = i * blck_size_interleave; uint64_t elem_hs; - memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t)); - memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t)); + memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave); + memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave); } // The below logic is designed so as to unpack and rearrange scales in Q6_K @@ -2262,7 +2277,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q6_K); - GGML_ASSERT(interleave_block == 8); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); constexpr int nrows_interleaved = 8; block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data; @@ -2511,6 +2526,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size); } @@ -2575,6 +2594,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2634,6 +2657,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -3043,6 +3070,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; // instance for Q6_K + static const ggml::cpu::repack::tensor_traits q6_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q6_K_8x8_q8_K; // instance for Q2 @@ -3107,6 +3135,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q6_K_8x8_q8_K; } } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 8 == 0) { + return &q6_K_8x4_q8_K; + } + } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 855320eeeb..39b6b48238 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -112,6 +112,7 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -122,6 +123,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -142,6 +144,7 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -152,6 +155,7 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 1d9873ad0f..1d8344436f 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -111,7 +111,7 @@ template static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst)); GGML_TENSOR_UNARY_OP_LOCALS diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index d313c1ac9a..262f88204e 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -64,7 +64,7 @@ if (CUDAToolkit_FOUND) FetchContent_Declare( CCCL GIT_REPOSITORY https://github.com/nvidia/cccl.git - GIT_TAG v3.2.0-rc2 + GIT_TAG v3.2.0 GIT_SHALLOW TRUE ) diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 0e6d777b1e..7339fe0c07 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + /*const int s0,*/ + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { @@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0, for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { const uint32_t i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } dst_row[i0] = (dst_t) result; @@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + /*const int s0,*/ + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { @@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const int i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } dst_row[i0] = (dst_t) result; @@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * cnb[3] *= cne[3]; }; - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { for (int i = 0; i < 4; i++) { if (nr[i] != 1) { break; @@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * size_t nb12 = cnb1[2]; size_t nb13 = cnb1[3]; - size_t s0 = nb0 / sizeof(dst_t); + //size_t s0 = nb0 / sizeof(dst_t); size_t s1 = nb1 / sizeof(dst_t); size_t s2 = nb2 / sizeof(dst_t); size_t s3 = nb3 / sizeof(dst_t); @@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s00 == 1); - GGML_ASSERT(s10 == 1); - const int block_size = 128; int64_t hne0 = std::max(ne0 / 2LL, 1LL); @@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * k_bin_bcast_unravel<<>>( src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast_unravel <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); } } else { const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); if constexpr (sizeof...(I) > 0) { k_bin_bcast<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + /*s0,*/ s1, s2, s3, + s00 ,s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); } } } diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index ba3d4eeb88..09b6d5db6a 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -7,7 +7,8 @@ template static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, - const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t ne00, const int64_t ne01, + const int64_t ne0203, const uint3 ne02, const int64_t s01, const int64_t s02, const int64_t s03) { const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x); @@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ } const int64_t i01 = blockIdx.y; - const int64_t i02 = blockIdx.z % ne02; - const int64_t i03 = blockIdx.z / ne02; - const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; - const int64_t ib = ibx0 + i00/qk; // block index - const int64_t iqs = (i00%qk)/qr; // quant index - const int64_t iybs = i00 - i00%qk; // y block start index - const int64_t y_offset = qr == 1 ? 1 : qk/2; + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; - // dequantize - float2 v; - dequantize_kernel(vx, ib, iqs, v); + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; - const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; - y[iy0 + 0] = ggml_cuda_cast(v.x); - y[iy0 + y_offset] = ggml_cuda_cast(v.y); + // dequantize + float2 v; + dequantize_kernel(vx, ib, iqs, v); + + const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = ggml_cuda_cast(v.x); + y[iy0 + y_offset] = ggml_cuda_cast(v.y); + } } template @@ -485,9 +490,11 @@ template static void dequantize_block_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { - const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03); + const int64_t ne0203 = ne02*ne03; + const uint3 ne02_fdv = init_fastdiv_values(ne02); + const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535)); dequantize_block<<>> - (vx, y, ne00, ne01, ne02, s01, s02, s03); + (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } template @@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t template static __global__ void convert_unary( - const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, + const int64_t ne0203, const uint3 ne02, const int64_t s01, const int64_t s02, const int64_t s03) { const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; @@ -621,23 +629,29 @@ static __global__ void convert_unary( } const int64_t i01 = blockIdx.y; - const int64_t i02 = blockIdx.z % ne02; - const int64_t i03 = blockIdx.z / ne02; const src_t * x = (const src_t *) vx; - const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; - const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; - y[iy] = ggml_cuda_cast(x[ix]); + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; + + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = (i0203*ne01 + i01)*ne00 + i00; + y[iy] = ggml_cuda_cast(x[ix]); + } } template static void convert_unary_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { - const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03); + const int64_t ne0203 = ne02*ne03; + const uint3 ne02_fdv = init_fastdiv_values(ne02); + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535)); convert_unary<<>> - (vx, y, ne00, ne01, ne02, s01, s02, s03); + (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 8694fd06c7..35735d48b2 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); +#if defined(GGML_USE_HIP) + typedef wmma::fragment frag_a_K; + typedef wmma::fragment frag_a_V; + typedef wmma::fragment frag_b; + typedef wmma::fragment frag_c_KQ; + typedef wmma::fragment frag_c_VKQ; +#else typedef wmma::fragment frag_a_K; typedef wmma::fragment frag_a_V; typedef wmma::fragment frag_b; typedef wmma::fragment frag_c_KQ; typedef wmma::fragment frag_c_VKQ; +#endif constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16( __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; + +#if defined(GGML_USE_HIP) + const _Float16 * K_h_f16 = reinterpret_cast(K_h); + const _Float16 * V_h_f16 = reinterpret_cast(V_h); + _Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ); + _Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ); +#else + const half * K_h_f16 = K_h; + const half * V_h_f16 = V_h; + half * KQ_f16 = KQ; + half * VKQ_f16 = VKQ; +#endif + #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded); } } @@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*(kqar*kqs_padded) + k, + KQ_f16 + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); } } @@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], D_padded, wmma::mem_col_major); } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9e77c231c8..7dc688483a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3640,11 +3640,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud n_fuse++; if (n_fuse > 1) { + ggml_tensor fused_add_node; + memcpy(&fused_add_node, node, sizeof(ggml_tensor)); for (int j = 0; j < n_fuse - 1; ++j) { - node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; } - cgraph->nodes[i + n_fuse - 1]->data = node->data; - ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse); + fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data; + ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse); i += n_fuse - 1; continue; @@ -4834,8 +4836,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: - case GGML_OP_PAD: return ggml_is_contiguous(op->src[0]); + case GGML_OP_PAD: + return true; case GGML_OP_UPSCALE: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_ARANGE: diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu index 660c192e48..31cd00f778 100644 --- a/ggml/src/ggml-cuda/pad.cu +++ b/ggml/src/ggml-cuda/pad.cu @@ -7,7 +7,7 @@ __device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) { return (coord + size) % size; } -static __global__ void pad_f32(const float * src, float * dst, +static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, @@ -34,11 +34,8 @@ static __global__ void pad_f32(const float * src, float * dst, const int64_t i01 = i1 - lp1; const int64_t i02 = i2 - lp2; const int64_t i03 = i3 - lp3; - const int64_t ne02 = ne2 - lp2 - rp2; - const int64_t ne01 = ne1 - lp1 - rp1; - const int64_t ne00 = ne0 - lp0 - rp0; - const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; dst[dst_idx] = src[src_idx]; } else { @@ -57,21 +54,21 @@ static __global__ void pad_f32(const float * src, float * dst, const int64_t i02 = wrap_around(i2 - lp2, ne02); const int64_t i03 = wrap_around(i3 - lp3, ne03); - const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; dst[dst_idx] = src[src_idx]; } } -static void pad_f32_cuda(const float * src, float * dst, +static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, const bool circular, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; dim3 gridDim(num_blocks, ne1, ne2 * ne3); - pad_f32<<>>(src, dst, + pad_f32<<>>(src, s00, s01, s02, s03, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3, circular); } @@ -82,9 +79,10 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); + GGML_TENSOR_UNARY_OP_LOCALS; + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); const int32_t lp0 = ((const int32_t *) (dst->op_params))[0]; const int32_t rp0 = ((const int32_t *) (dst->op_params))[1]; @@ -96,7 +94,12 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int32_t rp3 = ((const int32_t *) (dst->op_params))[7]; const int32_t circular = ((const int32_t *) (dst->op_params))[8]; - pad_f32_cuda(src0_d, dst_d, + const size_t s00 = nb00 / ggml_type_size(src0->type); + const size_t s01 = nb01 / ggml_type_size(src0->type); + const size_t s02 = nb02 / ggml_type_size(src0->type); + const size_t s03 = nb03 / ggml_type_size(src0->type); + + pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (bool) circular, stream); diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 88ed79111a..45a49a5dc2 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -43,10 +43,15 @@ static __device__ void rope_yarn( template static __global__ void rope_norm(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int32_t * pos, const float freq_scale, @@ -59,23 +64,23 @@ static __global__ void rope_norm(const T * x, const int set_rows_stride) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - - int idst = row_dst * ne0 + i0; - const int ix = channel_x*s2 + row_x*s1 + i0; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * s1 + i0; + idst += row_indices[i2] * set_rows_stride; } const auto & store_coaelsced = [&](float x0, float x1) { @@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -110,10 +115,15 @@ static __global__ void rope_norm(const T * x, template static __global__ void rope_neox(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int32_t * pos, const float freq_scale, @@ -126,23 +136,24 @@ static __global__ void rope_neox(const T * x, const int set_rows_stride) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - int idst = row_dst * ne0 + i0 / 2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0 / 2; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * s1 + i0 / 2; + idst += row_indices[i2] * set_rows_stride; } if (i0 >= n_dims) { @@ -152,7 +163,7 @@ static __global__ void rope_neox(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -168,24 +179,42 @@ static __global__ void rope_neox(const T * x, dst[idst + n_dims / 2] = ggml_cuda_cast(x0 * sin_theta + x1 * cos_theta); } -template -static __global__ void rope_multi( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, - const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { - const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); +template +static __global__ void rope_multi(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope) { + const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; if (i0 >= n_dims) { dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; @@ -200,27 +229,24 @@ static __global__ void rope_multi( float theta_base = 0.0; if (is_imrope) { - if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); } else { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); } } else { if (sector < sections.v[0]) { - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); } } @@ -238,37 +264,53 @@ static __global__ void rope_multi( dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } -template -static __global__ void rope_vision( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, - const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * freq_factors, const mrope_sections sections) { +template +static __global__ void rope_vision(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; const int sect_dims = sections.v[0] + sections.v[1]; - const int sec_w = sections.v[1] + sections.v[0]; - const int sector = (i0 / 2) % sect_dims; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; if (sector < sections.v[0]) { const int p = sector; - theta_base = pos[channel_x]*powf(theta_scale, p); - } - else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2] * powf(theta_scale, p); + } else if (sector >= sections.v[0] && sector < sec_w) { const int p = sector - sections.v[0]; - theta_base = pos[channel_x + ne2]*powf(theta_scale, p); + theta_base = pos[i2 + ne02] * powf(theta_scale, p); } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -288,10 +330,15 @@ static __global__ void rope_vision( template static void rope_norm_cuda(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int nr, const int32_t * pos, @@ -304,31 +351,36 @@ static void rope_norm_cuda(const T * x, const int64_t * row_indices, const int set_rows_stride, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } template static void rope_neox_cuda(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int nr, const int32_t * pos, @@ -341,55 +393,92 @@ static void rope_neox_cuda(const T * x, const int64_t * row_indices, const int set_rows_stride, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } -template -static void rope_multi_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); +template +static void rope_multi_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_multi<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { rope_multi<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } } -template -static void rope_vision_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); +template +static void rope_vision_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq) // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE); @@ -398,11 +487,11 @@ static void rope_vision_cuda( if (freq_factors == nullptr) { rope_vision<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } else { rope_vision<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } } @@ -445,6 +534,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + const size_t s03 = src0->nb[3] / ggml_type_size(src0->type); + + const size_t s1 = dst->nb[1] / ggml_type_size(dst->type); + const size_t s2 = dst->nb[2] / ggml_type_size(dst->type); + const size_t s3 = dst->nb[3] / ggml_type_size(dst->type); //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -495,57 +589,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, // compute if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } } else if (is_mrope && !is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_multi_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_multi_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else { GGML_ABORT("fatal error"); } } else if (is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_vision_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_vision_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_vision_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_vision_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } } else { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 4f0a1620fb..54f9986498 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1935,11 +1935,6 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se return false; } - // TODO: add support for non-contigiuos tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { - return false; - } - return true; } @@ -1991,6 +1986,25 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses return true; } +static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (!hex_supported_src0_type(src0->type)) { + return false; + } + if (!hex_supported_dst_type(dst->type)) { + return false; + } + + // TODO: add support for non-contigiuos tensors + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + return false; + } + + return true; +} + static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; @@ -2111,6 +2125,26 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * dst = op; // indices + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (dst->type != GGML_TYPE_I32) { + return false; + } + + if (src0->ne[0] > (16*1024)) { + // reject tensors with huge rows for now + return false; + } + + return true; +} + static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const int32_t * op_params = &op->op_params[0]; @@ -2278,6 +2312,9 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu case GGML_OP_SUB: req->op = HTP_OP_SUB; break; + case GGML_OP_DIV: + req->op = HTP_OP_DIV; + break; default: GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op); break; @@ -2316,6 +2353,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * return n_bufs; } +static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_ARGSORT; + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + template static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { switch (t->op) { @@ -2370,6 +2418,16 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf supported = true; break; + case GGML_OP_SQR: + req->op = HTP_OP_SQR; + supported = true; + break; + + case GGML_OP_SQRT: + req->op = HTP_OP_SQRT; + supported = true; + break; + case GGML_OP_UNARY: if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { req->op = HTP_OP_UNARY_SILU; @@ -2387,6 +2445,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) { req->op = HTP_OP_GLU_SWIGLU_OAI; supported = true; + } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) { + req->op = HTP_OP_GLU_GEGLU; + supported = true; } break; @@ -2411,6 +2472,17 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf return n_bufs; } +static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + req->op = HTP_OP_SUM_ROWS; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); req->op = HTP_OP_ROPE; @@ -2519,6 +2591,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg case GGML_OP_MUL: case GGML_OP_ADD: case GGML_OP_SUB: + case GGML_OP_DIV: ggml_hexagon_dispatch_op>(sess, node, flags); break; case GGML_OP_ADD_ID: @@ -2528,6 +2601,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg case GGML_OP_SCALE: ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + case GGML_OP_SUM_ROWS: + ggml_hexagon_dispatch_op(sess, node, flags); + break; case GGML_OP_UNARY: if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) || (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) { @@ -2536,7 +2616,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg break; case GGML_OP_GLU: if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) || - (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) { + (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) || + (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) { ggml_hexagon_dispatch_op(sess, node, flags); } break; @@ -2564,6 +2645,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_ARGSORT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -2916,6 +3001,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons case GGML_OP_MUL: case GGML_OP_ADD: case GGML_OP_SUB: + case GGML_OP_DIV: supp = ggml_hexagon_supported_binary(sess, op); break; @@ -2928,6 +3014,15 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_unary(sess, op); break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + supp = ggml_hexagon_supported_unary(sess, op); + break; + + case GGML_OP_SUM_ROWS: + supp = ggml_hexagon_supported_sum_rows(sess, op); + break; + case GGML_OP_SOFT_MAX: supp = ggml_hexagon_supported_softmax(sess, op); break; @@ -2943,7 +3038,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons case GGML_OP_GLU: { const auto glu_op = ggml_get_glu_op(op); - if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) { + if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) { supp = ggml_hexagon_supported_activations(sess, op); } break; @@ -2968,6 +3063,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cpy(sess, op); break; + case GGML_OP_ARGSORT: + supp = ggml_hexagon_supported_argsort(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index e8ef203045..2c23b60da3 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include_directories( ${HEXAGON_SDK_ROOT}/incs ${HEXAGON_SDK_ROOT}/incs/stddef + ${CMAKE_CURRENT_SOURCE_DIR}/../../../include ${CMAKE_CURRENT_SOURCE_DIR}/../.. ${CMAKE_CURRENT_SOURCE_DIR}/.. ${CMAKE_CURRENT_SOURCE_DIR} @@ -21,6 +22,7 @@ add_library(${HTP_LIB} SHARED matmul-ops.c binary-ops.c unary-ops.c + sum-rows-ops.c softmax-ops.c act-ops.c rope-ops.c @@ -28,6 +30,7 @@ add_library(${HTP_LIB} SHARED set-rows-ops.c get-rows-ops.c cpy-ops.c + argsort-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index c3daf5adb2..950d836ad3 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -410,7 +410,7 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, // gelu = x * sigmoid(1.702 * x) // current implementation hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0); hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); - hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -516,7 +516,7 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, // silu = x * sigmoid(x) hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0); - hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -541,6 +541,143 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static const float GELU_COEF_A = 0.044715f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, + const struct htp_tensor * src1, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * src1_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_act_preamble3; + + size_t src0_row_size = nb01; + size_t src1_row_size = nb11; + size_t dst_row_size = nb1; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const bool src1_valid = src1->ne[0]; + const int nc = (src1_valid) ? ne00 : ne00 / 2; + if (!src1_valid) { + const int32_t swapped = op_params[1]; + data_src1 = data_src0; + src1_row_size = src0_row_size; + + const size_t nc_in_bytes = nc * SIZEOF_FP32; + data_src0 += swapped ? nc_in_bytes : 0; + data_src1 += swapped ? 0 : nc_in_bytes; + } + + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t src1_spad_half_size = src1_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + if (BLOCK == 0) { + FARF(ERROR, + "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); + return; + } + + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)), + src1_row_size_aligned, src1_row_size, block_size); + } + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = (float *) dma_queue_pop(dma_queue).dst; + + for (uint32_t ib = 0; ib < block_size; ib++) { + const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float))); + const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float))); + uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float))); + + // geglu tanh implementation + // geglu(x, g) = gelu(x) * g + // gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))) + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A + hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI + hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res) + hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f + hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g + } + + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, + dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)), + src1_row_size_aligned, src1_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void unary_silu_f32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, @@ -559,6 +696,12 @@ static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) { &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); } +static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, + &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + static int execute_op_activations_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; @@ -593,6 +736,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { act_op_func = unary_gelu_f32; op_type = "gelu-f32"; break; + + case HTP_OP_GLU_GEGLU: + act_op_func = glu_geglu_f32; + op_type = "geglu-f32"; + break; default: FARF(ERROR, "Unsupported activations Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT; diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c new file mode 100644 index 0000000000..a4cee980be --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -0,0 +1,281 @@ +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "ggml.h" + +#include "hvx-utils.h" +#include "hex-dma.h" + +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +struct htp_argsort_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; +}; + +static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y) +{ + const HVX_Vector one = Q6_V_vsplat_R(1); + const HVX_Vector zero = Q6_V_vzero(); + + HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y); + HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero); + HVX_Vector sum = hvx_vec_reduce_sum_i32(matches); + return hvx_vec_get_i32(sum) == 32; +} + +// Sorts values and mirrors swaps to indices. +static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + while (i <= j) { + // Vectorized scan for i + while (i <= j) { + // Check if we have at least one full vector + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(pivot_vec, vals_vec)) { + // If all elements are < pivot, we can skip this whole block + i += 32; + continue; + } + } + + // Scalar fallback / cleanup + if (values[i] < pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j + while (i <= j) { + if (j - 32 >= i) { + // Load 32 elements ending at j. + // Since we want `values[j] > pivot`, let's load from j-31 to j. + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(vals_vec, pivot_vec)) { + j -= 32; + continue; + } + } + + if (values[j] > pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_asc(values, indices, left, j); + if (i < right) quicksort_values_indices_asc(values, indices, i, right); +} + +static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + + while (i <= j) { + // Vectorized scan for i (values[i] > pivot) + while (i <= j) { + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(vals_vec, pivot_vec)) { + i += 32; + continue; + } + } + + if (values[i] > pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j (values[j] < pivot) + while (i <= j) { + if (j - 32 >= i) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(pivot_vec, vals_vec)) { + j -= 32; + continue; + } + } + + if (values[j] < pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_desc(values, indices, left, j); + if (i < right) quicksort_values_indices_desc(values, indices, i, right); +} + +static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { + struct htp_argsort_context * actx = (struct htp_argsort_context *)data; + struct htp_ops_context * octx = actx->octx; + + // Unpack context + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + // Scratchpad memory + uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i; + + // Dimensions + uint32_t ne00 = src0->ne[0]; + uint32_t ne01 = src0->ne[1]; + uint32_t ne02 = src0->ne[2]; + uint32_t ne03 = src0->ne[3]; + + uint32_t nb01 = src0->nb[1]; + //uint32_t nb02 = src0->nb[2]; + //uint32_t nb03 = src0->nb[3]; + + uint32_t nb1 = dst->nb[1]; + //uint32_t nb2 = dst->nb[2]; + //uint32_t nb3 = dst->nb[3]; + + // Sort order + enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0]; + + // Rows to process + uint32_t total_rows = ne01 * ne02 * ne03; + uint32_t rows_per_thread = actx->nrows_per_thread; + uint32_t start_row = rows_per_thread * i; + uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); + + // Scratchpad layout: + // We need space for one row of float data (values) and one row of int32 indices. + // values: ne00 * sizeof(float) + // indices: ne00 * sizeof(int32_t) + // Padded to 128 bytes. + + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + float * values_buf = (float *) spad; + int32_t * indices_buf = (int32_t *) (spad + values_size); + + for (uint32_t r = start_row; r < end_row; r++) { + uint32_t src_offset = r * nb01; + uint32_t dst_offset = r * nb1; + + uint8_t * src_ptr = (uint8_t *) src0->data + src_offset; + uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset; + + hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); + hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00); + + // Initialize indices + for (uint32_t j = 0; j < ne00; j++) { + indices_buf[j] = j; + } + + // Sort values and mirror swaps to indices + if (order == GGML_SORT_ORDER_ASC) { + quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1); + } else { + quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1); + } + + // Copy indices back to DDR + hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00); + } +} + +int op_argsort(struct htp_ops_context * octx) { + // Check supported types + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // Allocate scratchpad + // We need 1 row of float + 1 row of int32 per thread. + uint32_t ne00 = octx->src0.ne[0]; + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128); + size_t spad_per_thread = values_size + indices_size; + + // Make sure we round up to 256 for alignment requirements + spad_per_thread = hex_round_up(spad_per_thread, 256); + + size_t total_spad_size = spad_per_thread * octx->n_threads; + + if (octx->ctx->vtcm_size < total_spad_size) { + FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.size = total_spad_size; + octx->src0_spad.size_per_thread = spad_per_thread; + + FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", + octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3], + octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], + octx->src0.data, octx->dst.data); + + uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + uint32_t n_jobs = MIN(total_rows, octx->n_threads); + + struct htp_argsort_context actx; + actx.octx = octx; + actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs; + + // Run jobs + worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index de22afe460..00dbcf8798 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -17,15 +17,37 @@ #include "htp-msg.h" #include "htp-ops.h" -typedef void (*hvx_elemwise_f32_func)(uint8_t * data_dst, const uint8_t * src0, const uint8_t * src1, const uint32_t num_elems); +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif -static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 }; -static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f32_aa, hvx_sub_f32_aa }; +// Context for binary operations +struct htp_binary_context { + struct htp_ops_context * octx; + struct fastdiv_values dim1_div; + struct fastdiv_values dim2_div; + struct fastdiv_values dim12_div; + + struct fastdiv_values src1_dim1_div; // ne11 + struct fastdiv_values src1_dim2_div; // ne12 + struct fastdiv_values src1_dim3_div; // ne13 + + uint32_t nrows_per_thread; + bool split_at_ne01; + bool split_at_ne02; + + // Precomputed values + uint32_t block_max; + size_t src0_row_size_aligned; + size_t src1_row_size_aligned; + size_t dst_row_size_aligned; + uint32_t src1_fetch_rows; // 1 or block_max + uint32_t src1_dma_stride; // 0 or stride +}; #define htp_binary_preamble \ const struct htp_tensor * src0 = &octx->src0; \ const struct htp_tensor * src1 = &octx->src1; \ - const struct htp_tensor * src2 = &octx->src2; \ struct htp_tensor * dst = &octx->dst; \ \ const uint32_t ne00 = src0->ne[0]; \ @@ -38,266 +60,696 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f3 const uint32_t ne12 = src1->ne[2]; \ const uint32_t ne13 = src1->ne[3]; \ \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ const uint32_t nb01 = src0->nb[1]; \ const uint32_t nb02 = src0->nb[2]; \ const uint32_t nb03 = src0->nb[3]; \ \ - const uint32_t nb10 = src1->nb[0]; \ const uint32_t nb11 = src1->nb[1]; \ const uint32_t nb12 = src1->nb[2]; \ const uint32_t nb13 = src1->nb[3]; \ \ - const uint32_t nb0 = dst->nb[0]; \ const uint32_t nb1 = dst->nb[1]; \ const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; \ - \ - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const uint32_t nb3 = dst->nb[3]; -static void binary_job_f32_per_thread(struct htp_ops_context * octx, - uint8_t * spad_data, - uint32_t nth, - uint32_t ith, - enum htp_op op) { - htp_binary_preamble; +static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, + uint32_t ne01, uint32_t ne02) { + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; - const size_t src0_row_size = nb01; - const size_t src1_row_size = nb11; - const size_t dst_row_size = nb1; + uint32_t rows_left = end_row - ir; + uint32_t block_limit = rows_left; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows - - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - - // no work for this thread - if (src0_start_row >= src0_end_row) { - return; + if (bctx->split_at_ne01) { + block_limit = MIN(block_limit, ne01 - i01); + } + if (bctx->split_at_ne02) { + uint32_t rows_in_plane = (ne02 * ne01) - rem; + block_limit = MIN(block_limit, rows_in_plane); } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - int is_aligned = 1; - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) || - (0 == hex_is_aligned((void *) dst->data, VLEN))) { - is_aligned = 0; - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; - } - - hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op]; - - uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size); - - const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size); - uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size); - - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - - const uint32_t ne02_ne01 = ne02 * ne01; - - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - const uint32_t i03 = fastdiv(ir, &octx->src0_div21); - const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); - const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); - - const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3); - const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2); - const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1); - - const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size; - - if (ir + 1 < src0_end_row) { - hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1); - if (src1_row_size == src0_row_size) { - hex_l2fetch(src1_ptr, src1_row_size, src1_row_size, 1); - } - } - - const uint32_t nr0 = ne00 / ne10; - if (nr0 > 1) { - if ((1 == is_aligned) && (nr0 == ne00)) { - hvx_splat_f32_a(spad_data_th, *(float *) src1_ptr, nr0); - } else { - for (uint32_t r = 0; r < nr0; r++) { - memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11); - } - } - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, ne00); - } else { - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00); - } - - src0_ptr += src0_row_size; - dst_ptr += dst_row_size; - } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, - ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + return MIN(bctx->block_max, block_limit); } -static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, - uint8_t * spad_data, - uint32_t nth, - uint32_t ith, - hvx_elemwise_f32_func func_HVX) { +// Macro for scalar op switch +#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \ + default: break; \ + } + +// Macro for vector op switch (All Aligned) +#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } + +// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned) +#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ + } + +// Macro for vector op switch (All Unaligned - generic loop used in element repeat) +#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ + } + +// 1. Scalar src1 (ne10 == 1) +static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const size_t src0_row_size = nb01; - const size_t src1_row_size = nb11; - const size_t dst_row_size = nb1; + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; - // no work for this thread - if (src0_start_row >= src0_end_row) { - return; + // Preamble + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + // Main loop + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - const uint32_t ne02_ne01 = ne02 * ne01; - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - // src0 indices - const uint32_t i03 = fastdiv(ir, &octx->src0_div21); - const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); - const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; - // src1 indices - const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]); - assert(i11 >= 0 && i11 < ne11); + // src1 indices (broadcast/repeat) + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div); - float * restrict dst_ptr = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1); - const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01); - const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11); + uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint32_t s1_stride = (ne11 == 1) ? 0 : nb11; - if (ir + 1 < src0_end_row) { - hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1); - if (src1_row_size == src0_row_size) { - hex_l2fetch(src1_ptr + ne10, src1_row_size, src1_row_size, 1); - } + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + float val = *(float *)src1_ptr; + src1_ptr += s1_stride; + COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00); } - const uint32_t nr0 = ne00 / ne10; - if (nr0 > 1) { - for (uint32_t r = 0; r < nr0; r++) { - memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10); - } - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data, ne00); - } else { - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00); + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; } + ir += current_block_size; } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], - src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], - dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + dma_queue_flush(q); } -static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; +// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast +static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; - switch (octx->op) { - case HTP_OP_MUL: - case HTP_OP_ADD: - case HTP_OP_SUB: - binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op); - break; + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; - case HTP_OP_ADD_ID: - binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32); - break; + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - default: - FARF(ERROR, "Unknown Binary Op %u", octx->op); - break; + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t src1_spad_half = octx->src1_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint32_t i13 = (ne13 == 1) ? 0 : i03; + uint32_t i12 = (ne12 == 1) ? 0 : i02; + uint32_t i11 = (ne11 == 1) ? 0 : i01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + } + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + + uint32_t p13 = (ne13 == 1) ? 0 : p03; + uint32_t p12 = (ne12 == 1) ? 0 : p02; + uint32_t p11 = (ne11 == 1) ? 0 : p01; + + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; + + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size); + + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1) +static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + void * s1_ptr = (void *) src1_spad; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + } + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 4. Vector Complex (ne10 == ne00, complex broadcast) +static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div); + + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + // Read src1 from DDR (unaligned) + COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00); + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 5. Element Repeat (ne10 != ne00) +static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div); + + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + // Repeat src1 row + for (uint32_t c = 0; c < ne00; c += ne10) { + uint32_t len = MIN(ne10, ne00 - c); + // Use UUU for speed and simplicity + COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len); + } + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 6. ADD_ID (src1 gathered via src2 indices) +static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src2 = &octx->src2; + struct htp_tensor * dst = &octx->dst; + + const uint32_t ne00 = src0->ne[0]; + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + const uint32_t ne03 = src0->ne[3]; + const uint32_t ne11 = src1->ne[1]; // for bounds check + + const uint32_t nb01 = src0->nb[1]; + const uint32_t nb02 = src0->nb[2]; + const uint32_t nb03 = src0->nb[3]; + const uint32_t nb11 = src1->nb[1]; // src1 row stride + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; // linear within block since we split at ne01 + + const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]); + + uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11; + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00); + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); } static int execute_op_binary_f32(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; struct htp_tensor * dst = &octx->dst; - worker_callback_t binary_op_func; - const char * op_type = NULL; - - switch (octx->op) { - case HTP_OP_MUL: - binary_op_func = binary_job_dispatcher_f32; - op_type = "mul-f32"; - break; - - case HTP_OP_ADD: - binary_op_func = binary_job_dispatcher_f32; - op_type = "add-f32"; - break; - - case HTP_OP_SUB: - binary_op_func = binary_job_dispatcher_f32; - op_type = "sub-f32"; - break; - - case HTP_OP_ADD_ID: - binary_op_func = binary_job_dispatcher_f32; - op_type = "add-id-f32"; - break; - - default: - FARF(ERROR, "Unsupported binary-Op %u\n", octx->op); - return HTP_STATUS_NO_SUPPORT; - } - - const int n_threads = octx->n_threads; + const uint32_t n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; - const size_t src0_row_size = src0->nb[1]; - const size_t src1_row_size = src1->nb[1]; - const size_t dst_row_size = dst->nb[1]; + // Use packed row sizes for VTCM allocation + const size_t src0_row_size = src0->ne[0] * sizeof(float); + const size_t src1_row_size = src1->ne[0] * sizeof(float); + const size_t dst_row_size = dst->ne[0] * sizeof(float); - // VTCM scratchpads for all tensors - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + // Align to VLEN + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); - size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + bool is_add_id = (octx->op == HTP_OP_ADD_ID); + bool is_scalar = !is_add_id && (src1->ne[0] == 1); - FARF(HIGH, - "%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, - octx->dst_spad.size); + // Determine which kernel we will use to alloc memory and dispatch + bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] && + (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && + (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && + (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, - octx->ctx->vtcm_size, spad_size); + bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); + bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]); + bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]); + + size_t spad_row_total; + if (is_scalar) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } else if (is_row_bcast) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } else if (use_vector_same) { + spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned); + } else if (is_add_id) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly + } else { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } + + size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total); + // Adjust for static src1 in row_bcast case + if (is_row_bcast) { + size_t needed_static = src1_row_size_aligned; + if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL; + size_t avail = octx->ctx->vtcm_size - needed_static; + rows_per_buffer = avail / (n_threads * spad_row_total); + } + + if (rows_per_buffer < 1) { + FARF(ERROR, "binary-f32: VTCM too small\n"); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned; + octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned; + + if (is_scalar || use_complex || use_repeat || is_add_id) { + octx->src1_spad.size_per_thread = 0; + } else if (is_row_bcast) { + octx->src1_spad.size_per_thread = 0; + } else { + octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned; + } + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + if (is_row_bcast) { + octx->src1_spad.size = src1_row_size_aligned; + } else { + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + } + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) { return HTP_STATUS_VTCM_TOO_SMALL; } @@ -305,39 +757,71 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - - octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]); - octx->src0_div3 = init_fastdiv_values(src0->ne[3]); - octx->src0_div2 = init_fastdiv_values(src0->ne[2]); - octx->src0_div1 = init_fastdiv_values(src0->ne[1]); - - octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]); - octx->src1_div3 = init_fastdiv_values(src1->ne[3]); - octx->src1_div2 = init_fastdiv_values(src1->ne[2]); - octx->src1_div1 = init_fastdiv_values(src1->ne[1]); - - worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs); + if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + return HTP_STATUS_OK; } - return err; + uint32_t n_jobs = MIN(n_threads, src0_nrows); + + dma_queue * q = octx->ctx->dma[0]; + if (is_row_bcast) { + dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1); + } + + struct htp_binary_context bctx; + bctx.octx = octx; + bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + bctx.block_max = rows_per_buffer; + bctx.src0_row_size_aligned = src0_row_size_aligned; + bctx.src1_row_size_aligned = src1_row_size_aligned; + bctx.dst_row_size_aligned = dst_row_size_aligned; + + bctx.dim1_div = init_fastdiv_values(src0->ne[1]); + bctx.dim2_div = init_fastdiv_values(src0->ne[2]); + bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); + + bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); + bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); + bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); + + bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]); + bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); + + bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]); + bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); + + bctx.split_at_ne01 = (src0->ne[2] > 1) && + ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); + + bctx.split_at_ne02 = (src0->ne[3] > 1) && + ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2); + + // Precompute specific kernel parameters + if (use_vector_same) { + bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1]; + bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer; + } + + worker_callback_t worker_func; + if (is_add_id) worker_func = binary_job_add_id; + else if (is_scalar) worker_func = binary_job_scalar; + else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; + else if (use_vector_same) worker_func = binary_job_vector_same_shape; + else if (use_complex) worker_func = binary_job_vector_complex; + else worker_func = binary_job_element_repeat; + + if (is_row_bcast) { + dma_queue_pop(q); + } + + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs); + + return HTP_STATUS_OK; } int op_binary(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - - switch (octx->src0.type) { - case HTP_TYPE_F32: - err = execute_op_binary_f32(octx); - break; - - default: - err = HTP_STATUS_NO_SUPPORT; - break; + if (octx->src0.type == HTP_TYPE_F32) { + return execute_op_binary_f32(octx); } - - return err; + return HTP_STATUS_NO_SUPPORT; } diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index f49e8ee447..25403bb112 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -42,32 +42,36 @@ enum htp_data_type { HTP_TYPE_COUNT }; -// These values are manually translated over to HTP -// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!! +// Do not reorder first 4 (used as an index) enum htp_op { - HTP_OP_MUL = 0, - HTP_OP_ADD = 1, - HTP_OP_SUB = 2, - HTP_OP_DIV = 3, - HTP_OP_MUL_MAT = 4, - HTP_OP_MUL_MAT_ID = 5, - HTP_OP_RMS_NORM = 6, - HTP_OP_UNARY_SILU = 7, - HTP_OP_UNARY_GELU = 8, - HTP_OP_GLU_SWIGLU = 9, - HTP_OP_GLU_SWIGLU_OAI = 10, - HTP_OP_SOFTMAX = 11, - HTP_OP_ADD_ID = 12, - HTP_OP_ROPE = 13, - HTP_OP_FLASH_ATTN_EXT = 14, - HTP_OP_SET_ROWS = 15, - HTP_OP_SCALE = 16, - HTP_OP_GET_ROWS = 17, - HTP_OP_CPY = 18, + HTP_OP_MUL = 0, + HTP_OP_ADD = 1, + HTP_OP_SUB = 2, + HTP_OP_DIV = 3, + HTP_OP_MUL_MAT, + HTP_OP_MUL_MAT_ID, + HTP_OP_RMS_NORM, + HTP_OP_UNARY_SILU, + HTP_OP_UNARY_GELU, + HTP_OP_GLU_SWIGLU, + HTP_OP_GLU_SWIGLU_OAI, + HTP_OP_GLU_GEGLU, + HTP_OP_SOFTMAX, + HTP_OP_ADD_ID, + HTP_OP_ROPE, + HTP_OP_FLASH_ATTN_EXT, + HTP_OP_SET_ROWS, + HTP_OP_GET_ROWS, + HTP_OP_SCALE, + HTP_OP_CPY, + HTP_OP_ARGSORT, + HTP_OP_SQR, + HTP_OP_SQRT, + HTP_OP_SUM_ROWS, INVALID }; -static inline size_t htp_type_block_size(uint32_t t) { +static inline size_t htp_t_block_size(uint32_t t) { switch (t) { case HTP_TYPE_F32: return 1; @@ -103,22 +107,6 @@ static inline size_t htp_type_nbytes(uint32_t t) { return 0; } -static const char * htp_type_name(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return "fp32"; - case HTP_TYPE_F16: - return "fp16"; - case HTP_TYPE_Q4_0: - return "q4_0"; - case HTP_TYPE_Q8_0: - return "q8_0"; - case HTP_TYPE_MXFP4: - return "mxfp4"; - } - return 0; -} - // Internal types #define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) #define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 602a2775a4..f1ad24dbfa 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -64,25 +64,12 @@ struct htp_ops_context { struct fastdiv_values broadcast_rv2; struct fastdiv_values broadcast_rv3; - struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1 - struct fastdiv_values mm_div_ne1; // fastdiv values for ne1 - struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02 - struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03 - struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 - struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01 - struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02 - struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03 - - struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00 - struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01 - struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02 - uint32_t flags; }; @@ -90,6 +77,7 @@ int op_matmul(struct htp_ops_context * octx); int op_matmul_id(struct htp_ops_context * octx); int op_binary(struct htp_ops_context * octx); int op_unary(struct htp_ops_context * octx); +int op_sum_rows(struct htp_ops_context * octx); int op_activations(struct htp_ops_context * octx); int op_softmax(struct htp_ops_context * octx); int op_add_id(struct htp_ops_context * octx); @@ -98,5 +86,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx); int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); +int op_argsort(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-arith.h b/ggml/src/ggml-hexagon/htp/hvx-arith.h index 3449739a4f..2577cdd041 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-arith.h +++ b/ggml/src/ggml-hexagon/htp/hvx-arith.h @@ -46,127 +46,76 @@ #define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) #endif -// ADD variants +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \ +static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ +} \ -static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD); +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL) + +// Dispatcher logic +#define HVX_BINARY_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128)) { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \ + else OP_NAME##_aau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \ + else OP_NAME##_auu(dst, src0, src1, num_elems); \ + } \ + } else { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \ + else OP_NAME##_uau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \ + else OP_NAME##_uuu(dst, src0, src1, num_elems); \ + } \ + } \ } -static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD); -} - -static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD); -} - -static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD); -} - -// SUB variants - -static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB); -} - -static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB); -} - -static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB); -} - -static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB); -} - -// MUL variants - -static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL); -} - -static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL); -} - -static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL); -} - -static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL); -} - -// Dispatchers - -static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) { - hvx_add_f32_aa(dst, src0, src1, num_elems); - } else { - hvx_add_f32_au(dst, src0, src1, num_elems); - } - } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { - hvx_add_f32_ua(dst, src0, src1, num_elems); - } else { - hvx_add_f32_uu(dst, src0, src1, num_elems); - } -} - -static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) { - hvx_sub_f32_aa(dst, src0, src1, num_elems); - } else { - hvx_sub_f32_au(dst, src0, src1, num_elems); - } - } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { - hvx_sub_f32_ua(dst, src0, src1, num_elems); - } else { - hvx_sub_f32_uu(dst, src0, src1, num_elems); - } -} - -static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) { - hvx_mul_f32_aa(dst, src0, src1, num_elems); - } else { - hvx_mul_f32_au(dst, src0, src1, num_elems); - } - } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { - hvx_mul_f32_ua(dst, src0, src1, num_elems); - } else { - hvx_mul_f32_uu(dst, src0, src1, num_elems); - } -} +HVX_BINARY_DISPATCHER(hvx_add_f32) +HVX_BINARY_DISPATCHER(hvx_sub_f32) +HVX_BINARY_DISPATCHER(hvx_mul_f32) // Mul-Mul Optimized - static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) { assert((unsigned long) dst % 128 == 0); assert((unsigned long) src0 % 128 == 0); @@ -443,6 +392,68 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * } } +// +// Square +// + +#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + } \ + if (nloe) { \ + HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128)) { + if (hex_is_aligned((void *) src, 128)) { + hvx_sqr_f32_aa(dst, src, num_elems); + } else { + hvx_sqr_f32_au(dst, src, num_elems); + } + } else { + if (hex_is_aligned((void *) src, 128)) { + hvx_sqr_f32_ua(dst, src, num_elems); + } else { + hvx_sqr_f32_uu(dst, src, num_elems); + } + } +} + #undef HVX_OP_ADD #undef HVX_OP_SUB #undef HVX_OP_MUL @@ -453,5 +464,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * #undef hvx_scalar_loop_body #undef HVX_OP_MIN_SCALAR #undef HVX_OP_CLAMP_SCALAR +#undef DEFINE_HVX_BINARY_OP_VARIANTS +#undef HVX_BINARY_DISPATCHER #endif // HVX_ARITH_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index ffa6e18e64..12a1b7f128 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) { return x; } +static inline int32_t hvx_vec_get_i32(HVX_Vector v) { + int32_t __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 4, v); + return x; +} + static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { // abs by clearing the fp16 sign bit HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index 6b617b7617..ae0dbed030 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr dst_type * restrict vdst = (dst_type *) dst; \ src_type * restrict vsrc = (src_type *) src; \ \ - const HVX_Vector zero = Q6_V_vsplat_R(0); \ - \ const uint32_t elem_size = sizeof(__fp16); \ const uint32_t epv = 128 / elem_size; \ const uint32_t nvec = n / epv; \ diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h new file mode 100644 index 0000000000..7dae012e0e --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -0,0 +1,116 @@ +#ifndef HVX_DIV_H +#define HVX_DIV_H + +#include + +#include +#include +#include +#include +#include + +#include "hvx-base.h" +#include "hex-utils.h" +#include "hvx-inverse.h" +#include "hvx-arith.h" + +#if __HVX_ARCH__ < 79 +#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ + HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ + HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \ + } \ + } while(0) + +// 3-letter suffix variants +static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src0 % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src0 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src0 % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src0 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128)) { + if (hex_is_aligned((void *) src0, 128)) { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems); + else hvx_div_f32_aau(dst, src0, src1, num_elems); + } else { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems); + else hvx_div_f32_auu(dst, src0, src1, num_elems); + } + } else { + if (hex_is_aligned((void *) src0, 128)) { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems); + else hvx_div_f32_uau(dst, src0, src1, num_elems); + } else { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems); + else hvx_div_f32_uuu(dst, src0, src1, num_elems); + } + } +} + +#undef HVX_OP_MUL + +#endif // HVX_DIV_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h index 1b4aaff0c9..095193277e 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +++ b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h @@ -91,6 +91,27 @@ static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) { } \ } while(0) +#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t epv = 128 / sizeof(float); \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \ + } \ + if (nloe) { \ + HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \ + } \ + } while(0) + static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); @@ -111,4 +132,10 @@ static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * re hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); } +static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + #endif /* HVX_SIGMOID_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-sqrt.h b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h index 28ee9f68d3..e31a1006d2 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +++ b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h @@ -12,11 +12,17 @@ #define RSQRT_ONE_HALF 0x3f000000 // 0.5 #define RSQRT_THREE_HALVES 0x3fc00000 // 1.5 +#if __HVX_ARCH__ < 79 +#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) { //Algorithm : // x2 = input*0.5 // y = * (long *) &input - // y = 0x5f3759df - (y>>2) + // y = 0x5f3759df - (y>>1) // y = y*(threehalfs - x2*y*y) HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST); @@ -57,4 +63,64 @@ static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) { return Q6_Vsf_equals_Vqf32(temp); } +// Compute sqrt(x) as x*inv_sqrt(x) +#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ + HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ + vdst[i] = sqrt_res; \ + } \ + if (nloe) { \ + HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ + HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \ + } \ + } while(0) + +static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) { + if ((unsigned long) dst % 128 == 0) { + if ((unsigned long) src % 128 == 0) { + hvx_sqrt_f32_aa(dst, src, num_elems); + } else { + hvx_sqrt_f32_au(dst, src, num_elems); + } + } else { + if ((unsigned long) src % 128 == 0) { + hvx_sqrt_f32_ua(dst, src, num_elems); + } else { + hvx_sqrt_f32_uu(dst, src, num_elems); + } + } +} + #endif /* HVX_SQRT_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 7b79a5ea32..a518ad3733 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -12,6 +12,7 @@ #include "hvx-sigmoid.h" #include "hvx-sqrt.h" #include "hvx-arith.h" +#include "hvx-div.h" #include "hvx-base.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e28a67a95d..92a1422896 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) { // otherwise we'll release it once we're done with the current Op. if (ctx->vtcm_inuse) { - ctx->vtcm_needs_release = false; + ctx->vtcm_needs_release = true; return 0; } @@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_argsort(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { struct dspqueue_buffer rsp_bufs[1]; @@ -679,6 +718,45 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_sum_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_activations_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -951,6 +1029,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { case HTP_OP_MUL: case HTP_OP_ADD: case HTP_OP_SUB: + case HTP_OP_DIV: if (n_bufs != 3) { FARF(ERROR, "Bad binary-req buffer list"); continue; @@ -968,6 +1047,25 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_unary_req(ctx, &req, bufs); break; + case HTP_OP_SQR: + case HTP_OP_SQRT: + if (n_bufs != 2) { + FARF(ERROR, "Bad unary-req buffer list"); + continue; + } + + proc_unary_req(ctx, &req, bufs); + break; + + case HTP_OP_SUM_ROWS: + if (n_bufs != 2) { + FARF(ERROR, "Bad unary-req buffer list"); + continue; + } + + proc_sum_rows_req(ctx, &req, bufs); + break; + case HTP_OP_UNARY_SILU: case HTP_OP_UNARY_GELU: if (n_bufs != 2) { @@ -980,6 +1078,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { case HTP_OP_GLU_SWIGLU: case HTP_OP_GLU_SWIGLU_OAI: case HTP_OP_SOFTMAX: + case HTP_OP_GLU_GEGLU: if ((n_bufs != 2) && (n_bufs != 3)) { FARF(ERROR, "Bad act-req buffer list"); continue; @@ -1035,6 +1134,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_cpy_req(ctx, &req, bufs); break; + case HTP_OP_ARGSORT: + if (n_bufs != 2) { + FARF(ERROR, "Bad argsort-req buffer list"); + continue; + } + proc_argsort_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index d251eeed33..c360abe8da 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -23,10 +23,30 @@ #define MM_SPAD_SRC1_NROWS 16 #define MM_SPAD_DST_NROWS 2 -struct htp_matmul_type { +struct htp_matmul_context { const char * type; - void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); - void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy); + struct htp_ops_context * octx; + + void (*vec_dot_1x1)(const int n, float * restrict s0, + const void * restrict vx0, + const void * restrict vy0); + + void (*vec_dot_2x1)(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0); + + void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1); + + // Precomputed values + uint32_t src0_nrows_per_thread; + uint32_t src1_nrows_per_thread; + + struct fastdiv_values mm_div_ne12_ne1; + struct fastdiv_values mm_div_ne1; + struct fastdiv_values mm_div_r2; + struct fastdiv_values mm_div_r3; }; // vdelta control to replicate first 4x fp32 values across lanes @@ -122,6 +142,7 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { HVX_Vector v6_7 = vptr[3]; // ... const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 @@ -133,15 +154,14 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 // Convert uint4 to int4 (i.e. x - 8) - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - v0 = Q6_Vb_vsub_VbVb(v0, i8); - v1 = Q6_Vb_vsub_VbVb(v1, i8); - v2 = Q6_Vb_vsub_VbVb(v2, i8); - v3 = Q6_Vb_vsub_VbVb(v3, i8); - v4 = Q6_Vb_vsub_VbVb(v4, i8); - v5 = Q6_Vb_vsub_VbVb(v5, i8); - v6 = Q6_Vb_vsub_VbVb(v6, i8); - v7 = Q6_Vb_vsub_VbVb(v7, i8); + v0 = Q6_Vb_vsub_VbVb(v0, i8); + v1 = Q6_Vb_vsub_VbVb(v1, i8); + v2 = Q6_Vb_vsub_VbVb(v2, i8); + v3 = Q6_Vb_vsub_VbVb(v3, i8); + v4 = Q6_Vb_vsub_VbVb(v4, i8); + v5 = Q6_Vb_vsub_VbVb(v5, i8); + v6 = Q6_Vb_vsub_VbVb(v6, i8); + v7 = Q6_Vb_vsub_VbVb(v7, i8); HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; return r; @@ -156,6 +176,7 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) HVX_Vector v6_7 = vptr[3]; // ... const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 @@ -166,15 +187,14 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; - v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); - v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); - v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); - v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); - v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); - v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); + v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); + v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); + v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); + v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); + v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); + v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; return r; @@ -196,46 +216,6 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { return r; } -static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0 = vptr[0]; // first 64 vals - HVX_Vector v1 = vptr[1]; // second 64 vals - HVX_Vector v2 = vptr[2]; // third 64 vals - HVX_Vector v3 = vptr[3]; // forth 64 vals - - HVX_Vector_x4 r = { v0, v1, v2, v3 }; - return r; -} - -static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) { - const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr; - - HVX_VectorPair v0 = vptr[0]; // first 64 vals - HVX_VectorPair v1 = vptr[1]; // second 64 vals - HVX_VectorPair v2 = vptr[2]; // third 64 vals - HVX_VectorPair v3 = vptr[3]; // forth 64 vals - - HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero()); - HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero()); - HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero()); - HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero()); - HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero()); - HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero()); - HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero()); - HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero()); - - HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo)); - HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo)); - HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo)); - HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo)); - - // vcombine does a shuffle, use vdeal to undo - - HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) }; - return r; -} - // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). // Accumulate each block into a single int32 value. // Return a single HVX vector with 32x int32 accumulators. @@ -300,26 +280,26 @@ static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, return hvx_vec_rmpy_x8_n(x, y, 1024); } -static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -372,36 +352,34 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -468,13 +446,143 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; @@ -486,11 +594,11 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -543,36 +651,34 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (qf32) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -639,16 +745,143 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_mxfp4x4x2_q8x4x2(const int n, - float * restrict s, - const void * restrict vx, - const void * restrict vy) { +static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q8_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; @@ -660,11 +893,11 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -747,36 +980,34 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -879,10 +1110,180 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -912,14 +1313,12 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f16_aa_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { - const HVX_Vector * restrict x0 = (const HVX_Vector *) vx; - const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size); - const HVX_Vector * restrict y = (const HVX_Vector *) vy; +static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16; @@ -953,10 +1352,86 @@ static void vec_dot_f16_f16_aa_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1)); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP16; + uint32_t nloe = n % VLEN_FP16; + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_hf = x0[i]; + HVX_Vector r1_hf = x1[i]; + HVX_Vector c0_hf = y0[i]; + HVX_Vector c1_hf = y1[i]; + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); + HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); + HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); + HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); + + HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); + HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); + HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); + HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + + HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]); + + HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); + HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); + HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); + HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); + + HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); + HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); + HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); + HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_UVector * restrict x = (const HVX_UVector *) vx; const HVX_UVector * restrict y = (const HVX_UVector *) vy; @@ -986,7 +1461,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) { +static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; @@ -1083,14 +1558,16 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -#define htp_matmul_preamble \ - htp_matmul_tensors_preamble; \ - dma_queue *dma_queue = octx->ctx->dma[ith]; \ - uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; +#define htp_matmul_preamble \ + struct htp_matmul_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + htp_matmul_tensors_preamble; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; // *** matmul with support for 4d tensors and full broadcasting -static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; uint64_t t1, t2; @@ -1136,13 +1613,13 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { - const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1); - const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1); + const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1); + const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1); const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); // broadcast src0 into src1 - const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3); - const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2); + const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3); + const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2); const uint32_t i1 = i11; const uint32_t i2 = i12; @@ -1155,7 +1632,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { const uint8_t * restrict src0_row = src0_base + ir0 * nb01; - mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col); } } } @@ -1170,7 +1647,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx } // src1 tensor is already in VTCM spad -static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows @@ -1195,7 +1672,7 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; uint8_t * restrict src1_data = src1_spad->data; @@ -1219,11 +1696,21 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - #pragma unroll(2) - for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + // Process src1 columns in pairs (2×2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); + float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); + } + + // Handle remaining src1 rows (fallback to 2×1) + for (; ir1 < src1_nrows; ++ir1) { const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col); } // Prefetch next (n + spad_nrows) row @@ -1247,20 +1734,20 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } // q8x4x2 src1 tensor is already in VTCM spad -static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; const uint32_t src0_nrows = ne01; @@ -1311,7 +1798,7 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col); + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); // Prefetch next (n + spad_nrows) row const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); @@ -1329,14 +1816,14 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); } hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); @@ -1350,7 +1837,7 @@ struct mmid_row_mapping { }; // src1 tensor is already in VTCM spad -static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; struct htp_tensor * restrict ids = &octx->src2; @@ -1423,11 +1910,10 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const int rm2 = row_mapping.i2; // token idx const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = - (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); } // Prefetch next (n + spad_nrows) row @@ -1453,25 +1939,24 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const int rm2 = row_mapping.i2; // token idx const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = - (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } // src1 tensor is already in VTCM spad -static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matvec_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; struct htp_tensor * restrict ids = &octx->src2; @@ -1531,7 +2016,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); // Prefetch next (n + spad_nrows) row const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); @@ -1549,13 +2034,13 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); @@ -1754,12 +2239,14 @@ static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, ui hvx_copy_f16_ua(y_d, t_d, nb * 8); } -static void quantize_f32_q8x4x2(const struct htp_tensor * src, - uint8_t * restrict dst, - struct htp_spad * spad, - uint32_t nth, - uint32_t ith, - uint32_t nrows_per_thread) { +static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1799,8 +2286,14 @@ static void quantize_f32_q8x4x2(const struct htp_tensor * src, ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, - uint32_t nrows_per_thread, uint32_t dst_stride) { +static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1835,8 +2328,14 @@ static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict d } // TODO just a plain copy that should be done via the DMA during the Op setup -static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, - uint32_t nrows_per_thread, uint32_t dst_stride) { +static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1870,213 +2369,76 @@ static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict d ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void htp_quantize_f32_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); -} - -static void htp_quantize_f32_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f32_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); -} - -static void htp_quantize_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f16_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); -} - -// ** matmul/matvec callbacks for worker_pool - -static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_aa; - mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_aa; - mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f32"; - mt.vec_dot = vec_dot_f16_f32_uu; - - matmul_4d(&mt, octx, n, i); -} - -static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_uu; - - matmul_4d(&mt, octx, n, i); -} - -// ** matmul-id callbacks for worker_pool - -static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); -} - -static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} - -static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); -} - -static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} - -static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); -} - -static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} - -// ** main matmul entry point static inline bool htp_is_permuted(const struct htp_tensor * t) { return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; } +static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) { + switch (type) { + case HTP_TYPE_Q4_0: + mmctx->type = "q4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; + return 0; + case HTP_TYPE_Q8_0: + mmctx->type = "q8x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; + return 0; + case HTP_TYPE_MXFP4: + mmctx->type = "mxfp4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; + return 0; + default: + return -1; + } +} + +static void htp_mminit_spad(struct htp_ops_context * octx, + size_t dst_row_size, + size_t src0_row_size_padded, + size_t src1_row_size, + uint32_t src1_nrows, + size_t src2_spad_size_per_thread) { + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + if (src2_spad_size_per_thread > 0) { + octx->src2_spad.size_per_thread = src2_spad_size_per_thread; + octx->src2_spad.size = octx->src2_spad.size_per_thread; + } + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; +} + int op_matmul(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - const char * op_type; + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; const uint32_t src0_nrows = ne01 * ne02 * ne03; const uint32_t src1_nrows = ne11 * ne12 * ne13; + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; size_t src1_row_size = nb11; @@ -2085,181 +2447,95 @@ int op_matmul(struct htp_ops_context * octx) { size_t src1_row_size_padded; worker_callback_t quant_job_func; - worker_callback_t matmul_job_func; + worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE); - switch (src0->type) { - case HTP_TYPE_Q4_0: - op_type = "q4x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2; - } else { - matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2; - } + if (src0->type == HTP_TYPE_F16) { + // Try optimized f16-f16 path first (src1 in VTCM) + const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); + const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); + const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size + // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). + // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); + + if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; + + src1_row_size = f16_src1_row_size; // row size post quantization octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_Q8_0: - op_type = "q8x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2; + } else { + // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required + quant_job_func = NULL; + if (src1->type == HTP_TYPE_F32) { + mmctx->type = "f16-f32"; + mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; + matmul_job_func = matmul_4d; } else { - matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; + matmul_job_func = matmul_4d; } - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size + src1_row_size = nb11; // original row size in DDR octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - case HTP_TYPE_MXFP4: - op_type = "mxfp4x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2; - } else { - matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2; - } + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_F16: - { - // Try optimized f16-f16 path first (src1 in VTCM) - const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); - const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); - const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - - const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - - // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). - // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. - const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); - - if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { - // Optimized path - op_type = "f16-f16"; - quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_f32_f16 : htp_quantize_f16_f16; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_f16_f16; - } else { - matmul_job_func = htp_matvec_2d_f16_f16; - } - - src1_row_size = f16_src1_row_size; // row size post quantization - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - } else { - // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required - quant_job_func = NULL; - if (src1->type == HTP_TYPE_F32) { - op_type = "f16-f32"; - matmul_job_func = htp_matmul_4d_f16_f32; - } else { - op_type = "f16-f16"; - matmul_job_func = htp_matmul_4d_f16_f16; - } - - src1_row_size = nb11; // original row size in DDR - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - - // Init fastdiv for matmul_4d (supports broadcasting) - octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); - octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); - octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); - octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - - need_quant = false; - } - } - break; - - default: + need_quant = false; + } + } else { + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { return HTP_STATUS_NO_SUPPORT; + } + + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); } // VTCM scratchpads for all tensors size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type, + FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size); - FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0], + FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, + FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -2268,40 +2544,32 @@ int op_matmul(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even - octx->src0_spad.stride = src0_row_size_padded; octx->src1_spad.stride = src1_row_size; if (need_quant) { - // Run quant jobs - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); + const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - // Run matmul jobs const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs); + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); } return HTP_STATUS_OK; } -// ** main matmul-id entry point - int op_matmul_id(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + struct htp_tensor * restrict ids = &octx->src2; - const char * op_type; - - worker_callback_t quant_job_func; - worker_callback_t matmul_id_job_func; - const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; @@ -2310,6 +2578,13 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t src0_nrows = ne01; // per expert const uint32_t src1_nrows = ne11 * ne12 * ne13; + worker_callback_t quant_job_func; + worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + size_t src1_row_size; size_t src1_row_size_padded; @@ -2320,112 +2595,29 @@ int op_matmul_id(struct htp_ops_context * octx) { size_t matrix_row_counts_size = n_as * sizeof(uint32_t); size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); - switch (src0->type) { - case HTP_TYPE_Q4_0: - op_type = "q4x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_Q8_0: - op_type = "q8x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_MXFP4: - op_type = "mxfp4x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - default: - return HTP_STATUS_NO_SUPPORT; + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; } + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + + const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); + size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type, + FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size); - FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, + FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, - octx->ctx->vtcm_size, spad_size); + FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -2434,8 +2626,8 @@ int op_matmul_id(struct htp_ops_context * octx) { octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; - octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + octx->src0_spad.stride = src0_row_size_padded; + octx->src1_spad.stride = src1_row_size; if (src1_nrows > 1) { // initialize matrix_row_counts and map @@ -2447,8 +2639,7 @@ int op_matmul_id(struct htp_ops_context * octx) { // group rows by src0 matrix for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx for (uint32_t id = 0; id < n_ids; ++id) { // expert idx - const uint32_t i02 = - *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); assert(i02 >= 0 && i02 < n_as); @@ -2460,16 +2651,14 @@ int op_matmul_id(struct htp_ops_context * octx) { // Setup worker pool callbacks if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) { - // Run quant jobs const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - // Run matmul-id jobs const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs); + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); } return HTP_STATUS_OK; diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c new file mode 100644 index 0000000000..62e45da2b3 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -0,0 +1,115 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include +#include + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + + +#define sum_rows_preamble \ + struct htp_tensor *src0 = &octx->src0;\ + struct htp_tensor *dst = &octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + +static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) { + sum_rows_preamble; + + const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return HTP_STATUS_OK; + } + + int opt_path = 0; + if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + const uint8_t * restrict data_src = (const uint8_t *) src0->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); + float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); + + for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) { + const float * restrict src_local = src_th + (ir * ne00); + + if (ir + 1 < src0_nrows_per_thread) { + hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1); + } + + if (1 == opt_path) { + dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00); + } else { + dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00); + } + } + + return HTP_STATUS_OK; +} + +static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) { + sum_rows_thread_f32((struct htp_ops_context *) data, n, i); +} + +int op_sum_rows(struct htp_ops_context * octx) { + sum_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const int n_threads = octx->n_threads; + const uint32_t src0_nrows = ne01 * ne02 * ne03; + + uint32_t n_jobs = MIN(n_threads, src0_nrows); + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs); + + return HTP_STATUS_OK; +} + diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 1a27cb6e63..ce879bf037 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -132,6 +132,56 @@ static void rms_norm_htp_f32(const float * restrict src, } } +static void sqr_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); + } + + if (1 == opt_path) { + hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } else { + hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } + } +} + +static void sqrt_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); + } + + if (1 == opt_path) { + hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } else { + hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } + } +} + static void unary_job_f32_per_thread(const struct htp_tensor * src, struct htp_tensor * dst, uint8_t * spad, @@ -181,6 +231,12 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, case HTP_OP_SCALE: scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); break; + case HTP_OP_SQR: + sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; + case HTP_OP_SQRT: + sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; default: break; @@ -218,6 +274,14 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { unary_op_func = unary_job_dispatcher_f32; op_type = "scale-f32"; break; + case HTP_OP_SQR: + unary_op_func = unary_job_dispatcher_f32; + op_type = "sqr-f32"; + break; + case HTP_OP_SQRT: + unary_op_func = unary_job_dispatcher_f32; + op_type = "sqrt-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index 95627d3866..87e1378684 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -264,15 +264,25 @@ static std::vector ggml_metal_graph_optimize_reorder(const std::vector ggml_metal_graph_optimize_reorder(const std::vectorsrc[0])); - char base[256]; char name[256]; - const int64_t n = ggml_nelements(op); + int op_num = -1; - const char * op_str = "undefined"; switch (op->op) { - case GGML_OP_SCALE: op_str = "scale"; break; - case GGML_OP_FILL: op_str = "fill"; break; - case GGML_OP_CLAMP: op_str = "clamp"; break; - case GGML_OP_SQR: op_str = "sqr"; break; - case GGML_OP_SQRT: op_str = "sqrt"; break; - case GGML_OP_SIN: op_str = "sin"; break; - case GGML_OP_COS: op_str = "cos"; break; - case GGML_OP_LOG: op_str = "log"; break; - case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break; + case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break; + case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break; + case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break; + case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break; + case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break; + case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break; + case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break; + case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break; + case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_TANH: op_str = "tanh"; break; - case GGML_UNARY_OP_RELU: op_str = "relu"; break; - case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break; - case GGML_UNARY_OP_GELU: op_str = "gelu"; break; - case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break; - case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break; - case GGML_UNARY_OP_SILU: op_str = "silu"; break; - case GGML_UNARY_OP_ELU: op_str = "elu"; break; - case GGML_UNARY_OP_NEG: op_str = "neg"; break; - case GGML_UNARY_OP_ABS: op_str = "abs"; break; - case GGML_UNARY_OP_SGN: op_str = "sgn"; break; - case GGML_UNARY_OP_STEP: op_str = "step"; break; - case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break; - case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; - case GGML_UNARY_OP_EXP: op_str = "exp"; break; - case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break; - case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break; + case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break; + case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break; + case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break; + case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break; + case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break; + case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break; + case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break; + case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break; + case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break; + case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break; + case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break; + case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break; + case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break; + case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break; + case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break; + case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break; + case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); }; - const char * suffix = ""; - if (n % 4 == 0) { - suffix = "_4"; - } + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix); - snprintf(name, 256, "%s", base); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768; + + snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0); + ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } + res.c4 = is_c4; + res.cnt = is_cnt; + return res; } @@ -320,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l } ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { - GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); char base[256]; char name[256]; - const char * op_str = "undefined"; + int op_num = -1; + switch (op->op) { - case GGML_OP_SUM_ROWS: - op_str = "sum_rows"; break; - case GGML_OP_MEAN: - op_str = "mean"; break; + case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break; + case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break; default: GGML_ABORT("fatal error"); }; - snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(name, 256, "%s", base); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d", base, op_num); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } res.smem = 32*sizeof(float); + if (is_c4) { + res.smem *= 4; + } + + res.c4 = is_c4; + return res; } @@ -1392,34 +1415,78 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v GGML_UNUSED(op); } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( - ggml_metal_library_t lib, - ggml_op op, - int32_t n_fuse, - bool row) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) { char base[256]; char name[256]; - const char * op_str = "undefined"; - switch (op) { - case GGML_OP_ADD: op_str = "add"; break; - case GGML_OP_SUB: op_str = "sub"; break; - case GGML_OP_MUL: op_str = "mul"; break; - case GGML_OP_DIV: op_str = "div"; break; + int op_num = -1; + + switch (op->op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; default: GGML_ABORT("fatal error"); }; - if (row) { - snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse); - } else { - snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse); - } + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t1_str = ggml_type_name(op->src[1]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(name, 256, "%s", base); + const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0); + + const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1); + ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.c4 = is_c4; + res.cnt = is_rb; + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) { + char base[256]; + char name[256]; + + int op_num = -1; + + switch (op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32"); + snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1); + ggml_metal_cv_set_bool (cv, false, FC_BIN + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } return res; @@ -1428,13 +1495,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_L2_NORM); - GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); - char base[256]; char name[256]; - snprintf(base, 256, "kernel_l2_norm_f32"); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); + + snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1442,6 +1511,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } + res.c4 = is_c4; res.smem = 32*sizeof(float); return res; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 84dcec3083..93d7f6a216 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params { int nr1; size_t smem; + + bool c4; + bool cnt; }; int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); @@ -134,7 +137,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse ); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c8e737d418..b4ca9c5dd6 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -346,10 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, + /*.cnt =*/ false, }; res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); @@ -362,10 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, + /*.cnt =*/ false, }; [lib->lock lock]; @@ -1007,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te } switch (op->op) { + case GGML_OP_SCALE: + case GGML_OP_FILL: + case GGML_OP_CLAMP: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_LOG: + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_TANH: @@ -1026,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_EXPM1: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; } @@ -1054,11 +1067,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_ADD_ID: - return op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: case GGML_OP_REPEAT: - case GGML_OP_SCALE: - case GGML_OP_FILL: case GGML_OP_CONV_TRANSPOSE_1D: return true; case GGML_OP_CONV_TRANSPOSE_2D: @@ -1066,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; - case GGML_OP_CLAMP: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_LOG: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_TRI: @@ -1083,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_L2_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_COUNT_EQUAL: return has_simdgroup_reduction && op->src[0]->type == GGML_TYPE_I32 && @@ -1157,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return has_simdgroup_reduction; + case GGML_OP_SET: case GGML_OP_CPY: case GGML_OP_DUP: case GGML_OP_CONT: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 7f73cb97bb..383e0d6e93 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -80,6 +80,9 @@ #define FC_SSM_CONV 900 #define FC_SOLVE_TRI 1000 #define FC_COUNT_EQUAL 1100 +#define FC_UNARY 1200 +#define FC_BIN 1300 +#define FC_SUM_ROWS 1400 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -88,6 +91,37 @@ #define OP_FLASH_ATTN_EXT_VEC_NQPSG 1 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 +#define OP_UNARY_NUM_SCALE 10 +#define OP_UNARY_NUM_FILL 11 +#define OP_UNARY_NUM_CLAMP 12 +#define OP_UNARY_NUM_SQR 13 +#define OP_UNARY_NUM_SQRT 14 +#define OP_UNARY_NUM_SIN 15 +#define OP_UNARY_NUM_COS 16 +#define OP_UNARY_NUM_LOG 17 +#define OP_UNARY_NUM_LEAKY_RELU 18 + +#define OP_UNARY_NUM_TANH 100 +#define OP_UNARY_NUM_RELU 101 +#define OP_UNARY_NUM_SIGMOID 102 +#define OP_UNARY_NUM_GELU 103 +#define OP_UNARY_NUM_GELU_ERF 104 +#define OP_UNARY_NUM_GELU_QUICK 105 +#define OP_UNARY_NUM_SILU 106 +#define OP_UNARY_NUM_ELU 107 +#define OP_UNARY_NUM_NEG 108 +#define OP_UNARY_NUM_ABS 109 +#define OP_UNARY_NUM_SGN 110 +#define OP_UNARY_NUM_STEP 111 +#define OP_UNARY_NUM_HARDSWISH 112 +#define OP_UNARY_NUM_HARDSIGMOID 113 +#define OP_UNARY_NUM_EXP 114 +#define OP_UNARY_NUM_SOFTPLUS 115 +#define OP_UNARY_NUM_EXPM1 116 + +#define OP_SUM_ROWS_NUM_SUM_ROWS 10 +#define OP_SUM_ROWS_NUM_MEAN 11 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -123,6 +157,31 @@ typedef struct { int32_t dim; } ggml_metal_kargs_concat; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float slope; + float scale; + float bias; + float val; + float min; + float max; +} ggml_metal_kargs_unary; + typedef struct { int32_t ne00; int32_t ne01; @@ -180,20 +239,6 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_repeat; -typedef struct { - float scale; - float bias; -} ggml_metal_kargs_scale; - -typedef struct { - float val; -} ggml_metal_kargs_fill; - -typedef struct { - float min; - float max; -} ggml_metal_kargs_clamp; - typedef struct { int64_t nk0; int64_t ne00; @@ -497,8 +542,21 @@ typedef struct { typedef struct { int32_t ne00; - int32_t ne00_4; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float eps; } ggml_metal_kargs_l2_norm; @@ -880,10 +938,6 @@ typedef struct { int max_period; } ggml_metal_kargs_timestep_embedding; -typedef struct { - float slope; -} ggml_metal_kargs_leaky_relu; - typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e0ed6c7805..c04e9fc7ff 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -287,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { n_fuse = ggml_metal_op_acc(ctx, idx); } break; case GGML_OP_SCALE: - { - n_fuse = ggml_metal_op_scale(ctx, idx); - } break; case GGML_OP_FILL: - { - n_fuse = ggml_metal_op_fill(ctx, idx); - } break; case GGML_OP_CLAMP: - { - n_fuse = ggml_metal_op_clamp(ctx, idx); - } break; + case GGML_OP_LEAKY_RELU: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: @@ -426,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_top_k(ctx, idx); } break; - case GGML_OP_LEAKY_RELU: - { - n_fuse = ggml_metal_op_leaky_relu(ctx, idx); - } break; case GGML_OP_TRI: { n_fuse = ggml_metal_op_tri(ctx, idx); @@ -438,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); } break; + case GGML_OP_SET: + { + n_fuse = ggml_metal_op_set(ctx, idx); + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: @@ -707,7 +699,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.o1 =*/ { 0 }, }; - auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -722,119 +714,6 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { return 1; } -int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - float scale; - float bias; - memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float)); - memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float)); - - ggml_metal_kargs_scale args = { - /*.scale =*/ scale, - /*.bias =*/ bias, - }; - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; - } - - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - -int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - const float val = ggml_get_op_params_f32(op, 0); - - ggml_metal_kargs_fill args = { - /*.val =*/ val - }; - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; - } - - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - -int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - float min; - float max; - memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float)); - memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float)); - - ggml_metal_kargs_clamp args = { - /*.min =*/ min, - /*.max =*/ max, - }; - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; - } - - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -846,19 +725,79 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - int64_t n = ggml_nelements(op); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); - if (n % 4 == 0) { - n /= 4; + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_kargs_unary args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.slope =*/ 0.0, + /*.scale =*/ 0.0, + /*.bias =*/ 0.0, + /*.val =*/ 0.0, + /*.min =*/ 0.0, + /*.max =*/ 0.0, + }; + + if (op->op == GGML_OP_LEAKY_RELU) { + args.slope = ggml_get_op_params_f32(op, 0); + } + + if (op->op == GGML_OP_SCALE) { + args.scale = ggml_get_op_params_f32(op, 0); + args.bias = ggml_get_op_params_f32(op, 1); + } + + if (op->op == GGML_OP_FILL) { + args.val = ggml_get_op_params_f32(op, 0); + } + + if (op->op == GGML_OP_CLAMP) { + args.min = ggml_get_op_params_f32(op, 0); + args.max = ggml_get_op_params_f32(op, 1); } auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + if (pipeline.cnt) { + const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + } else { + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const int nth = MIN(args.ne00, nth_max); + + const int nk0 = (args.ne00 + nth - 1)/nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1); + } return 1; } @@ -969,6 +908,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + ggml_metal_kargs_sum_rows args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -990,21 +934,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + int nth = 32; // SIMD width - while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00); + nth = std::min(nth, (int) args.ne00); const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -1664,6 +1613,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + const size_t pnb1 = ((const int32_t *) op->op_params)[0]; + const size_t pnb2 = ((const int32_t *) op->op_params)[1]; + const size_t pnb3 = ((const int32_t *) op->op_params)[2]; + const size_t offs = ((const int32_t *) op->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } + + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type); + + GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0); + + int64_t nk0 = ne10; + if (ggml_is_quantized(op->src[1]->type)) { + nk0 = ne10/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne10/ggml_blck_size(op->type); + } + + int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + + // TODO: relax this constraint in the future + if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) { + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nrptg--; + } + } + } + + nth = std::min(nth, nk0); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne10, + /*.ne01 =*/ ne11, + /*.ne02 =*/ ne12, + /*.ne03 =*/ ne13, + /*.nb00 =*/ nb10, + /*.nb01 =*/ nb11, + /*.nb02 =*/ nb12, + /*.nb03 =*/ nb13, + /*.ne0 =*/ ne10, + /*.ne1 =*/ ne11, + /*.ne2 =*/ ne12, + /*.ne3 =*/ ne13, + /*.nb0 =*/ ggml_element_size(op), + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + }; + + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; + + bid_dst.offs += offs; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1); + + return 1; +} + int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -2895,8 +2972,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); - bool bcast_row = false; - ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); @@ -2990,18 +3065,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { struct ggml_metal_pipeline_with_params pipeline; - if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true); - - bcast_row = true; - } else { - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false); - } + pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse); if (n_fuse > 1) { bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); @@ -3015,20 +3079,28 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { } } + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne10 = ne10/4; + args.ne0 = ne0/4; + } + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, bid_src0, 1); ggml_metal_encoder_set_buffer (enc, bid_src1, 2); ggml_metal_encoder_set_buffer (enc, bid_dst, 3); - if (bcast_row) { - const int64_t n = ggml_nelements(op)/4; + if (pipeline.cnt) { + const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); } else { - int nth = 32; + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + int nth = 1; + + while (2*nth < args.ne0 && nth < nth_max) { nth *= 2; } @@ -3049,39 +3121,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + float eps; memcpy(&eps, op->op_params, sizeof(float)); - int nth = 32; // SIMD width - ggml_metal_kargs_l2_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, - /*.eps =*/ eps, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, }; auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); - while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00/4); const size_t smem = pipeline.smem; - const int64_t nrows = ggml_nrows(op->src[0]); - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); return 1; } @@ -4089,42 +4181,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { return 1; } -int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - float slope; - memcpy(&slope, op->op_params, sizeof(float)); - - ggml_metal_kargs_leaky_relu args = { - /*.slope =*/ slope - }; - - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; - } - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 3c64e4f600..f3e38c7aa9 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); @@ -62,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_set (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); @@ -86,7 +84,6 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 612a42a1ea..6c349aa0c9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -77,6 +77,14 @@ static inline float dot(float x, float y) { return x*y; } +static inline float sum(float x) { + return x; +} + +static inline float sum(float4 x) { + return x[0] + x[1] + x[2] + x[3]; +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -895,60 +903,217 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; -// general-purpose kernel for addition, subtraction, multiplication and division of two tensors -// pros: works for non-contiguous tensors, supports broadcast across all dims -// cons: not very efficient -template -kernel void kernel_add_fuse_impl( - constant ggml_metal_kargs_bin & args, +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +constant float SQRT_2_INV = 0.70710678118654752440084436210484f; + +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +constant float p_erf = 0.3275911f; +constant float a1_erf = 0.254829592f; +constant float a2_erf = -0.284496736f; +constant float a3_erf = 1.421413741f; +constant float a4_erf = -1.453152027f; +constant float a5_erf = 1.061405429f; + +template +inline T erf_approx(T x) { + T sign_x = sign(x); + x = fabs(x); + T t = 1.0f / (1.0f + p_erf * x); + T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + return sign_x * y; +} + +template T elu_approx(T x); + +template<> inline float elu_approx(float x) { + return (x > 0.f) ? x : (exp(x) - 1); +} + +template<> inline float4 elu_approx(float4 x) { + float4 res; + + res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); + res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); + res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); + res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); + + return res; +} + +constant short FC_unary_op [[function_constant(FC_UNARY + 0)]]; +constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]]; + +template +kernel void kernel_unary_impl( + constant ggml_metal_kargs_unary & args, device const char * src0, - device const char * src1, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +#define FC_OP FC_unary_op +#define FC_CNT FC_unary_cnt - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + device const T0 * src0_ptr; + device T * dst_ptr; - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); - device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + int i0; - device const float * src1_ptr[F]; - for (short j = 0; j < F; ++j) { - src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + if (FC_CNT) { + i0 = tgpig.x; + + src0_ptr = (device const T0 *) (src0); + dst_ptr = (device T *) (dst); + } else { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int k0 = tgpig.x/args.ne01; + const int i01 = tgpig.x - k0*args.ne01; + + i0 = k0*ntg.x + tpitg.x; + + src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 ); } - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; + { + //threadgroup_barrier(mem_flags::mem_none); - float res = src0_ptr[i0]; - -#pragma unroll - for (short j = 0; j < F; ++j) { - res += src1_ptr[j][i10]; + if (!FC_CNT) { + if (i0 >= args.ne0) { + return; + } } - dst_ptr[i0] = res; + const TC x = (TC) src0_ptr[i0]; + + if (FC_OP == OP_UNARY_NUM_SCALE) { + dst_ptr[i0] = (T) (args.scale * x + args.bias); + } + + if (FC_OP == OP_UNARY_NUM_FILL) { + dst_ptr[i0] = (T) args.val; + } + + if (FC_OP == OP_UNARY_NUM_CLAMP) { + dst_ptr[i0] = (T) clamp(x, args.min, args.max); + } + + if (FC_OP == OP_UNARY_NUM_SQR) { + dst_ptr[i0] = (T) (x * x); + } + + if (FC_OP == OP_UNARY_NUM_SQRT) { + dst_ptr[i0] = (T) sqrt(x); + } + + if (FC_OP == OP_UNARY_NUM_SIN) { + dst_ptr[i0] = (T) sin(x); + } + + if (FC_OP == OP_UNARY_NUM_COS) { + dst_ptr[i0] = (T) cos(x); + } + + if (FC_OP == OP_UNARY_NUM_LOG) { + dst_ptr[i0] = (T) log(x); + } + + if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) { + dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope)); + } + + if (FC_OP == OP_UNARY_NUM_TANH) { + dst_ptr[i0] = (T) precise::tanh(x); + } + + if (FC_OP == OP_UNARY_NUM_RELU) { + dst_ptr[i0] = (T) fmax(0, x); + } + + if (FC_OP == OP_UNARY_NUM_SIGMOID) { + dst_ptr[i0] = (T) (1 / (1 + exp(-x))); + } + + if (FC_OP == OP_UNARY_NUM_GELU) { + dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x)))); + } + + if (FC_OP == OP_UNARY_NUM_GELU_ERF) { + dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x))); + } + + if (FC_OP == OP_UNARY_NUM_GELU_QUICK) { + dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x)))); + } + + if (FC_OP == OP_UNARY_NUM_SILU) { + dst_ptr[i0] = (T) (x / (1 + exp(-x))); + } + + if (FC_OP == OP_UNARY_NUM_ELU) { + dst_ptr[i0] = (T) elu_approx(x); + } + + if (FC_OP == OP_UNARY_NUM_NEG) { + dst_ptr[i0] = (T) -x; + } + + if (FC_OP == OP_UNARY_NUM_ABS) { + dst_ptr[i0] = (T) fabs(x); + } + + if (FC_OP == OP_UNARY_NUM_SGN) { + dst_ptr[i0] = T(x > 0) - T(x < 0); + } + + if (FC_OP == OP_UNARY_NUM_STEP) { + dst_ptr[i0] = T(x > 0); + } + + if (FC_OP == OP_UNARY_NUM_HARDSWISH) { + dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5))); + } + + if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) { + dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5)); + } + + if (FC_OP == OP_UNARY_NUM_EXP) { + dst_ptr[i0] = (T) exp(x); + } + + if (FC_OP == OP_UNARY_NUM_SOFTPLUS) { + dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20); + } + + if (FC_OP == OP_UNARY_NUM_EXPM1) { + // TODO: precise implementation + dst_ptr[i0] = (T) (exp(x) - 1); + } } + +#undef FC_OP +#undef FC_CNT } -typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; +typedef decltype(kernel_unary_impl) kernel_unary_t; -template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; -template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; -template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; -template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; -template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; -template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; -template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; -template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; +template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl; -kernel void kernel_sub_fuse_1( +// OP: 0 - add, 1 - sub, 2 - mul, 3 - div +constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; +constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; +constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]]; + +template +kernel void kernel_bin_fuse_impl( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -956,89 +1121,152 @@ kernel void kernel_sub_fuse_1( uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +#define FC_OP FC_bin_op +#define FC_F FC_bin_f +#define FC_RB FC_bin_rb - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (FC_RB) { + // row broadcast + const uint i0 = tgpig.x; + const uint i1 = i0%args.ne10; - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + device const T0 * src0_row = (device const T0 *) (src0); + device T * dst_row = (device T *) (dst); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); - } -} + if (FC_F == 1) { + device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]); -kernel void kernel_mul_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; + if (FC_OP == 0) { + dst_row[i0] = src0_row[i0] + src1_row[i1]; + } - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (FC_OP == 1) { + dst_row[i0] = src0_row[i0] - src1_row[i1]; + } - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + if (FC_OP == 2) { + dst_row[i0] = src0_row[i0] * src1_row[i1]; + } - if (args.ne10 == 1) { - const float x = *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + if (FC_OP == 3) { + dst_row[i0] = src0_row[i0] / src1_row[i1]; + } + } else { + T0 res = src0_row[i0]; + + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res += ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res -= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res *= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res /= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + dst_row[i0] = res; } } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + if (i01 >= args.ne01) { + return; + } + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); + device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + + if (FC_F == 1) { + device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + + if (FC_OP == 0) { + dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10]; + } + + if (FC_OP == 1) { + dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10]; + } + + if (FC_OP == 2) { + dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; + } + + if (FC_OP == 3) { + dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10]; + } + } + } else { + device const T1 * src1_ptr[8]; + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + } + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + + T res = src0_ptr[i0]; + + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res += src1_ptr[j][i10]; + } + } + + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res -= src1_ptr[j][i10]; + } + } + + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res *= src1_ptr[j][i10]; + } + } + + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res /= src1_ptr[j][i10]; + } + } + + dst_ptr[i0] = res; + } } } + +#undef FC_OP +#undef FC_F +#undef FC_RB } -kernel void kernel_div_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t; - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; - - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - - if (args.ne10 == 1) { - const float x = 1.0f / *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; - } - } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); - } - } -} +template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; +template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; kernel void kernel_add_id( constant ggml_metal_kargs_add_id & args, @@ -1057,7 +1285,7 @@ kernel void kernel_add_id( const size_t nb1 = args.ne0 * sizeof(float); const size_t nb2 = args.ne1 * nb1; - device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); + device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02); device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11); @@ -1098,549 +1326,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; -// assumption: src1 is a row -// broadcast src1 into src0 -template -kernel void kernel_add_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res += ((device const float4 *) (src1 + args.o1[j]))[i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; - -template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; -template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; -template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; -template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; -template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>; -template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>; -template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>; -template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>; - -template -kernel void kernel_sub_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res -= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; - -template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; - -template -kernel void kernel_mul_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res *= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; - -template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; - -template -kernel void kernel_div_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res /= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; - -template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; - -kernel void kernel_scale_f32( - constant ggml_metal_kargs_scale & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * args.scale + args.bias; -} - -kernel void kernel_scale_f32_4( - constant ggml_metal_kargs_scale & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * args.scale + args.bias; -} - -kernel void kernel_fill_f32( - constant ggml_metal_kargs_fill & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = args.val; -} - -kernel void kernel_fill_f32_4( - constant ggml_metal_kargs_fill & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = args.val; -} - -kernel void kernel_clamp_f32( - constant ggml_metal_kargs_clamp & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = clamp(src0[tpig], args.min, args.max); -} - -kernel void kernel_clamp_f32_4( - constant ggml_metal_kargs_clamp & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = clamp(src0[tpig], args.min, args.max); -} - -kernel void kernel_relu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_relu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_sigmoid_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); -} - -kernel void kernel_sigmoid_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); -} - -kernel void kernel_tanh_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = precise::tanh(src0[tpig]); -} - -kernel void kernel_tanh_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = precise::tanh(src0[tpig]); -} - -constant float GELU_COEF_A = 0.044715f; -constant float GELU_QUICK_COEF = -1.702f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; -constant float SQRT_2_INV = 0.70710678118654752440084436210484f; - -kernel void kernel_gelu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_quick_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_gelu_quick_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation -// ref: https://www.johndcook.com/blog/python_erf/ -constant float p_erf = 0.3275911f; -constant float a1_erf = 0.254829592f; -constant float a2_erf = -0.284496736f; -constant float a3_erf = 1.421413741f; -constant float a4_erf = -1.453152027f; -constant float a5_erf = 1.061405429f; - -template -T erf_approx(T x) { - T sign_x = sign(x); - x = fabs(x); - T t = 1.0f / (1.0f + p_erf * x); - T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); - return sign_x * y; -} - -kernel void kernel_gelu_erf_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); -} - -kernel void kernel_gelu_erf_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); -} - -kernel void kernel_silu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_silu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_elu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f); -} - -kernel void kernel_elu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); - dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); - dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); - dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); -} - -kernel void kernel_sqr_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -kernel void kernel_sqr_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -kernel void kernel_sqrt_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sqrt(src0[tpig]); -} - -kernel void kernel_sqrt_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sqrt(src0[tpig]); -} - -kernel void kernel_sin_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sin(src0[tpig]); -} - -kernel void kernel_sin_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sin(src0[tpig]); -} - -kernel void kernel_cos_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = cos(src0[tpig]); -} - -kernel void kernel_cos_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = cos(src0[tpig]); -} - -kernel void kernel_log_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = log(src0[tpig]); -} - -kernel void kernel_log_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = log(src0[tpig]); -} - -kernel void kernel_neg_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = -src0[tpig]; -} - -kernel void kernel_neg_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = -src0[tpig]; -} - -kernel void kernel_abs_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = fabs(src0[tpig]); -} - -kernel void kernel_abs_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = fabs(src0[tpig]); -} - -kernel void kernel_sgn_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sign(src0[tpig]); -} - -kernel void kernel_sgn_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sign(src0[tpig]); -} - -kernel void kernel_step_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = step(0.0f, src0[tpig]); -} - -kernel void kernel_step_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = step(0.0f, src0[tpig]); -} - -kernel void kernel_hardswish_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_hardswish_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_hardsigmoid_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_hardsigmoid_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_exp_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]); -} - -kernel void kernel_exp_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]); -} - -kernel void kernel_softplus_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); -} - -kernel void kernel_softplus_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); -} - -kernel void kernel_expm1_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]) - 1.0f; -} - -kernel void kernel_expm1_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]) - 1.0f; -} - kernel void kernel_reglu_f32( constant ggml_metal_kargs_glu & args, device const char * src0, @@ -1824,33 +1509,35 @@ kernel void kernel_op_sum_f32( } } -template -kernel void kernel_sum_rows( +constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]]; + +template +kernel void kernel_sum_rows_impl( constant ggml_metal_kargs_sum_rows & args, - device const float * src0, - device float * dst, - threadgroup float * shmem_f32 [[threadgroup(0)]], + device const char * src0, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - int64_t i3 = tgpig.z; - int64_t i2 = tgpig.y; - int64_t i1 = tgpig.x; +#define FC_OP FC_sum_rows_op - if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { - return; - } + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + threadgroup T0 * shmem_t = (threadgroup T0 *) shmem; if (sgitg == 0) { - shmem_f32[tiisg] = 0.0f; + shmem_t[tiisg] = 0.0f; } - device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); - float sumf = 0; + T0 sumf = T0(0.0f); for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { sumf += src_row[i0]; @@ -1861,23 +1548,33 @@ kernel void kernel_sum_rows( threadgroup_barrier(mem_flags::mem_threadgroup); if (tiisg == 0) { - shmem_f32[sgitg] = sumf; + shmem_t[sgitg] = sumf; } threadgroup_barrier(mem_flags::mem_threadgroup); - sumf = shmem_f32[tiisg]; + sumf = shmem_t[tiisg]; sumf = simd_sum(sumf); if (tpitg.x == 0) { - dst_row[0] = norm ? sumf / args.ne00 : sumf; + if (FC_OP == OP_SUM_ROWS_NUM_MEAN) { + if (is_same::value) { + dst_row[0] = sum(sumf) / (4*args.ne00); + } else { + dst_row[0] = sum(sumf) / args.ne00; + } + } else { + dst_row[0] = sum(sumf); + } } + +#undef FC_OP } -typedef decltype(kernel_sum_rows) kernel_sum_rows_t; +typedef decltype(kernel_sum_rows_impl) kernel_sum_rows_t; -template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; -template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; +template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; template kernel void kernel_cumsum_blk( @@ -2758,9 +2455,6 @@ kernel void kernel_solve_tri_f32( const short K = FC_solve_tri_k; const short NP = PAD2(N, NW); - const int32_t ne02 = args.ne02; - const int32_t ne03 = args.ne03; - const int32_t i03 = tgpig.z; const int32_t i02 = tgpig.y; const int32_t i01 = tgpig.x*NSG + sgitg; @@ -3047,26 +2741,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; -kernel void kernel_l2_norm_f32( +template +kernel void kernel_l2_norm_impl( constant ggml_metal_kargs_l2_norm & args, device const char * src0, device char * dst, threadgroup float * shmem_f32 [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - ushort tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + if (sgitg == 0) { shmem_f32[tiisg] = 0.0f; } - device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); float sumf = 0.0f; // parallel sum - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { sumf += dot(x[i00], x[i00]); } sumf = simd_sum(sumf); @@ -3084,12 +2784,16 @@ kernel void kernel_l2_norm_f32( const float scale = 1.0f/sqrt(max(sumf, args.eps)); - device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { y[i00] = x[i00] * scale; } } +typedef decltype(kernel_l2_norm_impl) kernel_l2_norm_t; + +template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl; +template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl; + kernel void kernel_group_norm_f32( constant ggml_metal_kargs_group_norm & args, device const float * src0, @@ -5191,24 +4895,6 @@ kernel void kernel_argsort_merge_f32_i32( template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; -kernel void kernel_leaky_relu_f32( - constant ggml_metal_kargs_leaky_relu & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = x > 0.0f ? x : x * args.slope; -} - -kernel void kernel_leaky_relu_f32_4( - constant ggml_metal_kargs_leaky_relu & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); -} - constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; @@ -6280,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec( static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); - const short T = PK + NSG*SH; // shared memory size per query in (half) + //const short T = PK + NSG*SH; // shared memory size per query in (half) //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t @@ -8868,7 +8554,9 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -8991,8 +8679,8 @@ kernel void kernel_mul_mm( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; @@ -9241,7 +8929,9 @@ kernel void kernel_mul_mm_id( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -9376,8 +9066,8 @@ kernel void kernel_mul_mm_id( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; @@ -10058,7 +9748,7 @@ kernel void kernel_opt_step_sgd_f32( template kernel void kernel_memset( - constant ggml_metal_kargs_fill & args, + constant ggml_metal_kargs_memset & args, device T * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = args.val; diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index fa5fadd112..f389193691 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -85,6 +85,9 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_0_f32_8x_flat mul_mv_q4_0_f32_1d_8x_flat mul_mv_q4_0_f32_1d_16x_flat + mul_mv_q4_1_f32 + mul_mv_q4_1_f32_flat + mul_mv_q4_k_f32 mul_mv_q6_k_f32 mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 @@ -100,7 +103,10 @@ set(GGML_OPENCL_KERNELS gemv_moe_mxfp4_f32 mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm + mul_mm_q4_0_f32_l4_lm + mul_mm_q4_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_general_q8_0_f32 mul diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 508b2b8f03..ae3f79fd0d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -525,6 +525,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_f16_f32_kq; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; + cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; @@ -532,6 +533,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; + cl_kernel kernel_mul_mv_q4_1_f32; + cl_kernel kernel_mul_mv_q4_1_f32_flat; + cl_kernel kernel_mul_mv_q4_K_f32; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; @@ -563,7 +567,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mv_id_mxfp4_f32_flat; cl_kernel kernel_mul_mm_f32_f32_l4_lm; cl_kernel kernel_mul_mm_f16_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; std::vector profiling_info; @@ -886,6 +893,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err)); @@ -1117,6 +1126,57 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q4_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_1_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_1_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_1_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_1_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1342,6 +1402,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q4_0_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q4_0_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q4_0_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_0_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mm_q4_1_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q4_1_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q4_1_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_1_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1358,6 +1450,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q6_k_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q6_k_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_f16_f32_kq_kqv { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2887,6 +2996,59 @@ struct ggml_tensor_extra_cl_q4_0 { } }; +struct ggml_tensor_extra_cl_q4_1 { + // Quantized values. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Min + cl_mem m = nullptr; + // Min in image1d_buffer_t. + cl_mem m_img = nullptr; + // Size of quantized values. + size_t size_q = 0; + // Size of scales. + size_t size_d = 0; + // Size of min values. + size_t size_m = 0; + + ~ggml_tensor_extra_cl_q4_1() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (m != nullptr) { + CL_CHECK(clReleaseMemObject(m)); + m = nullptr; + } + // Currently, q_img and d_img are only initialized when SMALL_ALLOC is + // enabled. They point to the images in ggml_backend_opencl_buffer_context. + // So, there is no need to release them here. + // TODO: initialize them for non SMALL_PATH path, or remove them. + q_img = nullptr; + d_img = nullptr; + m_img = nullptr; + size_q = 0; + size_d = 0; + size_m = 0; + } +}; + struct ggml_tensor_extra_cl_mxfp4 { // Quantized values. cl_mem q = nullptr; @@ -3363,7 +3525,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return true; } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; - } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 || + } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } else if (op->src[0]->type == GGML_TYPE_Q8_0) { @@ -3592,6 +3756,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q4_1 * ggml_opencl_alloc_temp_tensor_extra_q4_1() { + ggml_tensor_extra_cl_q4_1 * extra; + if (temp_tensor_extras_q4_1.empty()) { + extra = new ggml_tensor_extra_cl_q4_1(); + } else { + extra = temp_tensor_extras_q4_1.back(); + temp_tensor_extras_q4_1.pop_back(); + } + + temp_tensor_extras_q4_1_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() { ggml_tensor_extra_cl_mxfp4 * extra; if (temp_tensor_extras_mxfp4.empty()) { @@ -3648,6 +3827,11 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q4_0_in_use.clear(); + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { + temp_tensor_extras_q4_1.push_back(e); + } + temp_tensor_extras_q4_1_in_use.clear(); + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { temp_tensor_extras_mxfp4.push_back(e); } @@ -3673,6 +3857,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_in_use; std::vector temp_tensor_extras_q4_0; std::vector temp_tensor_extras_q4_0_in_use; + std::vector temp_tensor_extras_q4_1; + std::vector temp_tensor_extras_q4_1_in_use; std::vector temp_tensor_extras_mxfp4; std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; @@ -4042,6 +4228,75 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } + if (tensor->type == GGML_TYPE_Q4_1) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_1(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_m + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, mins, then quants. + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for mins. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_m; + extra->m = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_m, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + return; + } if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -4544,7 +4799,35 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, size, data, 0, NULL, NULL)); CL_CHECK(clReleaseMemObject(data_device)); return; - } else if (tensor->type == GGML_TYPE_MXFP4) { + } + if (tensor->type == GGML_TYPE_Q4_1) { + ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; cl_int err; @@ -8372,6 +8655,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; @@ -8885,6 +9169,91 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q4_0: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q4_1: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q8_0: { if (ne11 < 32) { break; @@ -8927,6 +9296,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q6_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } default: break; } @@ -9181,7 +9594,71 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); #endif // GGML_OPENCL_SOA_Q break; - case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1: { +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q4_1_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q4_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat; @@ -9262,7 +9739,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: + case GGML_TYPE_Q4_K: { + kernel = backend_ctx->kernel_mul_mv_q4_K_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + break; + } case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: #ifdef GGML_OPENCL_SOA_Q @@ -9424,7 +9936,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else if (src0t == GGML_TYPE_Q4_K) { - GGML_ASSERT(false && "not implemented"); + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else if (src0t == GGML_TYPE_Q3_K) { GGML_ASSERT(false && "not implemented"); } else if (src0t == GGML_TYPE_Q5_K) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 9fb434713d..2c244ce321 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -46,6 +46,15 @@ struct block_q4_0 uint8_t qs[QK4_0 / 2]; }; +//------------------------------------------------------------------------------ +// block_q4_1 +//------------------------------------------------------------------------------ +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + //------------------------------------------------------------------------------ // block_q6_K //------------------------------------------------------------------------------ @@ -148,6 +157,48 @@ kernel void kernel_restore_block_q4_0_noshuffle( } } +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_1 +// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_1( + global struct block_q4_1 * src0, + global uchar * dst_q, + global half * dst_d, + global half * dst_m +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + + for (int i = 0; i < QK4_1/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q4_1( + global uchar * src_q, + global half * src_d, + global half * src_m, + global struct block_q4_1 * dst +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + for (int i = 0; i < QK4_1/2; ++i) { + b->qs[i] = q[i]; + } +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl new file mode 100644 index 0000000000..4100e3080a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl @@ -0,0 +1,163 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_0_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d; + float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl new file mode 100644 index 0000000000..d0d2f08361 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl @@ -0,0 +1,165 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_1_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global half * src0_m, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + float m = (float)src0_m[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m; + float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl new file mode 100644 index 0000000000..3602c92fef --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl @@ -0,0 +1,158 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 2 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q6_k_f32_l4_lm( + global uchar * src0_ql, + global uchar * src0_qh, + global char * src0_s, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + + int ib = idx / 128; // 2 values per idx + int iqs = idx % 128; // 0..127 + + int n = iqs / 64; // 0,1 + int b = (iqs % 64) / 32; // 0,1 + int is_b = (iqs % 16) / 8; // 0,1 + int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + int is = 8 * n + qhshift + is_b; // 0..15 + int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is]; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32); + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32); + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl new file mode 100644 index 0000000000..6fe828f20e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl @@ -0,0 +1,219 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_1 32 + +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + +inline float block_q4_1_dot_y( + global const struct block_q4_1 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float m = qb_curr->m; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il); + + yb += QK4_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_1_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl new file mode 100644 index 0000000000..d7c4645d67 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl @@ -0,0 +1,229 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_1 32 + +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + +inline float block_q4_1_dot_y_flat( + global const uchar * x, + global const half * dh, + global const half * mh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + float m = *mh; + global const ushort * qs = ((global const ushort *) x + il/2); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_q, + global void * src0_d, + global void * src0_m, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + // The number of scales/mins is the same as the number of blocks. + ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)); + // Each block contains QK4_1/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_dm; + global half * m = (global half *) src0_m + offset0_dm; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il); + + yb += QK4_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_1_f32_flat( + global void * src0_q, + global void * src0_d, + global void * src0_m, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl new file mode 100644 index 0000000000..71ab989821 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl @@ -0,0 +1,180 @@ +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define K_SCALE_SIZE 12 + +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qs[QK_K/2]; // 4-bit quants +} block_q4_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // number of rows each SIMD group works on +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_K_f32( + global char * src0, + int offset0, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; // super block index + int it = get_sub_group_local_id()%8; // block index (inside super block) + int iq = it/4; // 0 or 1 - first or second half of the super block + int ir = it%4; // 0...3 - block index in the half super block + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * sc = (global ushort *)x[ib].scales + iq; + global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir; + global half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F); + acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00); + acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0); + acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000); + acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F); + acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00); + acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0); + acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 651b875b63..00d54b83f8 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -836,16 +836,9 @@ static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tens } static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, - [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { - const int num_blocks = ceil_div(k_elements, 256); - stream->parallel_for( - sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), - sycl::range<1>(256)), - [=](sycl::nd_item<1> item_ct1) { - unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1); - }); - }); + ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { + return op_ceil(x); + }); } static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index a03d26d7f2..0614d7e8f3 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4591,9 +4591,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_CEIL: return true; case GGML_UNARY_OP_FLOOR: - case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: #if defined (GGML_SYCL_F16) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 84d88e81d4..63f797f142 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -4,6 +4,7 @@ #include "ggml.h" #include "pre_wgsl.hpp" +#include #include #include @@ -18,9 +19,9 @@ #define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u struct ggml_webgpu_processed_shader { - std::string wgsl; - std::string variant; - void * decisions; + std::string wgsl; + std::string variant; + std::shared_ptr decisions; }; // Same hash combine function as in boost @@ -192,13 +193,13 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_flash_attn_shader_decisions * decisions = new ggml_webgpu_flash_attn_shader_decisions(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + result.decisions = decisions; return result; } @@ -270,11 +271,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader( defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; return result; } @@ -305,11 +306,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader( } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions(); - decisions->wg_size = wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + result.decisions = decisions; return result; } @@ -324,11 +325,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader( uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions(); - decisions->wg_size = wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + result.decisions = decisions; return result; } @@ -391,11 +392,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader( defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; return result; } @@ -457,12 +458,81 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader( defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; return result; } +/** Binary **/ + +struct ggml_webgpu_binary_pipeline_key { + int type; + int op; + bool inplace; + bool overlap; + + bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; + } +}; + +struct ggml_webgpu_binary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); + return seed; + } +}; + +struct ggml_webgpu_binary_shader_lib_context { + ggml_webgpu_binary_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_binary_shader_lib_context & context) { + std::vector defines; + std::string op_name = ggml_op_name((ggml_op) context.key.op); + std::string variant = op_name; + + defines.push_back(std::string("OP_") + op_name); + + switch (context.key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for binary shader"); + } + + if (context.key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (context.key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; + return result; +} #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4ef50e365e..32e120266a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -186,11 +186,17 @@ struct webgpu_buf_pool { void cleanup() { std::lock_guard lock(mutex); for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); + if (bufs.host_buf) { + bufs.host_buf.Destroy(); + } + if (bufs.dev_buf) { + bufs.dev_buf.Destroy(); + } } free.clear(); } + + ~webgpu_buf_pool() { this->cleanup(); } }; #ifdef GGML_WEBGPU_GPU_PROFILE @@ -252,13 +258,15 @@ struct webgpu_gpu_profile_buf_pool { } free.clear(); } + + ~webgpu_gpu_profile_buf_pool() { this->cleanup(); } }; #endif struct webgpu_pipeline { wgpu::ComputePipeline pipeline; std::string name; - void * context = nullptr; + std::shared_ptr context = nullptr; }; struct webgpu_command { @@ -319,6 +327,23 @@ struct webgpu_global_context_struct { wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; #endif + + ~webgpu_global_context_struct() { + if (this->get_tensor_staging_buf) { + this->get_tensor_staging_buf.Destroy(); + this->get_tensor_staging_buf = nullptr; + } +#ifdef GGML_WEBGPU_DEBUG + if (this->debug_host_buf) { + this->debug_host_buf.Destroy(); + this->debug_host_buf = nullptr; + } + if (this->debug_dev_buf) { + this->debug_dev_buf.Destroy(); + this->debug_dev_buf = nullptr; + } +#endif + } }; typedef std::shared_ptr webgpu_global_context; @@ -348,13 +373,12 @@ struct webgpu_context_struct { std::unordered_map set_rows_pipelines; - std::map> get_rows_pipelines; // src_type, vectorized + std::map> get_rows_pipelines; // src_type, vectorized - std::map> cpy_pipelines; // src_type, dst_type - std::map> add_pipelines; // type, inplace - std::map> sub_pipelines; // type, inplace - std::map> mul_pipelines; // type, inplace - std::map> div_pipelines; // type, inplace + std::map> cpy_pipelines; // src_type, dst_type + + std::unordered_map + binary_pipelines; std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace @@ -745,7 +769,6 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) { return ctx->name.c_str(); } -// TODO: implement proper cleanup static void ggml_backend_webgpu_free(ggml_backend_t backend) { ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); @@ -789,9 +812,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n"; #endif -#if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE) - GGML_UNUSED(ctx); -#endif + delete ctx; + delete backend; } static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { @@ -823,6 +845,28 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); } +// Used to determine if two tensors share the same buffer and their byte ranges overlap, +static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && + ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); +} + +struct binary_overlap_flags { + bool inplace; // src0 == dst + bool overlap; // src1 == dst +}; + +static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + binary_overlap_flags flags = {}; + flags.inplace = ggml_webgpu_tensor_equal(src0, dst); + flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); + + return flags; +} + static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -875,8 +919,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g ctx->pad_pipelines.emplace(pipeline_key, pipeline); } - ggml_webgpu_generic_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); const uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -920,7 +963,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -954,8 +997,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ctx->set_rows_pipelines.emplace(key, pipeline); } - ggml_webgpu_generic_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); std::optional error_bufs = std::nullopt; if (key.i64_idx) { @@ -1007,7 +1049,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, } else { threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } - uint32_t wg_x = CEIL_DIV(threads, decisions.wg_size); + uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1, error_bufs); } @@ -1276,10 +1318,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ctx->flash_attn_pipelines.emplace(key, pipeline); } - ggml_webgpu_flash_attn_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1310,8 +1351,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s ctx->unary_pipelines.emplace(pipeline_key, pipeline); } - ggml_webgpu_generic_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1371,18 +1411,45 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst, - webgpu_pipeline & pipeline, - bool inplace) { +static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); + + ggml_webgpu_binary_pipeline_key pipeline_key = { + .type = dst->type, + .op = dst->op, + .inplace = flags.inplace, + .overlap = flags.overlap, + }; + ggml_webgpu_binary_shader_lib_context shader_lib_ctx = { + .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; + + webgpu_pipeline pipeline; + auto it = ctx->binary_pipelines.find(pipeline_key); + if (it != ctx->binary_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->binary_pipelines.emplace(pipeline_key, pipeline); + } + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t ne = (uint32_t) ggml_nelements(dst); + std::vector params = { - (uint32_t) ggml_nelements(dst), + ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1399,24 +1466,30 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, (uint32_t) src1->ne[3], }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } - }; - if (!inplace) { + std::vector entries; + + entries.push_back({ + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0), + }); + + entries.push_back({ + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1), + }); + + if (!flags.inplace && !flags.overlap) { entries.push_back({ .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1766,8 +1839,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr argsort_pipeline.context = processed.decisions; ctx->argsort_pipelines.emplace(order, argsort_pipeline); } - ggml_webgpu_argsort_shader_decisions argsort_decisions = - *static_cast(argsort_pipeline.context); + auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); webgpu_pipeline argsort_merge_pipeline; it = ctx->argsort_merge_pipelines.find(order); @@ -1784,13 +1856,13 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr const uint32_t src_ne0 = (uint32_t) src->ne[0]; const uint32_t nrows = (uint32_t) ggml_nrows(src); - const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions.wg_size); + const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size); const uint32_t block_size = - is_top_k ? std::min(argsort_decisions.wg_size, (uint32_t) dst->ne[0]) : argsort_decisions.wg_size; + is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size; uint32_t out_ne0 = src_ne0; if (is_top_k) { if (npr > 1) { - const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions.wg_size; + const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size; out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size); } else { out_ne0 = block_size; @@ -2038,25 +2110,10 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return std::nullopt; #endif case GGML_OP_ADD: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace); - } case GGML_OP_SUB: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace); - } case GGML_OP_MUL: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace); - } case GGML_OP_DIV: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace); - } + return ggml_webgpu_binary_op(ctx, src0, src1, node); case GGML_OP_RMS_NORM: return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2158,7 +2215,10 @@ static ggml_backend_i ggml_backend_webgpu_i = { static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_webgpu_buffer_context * ctx = static_cast(buffer->context); - ctx->buffer.Destroy(); + if (ctx != nullptr && ctx->buffer != nullptr) { + ctx->buffer.Destroy(); + delete ctx; + } } // Returns the "fake" base pointer. @@ -2665,58 +2725,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } -static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32, "add_f32", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16, "add_f16", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants); -} - -static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32, "sub_f32", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16, "sub_f16", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants); -} - -static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32, "mul_f32", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16, "mul_f16", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants); -} - -static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32, "div_f32", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16, "div_f16", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants); -} - static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); @@ -2938,12 +2946,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { dev_desc.SetDeviceLostCallback( wgpu::CallbackMode::AllowSpontaneous, [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + if (reason == wgpu::DeviceLostReason::Destroyed) { + return; + } GGML_UNUSED(device); - GGML_UNUSED(reason); - GGML_UNUSED(message); - //TODO: uncomment once proper free logic is in place - //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), - //std::string(message).c_str()); + GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), + std::string(message).c_str()); }); dev_desc.SetUncapturedErrorCallback( [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { @@ -3018,10 +3026,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); ggml_webgpu_init_get_rows_pipeline(webgpu_ctx); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - ggml_webgpu_init_add_pipeline(webgpu_ctx); - ggml_webgpu_init_sub_pipeline(webgpu_ctx); - ggml_webgpu_init_mul_pipeline(webgpu_ctx); - ggml_webgpu_init_div_pipeline(webgpu_ctx); ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); ggml_webgpu_init_rope_pipeline(webgpu_ctx); ggml_webgpu_init_glu_pipeline(webgpu_ctx); @@ -3381,10 +3385,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) { return ctx->device_count; } -// TODO: Does this need to be thread safe? Is it only called once? -// TODO: move most logic to device_init function so backend can be freed/initialized properly // Only one device is supported for now - static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { GGML_ASSERT(index == 0); WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl deleted file mode 100644 index 1ce4d83fa8..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +++ /dev/null @@ -1,188 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "add_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "+" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "add_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "+" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "add_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "+" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "add_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "+" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "mul_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "*" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "mul_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "*" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "mul_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "*" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "mul_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "*" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sub_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "-" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sub_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "-" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sub_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "-" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sub_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "-" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "div_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "/" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "div_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "/" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "div_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "/" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "div_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "/" - }, - "DECLS": ["INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) - -fn update(dst_i: u32, src0_i: u32, src1_i: u32) { - dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; -} - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -#enddecl(NOT_INPLACE) - -#decl(INPLACE) - -fn update(dst_i: u32, src0_i: u32, src1_i: u32) { - src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; -} - -@group(0) @binding(2) -var params: Params; - -#enddecl(INPLACE) - -#end(DECLS) - - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl new file mode 100644 index 0000000000..55dd66408a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -0,0 +1,107 @@ +enable f16; + +struct Params { + ne: u32, + + // offsets in elements + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src1_0: u32, + stride_src1_1: u32, + stride_src1_2: u32, + stride_src1_3: u32, + + a_ne0: u32, + a_ne1: u32, + a_ne2: u32, + + b_ne0: u32, + b_ne1: u32, + b_ne2: u32, + b_ne3: u32, +}; + +fn src1_index(_i: u32) -> u32 { + var i = _i; + let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); + i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); + let a_i2 = i / (params.a_ne1 * params.a_ne0); + i = i % (params.a_ne1 * params.a_ne0); + let a_i1 = i / params.a_ne0; + let a_i0 = i % params.a_ne0; + + // handle repetition of b + // index loops back to the beginning and repeats after elements are exhausted = modulo + let b_i0 = a_i0 % params.b_ne0; + let b_i1 = a_i1 % params.b_ne1; + let b_i2 = a_i2 % params.b_ne2; + let b_i3 = a_i3 % params.b_ne3; + + // compute index for position in b's flat array + return b_i0 * params.stride_src1_0 + + b_i1 * params.stride_src1_1 + + b_i2 * params.stride_src1_2 + + b_i3 * params.stride_src1_3; +} + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1 : array; + +#ifdef INPLACE +@group(0) @binding(2) +var params: Params; + +#elif defined(OVERLAP) +@group(0) @binding(2) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#endif + +fn op(a: DataType, b: DataType) -> DataType { +#ifdef OP_ADD + return a + b; +#elif defined(OP_SUB) + return a - b; +#elif defined(OP_MUL) + return a * b; +#elif defined(OP_DIV) + return a / b; +#endif +} + +fn update(dst_i: u32, src0_i: u32, src1_i: u32){ + let result = op(src0[src0_i], src1[src1_i]); + +#ifdef INPLACE + src0[dst_i] = result; +#elif defined(OVERLAP) + src1[dst_i] = result; +#else + dst[dst_i] = result; +#endif +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl deleted file mode 100644 index 4b254f468d..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +++ /dev/null @@ -1,45 +0,0 @@ -struct Params { - ne: u32, - - // offsets in elements - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - stride_src1_0: u32, - stride_src1_1: u32, - stride_src1_2: u32, - stride_src1_3: u32, - - a_ne0: u32, - a_ne1: u32, - a_ne2: u32, - - b_ne0: u32, - b_ne1: u32, - b_ne2: u32, - b_ne3: u32, -}; - -fn src1_index(_i: u32) -> u32 { - var i = _i; - let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); - i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); - let a_i2 = i / (params.a_ne1 * params.a_ne0); - i = i % (params.a_ne1 * params.a_ne0); - let a_i1 = i / params.a_ne0; - let a_i0 = i % params.a_ne0; - - // handle repetition of b - // index loops back to the beginning and repeats after elements are exhausted = modulo - let b_i0 = a_i0 % params.b_ne0; - let b_i1 = a_i1 % params.b_ne1; - let b_i2 = a_i2 % params.b_ne2; - let b_i3 = a_i3 % params.b_ne3; - - // compute index for position in b's flat array - return b_i0 * params.stride_src1_0 + - b_i1 * params.stride_src1_1 + - b_i2 * params.stride_src1_2 + - b_i3 * params.stride_src1_3; -} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 500cb6b72f..e2a6ff67be 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5749,7 +5749,7 @@ static struct ggml_tensor * ggml_unary_impl( struct ggml_tensor * a, enum ggml_unary_op op, bool inplace) { - GGML_ASSERT(ggml_is_contiguous_1(a)); + GGML_ASSERT(ggml_is_contiguous_rows(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3ddbc73d1c..9dab0df08a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -142,10 +142,13 @@ class Keys: EMBEDDING_SCALE = "{arch}.embedding_scale" TOKEN_SHIFT_COUNT = "{arch}.token_shift_count" INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step" + FULL_ATTENTION_INTERVAL = "{arch}.full_attention_interval" ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale" ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx" ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs" EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input" + SWIGLU_CLAMP_EXP = "{arch}.swiglu_clamp_exp" + SWIGLU_CLAMP_SHEXP = "{arch}.swiglu_clamp_shexp" DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in" DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out" @@ -179,20 +182,20 @@ class Keys: TEMPERATURE_SCALE = "{arch}.attention.temperature_scale" class Rope: - DIMENSION_COUNT = "{arch}.rope.dimension_count" - DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" - FREQ_BASE = "{arch}.rope.freq_base" - FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" - SCALING_TYPE = "{arch}.rope.scaling.type" - SCALING_FACTOR = "{arch}.rope.scaling.factor" - SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" - SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" - SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" - SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" - SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor" - SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor" - SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast" - SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow" + DIMENSION_COUNT = "{arch}.rope.dimension_count" + DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" + FREQ_BASE = "{arch}.rope.freq_base" + FREQ_BASE_SWA = "{arch}.rope.freq_base_swa" + SCALING_TYPE = "{arch}.rope.scaling.type" + SCALING_FACTOR = "{arch}.rope.scaling.factor" + SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" + SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" + SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" + SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor" + SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor" + SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast" + SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow" class Split: LLM_KV_SPLIT_NO = "split.no" @@ -382,6 +385,8 @@ class MODEL_ARCH(IntEnum): QWEN3NEXT = auto() QWEN3VL = auto() QWEN3VLMOE = auto() + QWEN35 = auto() + QWEN35MOE = auto() PHI2 = auto() PHI3 = auto() PHIMOE = auto() @@ -462,6 +467,7 @@ class MODEL_ARCH(IntEnum): PANGU_EMBED = auto() MISTRAL3 = auto() MIMO2 = auto() + STEP35 = auto() LLAMA_EMBED = auto() MAINCODER = auto() KIMI_LINEAR = auto() @@ -554,13 +560,14 @@ class MODEL_TENSOR(IntEnum): SSM_D = auto() SSM_NORM = auto() SSM_OUT = auto() + SSM_ALPHA = auto() # qwen3.5 SSM_BETA_ALPHA = auto() # qwen3next SSM_CONV1D_Q = auto() # Kimi Linear SSM_CONV1D_K = auto() # Kimi Linear SSM_CONV1D_V = auto() # Kimi Linear SSM_F_A = auto() # Kimi Linear SSM_F_B = auto() # Kimi Linear - SSM_BETA = auto() # Kimi Linear + SSM_BETA = auto() # Kimi Linear qwen3.5 SSM_G_A = auto() # Kimi Linear SSM_G_B = auto() # Kimi Linear TIME_MIX_W0 = auto() @@ -811,6 +818,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.QWEN3NEXT: "qwen3next", MODEL_ARCH.QWEN3VL: "qwen3vl", MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe", + MODEL_ARCH.QWEN35: "qwen35", + MODEL_ARCH.QWEN35MOE: "qwen35moe", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PHI3: "phi3", MODEL_ARCH.PHIMOE: "phimoe", @@ -892,6 +901,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.PANGU_EMBED: "pangu-embedded", MODEL_ARCH.MISTRAL3: "mistral3", MODEL_ARCH.MIMO2: "mimo2", + MODEL_ARCH.STEP35: "step35", MODEL_ARCH.LLAMA_EMBED: "llama-embed", MODEL_ARCH.MAINCODER: "maincoder", MODEL_ARCH.KIMI_LINEAR: "kimi-linear", @@ -981,13 +991,14 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.SSM_ALPHA: "blk.{bid}.ssm_alpha", # qwen3.5 MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba", MODEL_TENSOR.SSM_CONV1D_Q: "blk.{bid}.ssm_conv1d_q", # Kimi Linear MODEL_TENSOR.SSM_CONV1D_K: "blk.{bid}.ssm_conv1d_k", # Kimi Linear MODEL_TENSOR.SSM_CONV1D_V: "blk.{bid}.ssm_conv1d_v", # Kimi Linear MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a", # Kimi Linear MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b", # Kimi Linear - MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear + MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear qwen3.5 MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a", # Kimi Linear MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b", # Kimi Linear MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", @@ -1814,6 +1825,61 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.QWEN35: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_BETA, + MODEL_TENSOR.SSM_ALPHA, + MODEL_TENSOR.SSM_OUT + ], + MODEL_ARCH.QWEN35MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_INP_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_BETA, + MODEL_TENSOR.SSM_ALPHA, + MODEL_TENSOR.SSM_OUT + ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -3364,6 +3430,32 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_EXP_PROBS_B, ], + MODEL_ARCH.STEP35: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + ], MODEL_ARCH.LLAMA_EMBED: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -3674,6 +3766,7 @@ class VisionProjectorType: VOXTRAL = "voxtral" LFM2 = "lfm2" KIMIVL = "kimivl" + KIMIK25 = "kimik25" LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" @@ -3753,12 +3846,12 @@ KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS # RoPE -KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT -KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE -KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE -KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR -KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN -KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED +KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT +KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE +KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE +KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR +KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN +KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED # SSM KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index f720aa2d54..a237537c8d 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -708,6 +708,9 @@ class GGUFWriter: def add_leading_dense_block_count(self, length: int) -> None: self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length) + def add_full_attention_interval(self, interval: int) -> None: + self.add_uint32(Keys.LLM.FULL_ATTENTION_INTERVAL.format(arch=self.arch), interval) + def add_feed_forward_length(self, length: int | Sequence[int]) -> None: if isinstance(length, int): self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) @@ -824,6 +827,12 @@ class GGUFWriter: def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value) + def add_swiglu_clamp_exp(self, values: Sequence[float]) -> None: + self.add_array(Keys.LLM.SWIGLU_CLAMP_EXP.format(arch=self.arch), values) + + def add_swiglu_clamp_shexp(self, values: Sequence[float]) -> None: + self.add_array(Keys.LLM.SWIGLU_CLAMP_SHEXP.format(arch=self.arch), values) + def add_expert_group_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e16c06c2a3..43647904b4 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -228,6 +228,7 @@ class TensorNameMap: "transformer_encoder.{bid}.qkv", # neobert "layers.{bid}.attn.Wqkv", # modern-bert "model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm + "model.layers.{bid}.linear_attn.in_proj_qkv", # qwen3.5 ), # Attention query @@ -359,6 +360,8 @@ class TensorNameMap: MODEL_TENSOR.ATTN_GATE: ( "model.layers.{bid}.self_attn.gate_proj", # afmoe + "model.layers.{bid}.linear_attn.in_proj_z", # qwen3.5 + "model.layers.{bid}.self_attn.g_proj", # step3.5 head-wise attention gate ), # Feed-forward norm @@ -423,6 +426,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.router.gate", # afmoe "layers.{bid}.gate", # mistral-large "backbone.layers.{bid}.mixer.gate", # nemotron-h-moe + "model.layers.{bid}.moe.gate", # step3.5 ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -439,6 +443,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.gate.e_score_correction", # nemotron-h-moe "model.layers.{bid}.mlp.e_score_correction", # exaone-moe "model.layers.{bid}.block_sparse_moe.gate.e_score_correction", # kimi + "model.layers.{bid}.moe.router_bias", # step3.5 expert selection bias ), # Feed-forward up @@ -493,6 +498,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe "model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker + "model.layers.{bid}.moe.up_proj", # step3.5 ), MODEL_TENSOR.FFN_UP_SHEXP: ( @@ -504,6 +510,7 @@ class TensorNameMap: "layers.{bid}.shared_experts.w3", # mistral-large "backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe "model.layers.{bid}.block_sparse_moe.shared_experts.up_proj", # kimi + "model.layers.{bid}.share_expert.up_proj", # step3.5 ), MODEL_TENSOR.FFN_UP_CHEXP: ( @@ -543,6 +550,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 "model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker + "model.layers.{bid}.moe.gate_proj", # step3.5 ), MODEL_TENSOR.FFN_GATE_SHEXP: ( @@ -552,6 +560,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan "layers.{bid}.shared_experts.w1", # mistral-large "model.layers.{bid}.block_sparse_moe.shared_experts.gate_proj", # kimi + "model.layers.{bid}.share_expert.gate_proj", # step3.5 ), MODEL_TENSOR.FFN_GATE_CHEXP: ( @@ -606,6 +615,7 @@ class TensorNameMap: "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe "model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker + "model.layers.{bid}.moe.down_proj", # step3.5 ), MODEL_TENSOR.FFN_DOWN_SHEXP: ( @@ -617,6 +627,7 @@ class TensorNameMap: "layers.{bid}.shared_experts.w2", # mistral-large "backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe "model.layers.{bid}.block_sparse_moe.shared_experts.down_proj", # kimi + "model.layers.{bid}.share_expert.down_proj", # step3.5 ), MODEL_TENSOR.FFN_DOWN_CHEXP: ( @@ -814,6 +825,10 @@ class TensorNameMap: "model.layers.layers.{bid}.mixer.out_proj", # plamo2 ), + MODEL_TENSOR.SSM_ALPHA: ( + "model.layers.{bid}.linear_attn.in_proj_a", # qwen3.5 + ), + MODEL_TENSOR.SSM_BETA_ALPHA: ( "model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next ), @@ -835,7 +850,8 @@ class TensorNameMap: "model.layers.{bid}.self_attn.f_b_proj", ), MODEL_TENSOR.SSM_BETA: ( - "model.layers.{bid}.self_attn.b_proj", + "model.layers.{bid}.linear_attn.in_proj_b", # qwen3.5 + "model.layers.{bid}.self_attn.b_proj", # Kimi Linear ), MODEL_TENSOR.SSM_G_A: ( "model.layers.{bid}.self_attn.g_a_proj", @@ -1287,6 +1303,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", + "mm_projector.proj.linear_{bid}", # Kimi-K2.5 "visual.merger.mlp.{bid}", # qwen2vl "merger.mlp.{bid}", ), @@ -1348,6 +1365,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_ATTN_QKV: ( "visual.blocks.{bid}.attn.qkv", # qwen3vl "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm + "vision_tower.encoder.blocks.{bid}.wqkv" # Kimi-K2.5 ), MODEL_TENSOR.V_ENC_ATTN_Q: ( @@ -1522,6 +1540,7 @@ class TensorNameMap: "multi_modal_projector.norm", "multi_modal_projector.layer_norm", "multi_modal_projector.pre_norm", + "mm_projector.pre_norm", # Kimi-K2.5 "pre_mm_projector_norm", "model.vision.linear_proj.norm1", # cogvlm "merger.ln_q", diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index f6c4cd14e7..48693ae3e3 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -23,7 +23,7 @@ numpy = ">=1.17" tqdm = ">=4.27" pyyaml = ">=5.1" requests = ">=2.25" -sentencepiece = { version = ">=0.1.98,<=0.2.0", optional = true } +sentencepiece = { version = ">=0.1.98,<0.3.0", optional = true } PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true } [tool.poetry.dev-dependencies] diff --git a/include/llama.h b/include/llama.h index bf4e28a8be..305623127c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -482,7 +482,7 @@ extern "C" { enum llama_params_fit_status { LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path + LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path }; // fits mparams and cparams to free device memory (assumes system memory is unlimited) @@ -1150,9 +1150,9 @@ extern "C" { // /// Apply chat template. Inspired by hf apply_chat_template() on python. - /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" + /// /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template - /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. + /// @param tmpl A Jinja template to use for this chat. /// @param chat Pointer to a list of multiple llama_chat_message /// @param n_msg Number of llama_chat_message in this chat /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. diff --git a/pyproject.toml b/pyproject.toml index 3d71b055a8..422f53c7c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.9" numpy = "^1.25.0" -sentencepiece = ">=0.1.98,<=0.2.0" +sentencepiece = ">=0.1.98,<0.3.0" transformers = ">=4.35.2,<5.0.0" protobuf = ">=4.21.0,<5.0.0" gguf = { path = "./gguf-py" } diff --git a/requirements/requirements-convert_legacy_llama.txt b/requirements/requirements-convert_legacy_llama.txt index dbab3b9508..4898bf7ee2 100644 --- a/requirements/requirements-convert_legacy_llama.txt +++ b/requirements/requirements-convert_legacy_llama.txt @@ -1,5 +1,5 @@ numpy~=1.26.4 -sentencepiece~=0.2.0 +sentencepiece>=0.1.98,<0.3.0 transformers>=4.57.1,<5.0.0 diff --git a/scripts/pr2wt.sh b/scripts/pr2wt.sh index bd635f3b9d..067f5d466b 100755 --- a/scripts/pr2wt.sh +++ b/scripts/pr2wt.sh @@ -30,12 +30,18 @@ fi PR=$1 [[ "$PR" =~ ^[0-9]+$ ]] || { echo "error: PR number must be numeric"; exit 1; } +url_origin=$(git config --get remote.upstream.url 2>/dev/null) || \ url_origin=$(git config --get remote.origin.url) || { - echo "error: no remote named 'origin' in this repository" + echo "error: no remote named 'upstream' or 'origin' in this repository" exit 1 } -org_repo=$(echo $url_origin | cut -d/ -f4-) +# Extract org/repo from either https or ssh format. +if [[ $url_origin =~ ^git@ ]]; then + org_repo=$(echo $url_origin | cut -d: -f2) +else + org_repo=$(echo $url_origin | cut -d/ -f4-) +fi org_repo=${org_repo%.git} echo "org/repo: $org_repo" diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 1ff6a9a40f..68db04dea9 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -2,6 +2,8 @@ import urllib.request +HTTPLIB_VERSION = "f80864ca031932351abef49b74097c67f14719c6" + vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", "https://github.com/nlohmann/json/releases/latest/download/json_fwd.hpp": "vendor/nlohmann/json_fwd.hpp", @@ -12,8 +14,8 @@ vendor = { # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/httplib.h": "vendor/cpp-httplib/httplib.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/LICENSE": "vendor/cpp-httplib/LICENSE", + f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "vendor/cpp-httplib/httplib.h", + f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/LICENSE": "vendor/cpp-httplib/LICENSE", "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5238a5e934..fdda05d3ea 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -122,6 +122,8 @@ add_library(llama models/qwen3vl-moe.cpp models/qwen3moe.cpp models/qwen3next.cpp + models/qwen35.cpp + models/qwen35moe.cpp models/refact.cpp models/rnd1.cpp models/rwkv6-base.cpp @@ -135,6 +137,7 @@ add_library(llama models/stablelm.cpp models/starcoder.cpp models/starcoder2.cpp + models/step35-iswa.cpp models/t5-dec.cpp models/t5-enc.cpp models/wavtokenizer-dec.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a8bf1c9b80..a943d40dc4 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -37,6 +37,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, + { LLM_ARCH_QWEN35, "qwen35" }, + { LLM_ARCH_QWEN35MOE, "qwen35moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -117,7 +119,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, - { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_STEP35, "step35" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, @@ -162,6 +165,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, { LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" }, + { LLM_KV_SWIGLU_CLAMP_EXP, "%s.swiglu_clamp_exp" }, + { LLM_KV_SWIGLU_CLAMP_SHEXP, "%s.swiglu_clamp_shexp" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, @@ -192,6 +197,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, + { LLM_KV_FULL_ATTENTION_INTERVAL, "%s.full_attention_interval" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -220,21 +226,21 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, - { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, - { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, + { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -363,6 +369,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_ALPHA, "blk.%d.ssm_alpha" }, { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, @@ -965,7 +972,6 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_OUT, LLM_TENSOR_ATTN_QKV, LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, @@ -982,6 +988,63 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, }; + case LLM_ARCH_QWEN35: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_ALPHA, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; + case LLM_ARCH_QWEN35MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_ALPHA, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; case LLM_ARCH_QWEN3VL: case LLM_ARCH_CHAMELEON: case LLM_ARCH_HUNYUAN_DENSE: @@ -2279,6 +2342,35 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP_EXPS, LLM_TENSOR_FFN_EXP_PROBS_B, }; + case LLM_ARCH_STEP35: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; case LLM_ARCH_GPTJ: case LLM_ARCH_UNKNOWN: return { @@ -2424,6 +2516,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2643,6 +2736,8 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_KIMI_LINEAR: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index f092f72834..4f7b51e70d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -41,6 +41,8 @@ enum llm_arch { LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, + LLM_ARCH_QWEN35, + LLM_ARCH_QWEN35MOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, @@ -122,6 +124,7 @@ enum llm_arch { LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, LLM_ARCH_MIMO2, + LLM_ARCH_STEP35, LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, LLM_ARCH_KIMI_LINEAR, @@ -166,6 +169,8 @@ enum llm_kv { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, + LLM_KV_SWIGLU_CLAMP_EXP, + LLM_KV_SWIGLU_CLAMP_SHEXP, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, @@ -196,6 +201,7 @@ enum llm_kv { LLM_KV_EMBEDDING_SCALE, LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_INTERLEAVE_MOE_LAYER_STEP, + LLM_KV_FULL_ATTENTION_INTERVAL, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -401,13 +407,14 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next + LLM_TENSOR_SSM_ALPHA, // qwen3.5 // Kimi Linear KDA (using SSM_ prefix for consistency) LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B - LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient + LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient and qwen3.5 LLM_TENSOR_SSM_G_A, // kimi: output gate projection A LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a6df893a31..6b43ca1926 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -677,7 +677,7 @@ enum llama_pooling_type llama_context::pooling_type() const { float * llama_context::get_logits() { output_reorder(); - return logits; + return logits.data; } int64_t llama_context::output_resolve_row(int32_t i) const { @@ -715,7 +715,7 @@ float * llama_context::get_logits_ith(int32_t i) { output_reorder(); try { - if (logits == nullptr) { + if (logits.data == nullptr) { throw std::runtime_error("no logits"); } @@ -739,7 +739,7 @@ float * llama_context::get_logits_ith(int32_t i) { throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); } - return logits + j*model.vocab.n_tokens(); + return logits.data + j*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -753,11 +753,11 @@ float * llama_context::get_logits_ith(int32_t i) { float * llama_context::get_embeddings() { output_reorder(); - return embd; + return embd.data; } llama_token * llama_context::get_sampled_tokens() const{ - return sampling.sampled; + return sampling.sampled.data; } float * llama_context::get_embeddings_ith(int32_t i) { @@ -766,7 +766,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { output_reorder(); try { - if (embd == nullptr) { + if (embd.data == nullptr) { throw std::runtime_error("no embeddings"); } @@ -791,7 +791,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { } const uint32_t n_embd_out = model.hparams.n_embd_out(); - return embd + j*n_embd_out; + return embd.data + j*n_embd_out; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -814,14 +814,14 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); - if (sampling.sampled == nullptr) { + if (!sampling.sampled.has_data()) { return LLAMA_TOKEN_NULL; } try { const int64_t row = output_resolve_row(idx); - GGML_ASSERT(row < (int64_t) sampling.sampled_size); - return sampling.sampled[row]; + GGML_ASSERT(row < (int64_t) sampling.sampled.size); + return sampling.sampled.data[row]; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); return LLAMA_TOKEN_NULL; @@ -831,7 +831,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) { float * llama_context::get_sampled_probs_ith(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return nullptr; } @@ -840,7 +840,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { return nullptr; } - return sampling.probs + row*model.vocab.n_tokens(); + return sampling.probs.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -850,7 +850,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { float * llama_context::get_sampled_logits_ith(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return nullptr; } @@ -859,7 +859,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) { if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { return nullptr; } - return sampling.logits + row*model.vocab.n_tokens(); + return sampling.logits.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -871,10 +871,10 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { try { const int64_t row = output_resolve_row(idx); - if (sampling.candidates != nullptr && + if (sampling.candidates.has_data() && (size_t) row < sampling.candidates_count.size() && sampling.candidates_count[row] > 0) { - return sampling.candidates + row*model.vocab.n_tokens(); + return sampling.candidates.data + row*model.vocab.n_tokens(); } } catch (const std::exception & err) { // fallback to full vocab list @@ -886,7 +886,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { size_t llama_context::get_sampled_candidates_count(int32_t idx) { output_reorder(); - if (sampling.candidates == nullptr) { + if (!sampling.candidates.has_data()) { return 0; } @@ -905,7 +905,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) { size_t llama_context::get_sampled_logits_count(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return model.vocab.n_tokens(); } @@ -924,7 +924,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) { size_t llama_context::get_sampled_probs_count(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return 0; } @@ -1254,16 +1254,16 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (logits && t_logits) { + if (logits.data && t_logits) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float)); } // extract embeddings - if (embd && t_embd) { + if (embd.data && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1271,11 +1271,11 @@ int llama_context::encode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); + GGML_ASSERT(embd.data != nullptr); const uint32_t n_embd_out = hparams.n_embd_out(); - GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float)); + GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); } break; case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: @@ -1323,7 +1323,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cross.n_embd = t_embd->ne[0]; cross.n_enc = t_embd->ne[1]; cross.v_embd.resize(cross.n_embd*cross.n_enc); - memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); + memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); const auto & batch = balloc->get_batch(); @@ -1363,11 +1363,10 @@ static std::map build_seq_to_output_row(const llama_ubat static void copy_tensor_async_ints( const std::map & tensor_map, - llama_token * sampled, - size_t sampled_size, + const buffer_view & sampled, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (sampled == nullptr) { + if (!sampled.has_data()) { return; } @@ -1378,23 +1377,23 @@ static void copy_tensor_async_ints( } const uint32_t row = it->second; - GGML_ASSERT(row < sampled_size); + GGML_ASSERT(row < sampled.size); GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); + ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row])); } } static void copy_tensor_async_floats( const std::map & tensor_map, - float * dst, + const buffer_view & dst, size_t stride, std::vector & counts, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1410,7 +1409,7 @@ static void copy_tensor_async_floats( GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - float * row_ptr = dst + (size_t) row * stride; + float * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of logits/probabilities that were written for this row. @@ -1420,12 +1419,12 @@ static void copy_tensor_async_floats( static void copy_tensor_async_candidates( const std::map & tensor_map, - llama_token * dst, + const buffer_view & dst, size_t stride, std::vector & counts, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1441,7 +1440,7 @@ static void copy_tensor_async_candidates( GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - llama_token * row_ptr = dst + (size_t) row * stride; + llama_token * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of candidates that were written. @@ -1671,22 +1670,22 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - if (logits && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { + if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - float * logits_out = logits + n_outputs_prev*n_vocab; + float * logits_out = logits.data + n_outputs_prev*n_vocab; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size); ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); } } // extract embeddings - if (embd && t_embd && n_outputs > 0) { + if (embd.data && t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1694,13 +1693,13 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); + GGML_ASSERT(embd.data != nullptr); const uint32_t n_embd_out = hparams.n_embd_out(); - float * embd_out = embd + n_outputs_prev*n_embd_out; + float * embd_out = embd.data + n_outputs_prev*n_embd_out; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size); ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float)); } } break; @@ -1747,7 +1746,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto stride = n_vocab; // async copy the sampling data from the backend to the host - copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); + copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); @@ -1841,19 +1840,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits_size = has_logits ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); if (has_sampling) { - sampling.logits_size = n_vocab*n_outputs_max; - sampling.probs_size = n_vocab*n_outputs_max; - sampling.sampled_size = n_outputs_max; - sampling.candidates_size = n_vocab*n_outputs_max; - - backend_float_count = sampling.logits_size + sampling.probs_size; - backend_token_count = sampling.sampled_size + sampling.candidates_size; + backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs + backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates } if (output_ids.empty()) { @@ -1863,7 +1857,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits_size + embd_size + backend_float_count) * sizeof(float) + + (logits.size + embd.size + backend_float_count) * sizeof(float) + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required @@ -1878,8 +1872,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { // TODO: not needed? buf_output = nullptr; - logits = nullptr; - embd = nullptr; + logits.data = nullptr; + embd.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1898,35 +1892,32 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - logits = nullptr; - embd = nullptr; - size_t offset = 0; uint8_t * base = (uint8_t *) output_base; - logits = has_logits ? output_base : nullptr; - offset += logits_size * sizeof(float); + logits = has_logits ? buffer_view{output_base, logits.size} : buffer_view{nullptr, 0}; + offset += logits.size * sizeof(float); - embd = has_embd ? (float *) (base + offset) : nullptr; - offset += embd_size * sizeof(float); + embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; + offset += embd.size * sizeof(float); - sampling.logits = nullptr; - sampling.probs = nullptr; - sampling.sampled = nullptr; - sampling.candidates = nullptr; + sampling.logits = {nullptr, 0}; + sampling.probs = {nullptr, 0}; + sampling.sampled = {nullptr, 0}; + sampling.candidates = {nullptr, 0}; if (has_sampling) { - sampling.logits = (float *) (base + offset); - offset += sampling.logits_size * sizeof(float); + sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.logits.size * sizeof(float); - sampling.probs = (float *) (base + offset); - offset += sampling.probs_size * sizeof(float); + sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.probs.size * sizeof(float); - sampling.sampled = (llama_token *) (base + offset); - offset += sampling.sampled_size * sizeof(llama_token); + sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max}; + offset += sampling.sampled.size * sizeof(llama_token); - sampling.candidates = (llama_token *) (base + offset); - offset += sampling.candidates_size * sizeof(llama_token); + sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.candidates.size * sizeof(llama_token); // The count vectors keep track of the actual number of logits/probs/candidates // copied from the backend for each output row. @@ -1939,7 +1930,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); - std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); + std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); } // set all ids as invalid (negative) @@ -1958,38 +1949,38 @@ void llama_context::output_reorder() { const uint64_t i0 = output_swaps[s].i0; const uint64_t i1 = output_swaps[s].i1; - if (logits_size > 0) { + if (logits.size > 0) { for (uint64_t k = 0; k < n_vocab; k++) { - std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); + std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]); } } - if (embd_size > 0) { + if (embd.size > 0) { for (uint64_t k = 0; k < n_embd; k++) { - std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); + std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]); } } - if (sampling.logits && sampling.logits_size > 0) { + if (sampling.logits.has_data()) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); + std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); } } - if (sampling.probs && sampling.probs_size > 0) { + if (sampling.probs.has_data()) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); + std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); } } - if (sampling.candidates && sampling.candidates_size > 0) { + if (sampling.candidates.has_data()) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); + std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); } } - if (sampling.sampled && sampling.sampled_size > 0) { - std::swap(sampling.sampled[i0], sampling.sampled[i1]); + if (sampling.sampled.has_data()) { + std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); } if (!sampling.logits_count.empty()) { @@ -2013,7 +2004,7 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } uint32_t res = std::max(1024u, 8u*model.n_tensors()); @@ -2533,12 +2524,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { { LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); - const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); + const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens()); io.write(&logits_size, sizeof(logits_size)); if (logits_size) { - io.write(logits, logits_size * sizeof(float)); + io.write(logits.data, logits_size * sizeof(float)); } } @@ -2546,12 +2537,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { { LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); - const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); + const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd); io.write(&embd_size, sizeof(embd_size)); if (embd_size) { - io.write(embd, embd_size * sizeof(float)); + io.write(embd.data, embd_size * sizeof(float)); } } @@ -2619,12 +2610,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { uint64_t logits_size; io.read_to(&logits_size, sizeof(logits_size)); - if (this->logits_size < logits_size) { + if (this->logits.size < logits_size) { throw std::runtime_error("logits buffer too small"); } if (logits_size) { - io.read_to(this->logits, logits_size * sizeof(float)); + io.read_to(this->logits.data, logits_size * sizeof(float)); } } @@ -2635,12 +2626,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { uint64_t embd_size; io.read_to(&embd_size, sizeof(embd_size)); - if (this->embd_size < embd_size) { + if (this->embd.size < embd_size) { throw std::runtime_error("embeddings buffer too small"); } if (embd_size) { - io.read_to(this->embd, embd_size * sizeof(float)); + io.read_to(this->embd.data, embd_size * sizeof(float)); } } diff --git a/src/llama-context.h b/src/llama-context.h index 8e71cdd1dc..d995117574 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -4,6 +4,7 @@ #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" +#include "llama-impl.h" #include "ggml-cpp.h" #include "ggml-opt.h" @@ -269,29 +270,19 @@ private: std::unique_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) - size_t logits_size = 0; // capacity (of floats) for logits - float * logits = nullptr; + struct buffer_view logits = {nullptr, 0}; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE - size_t embd_size = 0; // capacity (of floats) for embeddings - float * embd = nullptr; + struct buffer_view embd = {nullptr, 0}; - // TODO: simplify struct sampling_info { std::map samplers; - float * logits = nullptr; - size_t logits_size = 0; - - llama_token * sampled = nullptr; - size_t sampled_size = 0; - - float * probs = nullptr; - size_t probs_size = 0; - - llama_token * candidates = nullptr; - size_t candidates_size = 0; + struct buffer_view logits = {nullptr, 0}; + struct buffer_view sampled = {nullptr, 0}; + struct buffer_view probs = {nullptr, 0}; + struct buffer_view candidates = {nullptr, 0}; std::vector logits_count; std::vector probs_count; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 165cbc0a7d..bba747d37b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -13,6 +13,8 @@ #include #include #include +#include +#include #include void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { @@ -1014,6 +1016,26 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate && type_gate == LLM_FFN_PAR) { + // Step35: HF clamps gate (after SiLU) and up before multiplication + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_shexp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_silu_clamped", il); + + tmp = ggml_clamp(ctx0, tmp, -limit, limit); + cb(tmp, "ffn_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, tmp); + cb(cur, "ffn_swiglu_limited", il); + type_gate = LLM_FFN_SEQ; + break; + } + } + cur = ggml_swiglu_split(ctx0, cur, tmp); cb(cur, "ffn_swiglu", il); type_gate = LLM_FFN_SEQ; @@ -1316,6 +1338,25 @@ ggml_tensor * llm_graph_context::build_moe_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate_exps) { + // Step35: per-layer clamp for routed experts + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_exp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_moe_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_moe_silu_clamped", il); + + up = ggml_clamp(ctx0, up, -limit, limit); + cb(up, "ffn_moe_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, up); + cb(cur, "ffn_moe_swiglu_limited", il); + break; + } + } + cur = ggml_swiglu_split(ctx0, cur, up); cb(cur, "ffn_moe_swiglu", il); } else { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index a435043cfe..706eda8441 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -42,7 +42,6 @@ struct llama_hparams { uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; - uint32_t n_embd_features = 0; uint32_t n_layer; int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; @@ -206,6 +205,11 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + + // Step35: optional per-layer clamps for (Swi)GLU + std::array swiglu_clamp_exp; // clamping for expert FFN + std::array swiglu_clamp_shexp; // shared expert + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) // dense_first means whether the pattern is start with a dense layer // note that if n_pattern == 0, all layers are SWA diff --git a/src/llama-impl.h b/src/llama-impl.h index c3391e79f5..dfd9fee9f4 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -49,6 +49,16 @@ struct time_meas { int64_t & t_acc; }; +template +struct buffer_view { + T * data; + size_t size = 0; + + bool has_data() const { + return data && size > 0; + } +}; + void replace_all(std::string & s, const std::string & search, const std::string & replace); // TODO: rename to llama_format ? diff --git a/src/llama-kv-cache-iswa.cpp b/src/llama-kv-cache-iswa.cpp index 3a34102a23..26e2cb4270 100644 --- a/src/llama-kv-cache-iswa.cpp +++ b/src/llama-kv-cache-iswa.cpp @@ -218,7 +218,9 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, } bool llama_kv_cache_iswa::get_can_shift() const { - return kv_base->get_size() == kv_swa->get_size(); + return kv_base->get_can_shift() && + kv_swa->get_can_shift() && + kv_base->get_size() == kv_swa->get_size(); } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c35cd6761b..cb702b2a59 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -974,6 +974,10 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & } bool llama_kv_cache::get_can_shift() const { + // Step35 uses per-layer RoPE dims; K-shift assumes a single global n_rot. + if (model.arch == LLM_ARCH_STEP35) { + return false; + } return true; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 765e4de2e4..6b7da69e9d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -125,11 +125,13 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; + case LLM_TYPE_35B_A3B: return "35B.A3B"; case LLM_TYPE_48B_A3B: return "48B.A3B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_196B_A11B: return "196B.A11B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; @@ -521,7 +523,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { - ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); + ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl); ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); ml.get_key(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); @@ -560,6 +563,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f); + std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -2400,8 +2405,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // Mark recurrent layers (linear attention layers) - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval" + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } } switch (hparams.n_layer) { @@ -2409,6 +2418,62 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN35: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN35MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_35B_A3B; break; + case 48: type = LLM_TYPE_80B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_MISTRAL3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2482,6 +2547,35 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_STEP35: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + // MoE + SWA parameters + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Step35 uses sigmoid gating by default (if not set in GGUF) + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); + + switch (hparams.n_layer) { + case 45: type = LLM_TYPE_196B_A11B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -5953,9 +6047,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_WAVTOKENIZER_DEC: { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0); + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); // posnet @@ -6051,8 +6145,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); } - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {hparams.n_embd_out()}, 0); } break; case LLM_ARCH_BAILINGMOE: { @@ -7069,6 +7163,131 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); } } break; + case LLM_ARCH_QWEN35MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + + // Shared experts + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + } + } break; + case LLM_ARCH_QWEN35: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_MIMO2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7107,6 +7326,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_STEP35: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor + // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. + uint32_t n_rot_max = 0; + for (int i = 0; i < n_layer; ++i) { + n_rot_max = std::max(n_rot_max, hparams.n_rot); + } + if (n_rot_max == 0) { + n_rot_max = n_rot; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + // optional rope factors (llama3) / longrope tensors + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_l}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + + // head-wise attention gate (Step35 self_attn.g_proj) + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // dense MLP (leading dense blocks) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + // shared expert MLP + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + } + } break; case LLM_ARCH_MAINCODER: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7447,6 +7732,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); @@ -7678,7 +7965,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, nullptr); } else if (llm_arch_is_hybrid(arch)) { - // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -7703,7 +7989,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, /* attn_swa_full */ params.swa_full, - /* attn_kv_size */ cparams.n_ctx, + /* attn_kv_size */ cparams.n_ctx_seq, /* attn_n_ubatch */ cparams.n_ubatch, /* attn_n_pad */ 1, /* recurrent_type_r */ GGML_TYPE_F32, @@ -7720,7 +8006,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_type_k */ params.type_k, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx, + /* attn_kv_size */ cparams.n_ctx_seq, /* attn_n_pad */ 1, /* attn_n_swa */ hparams.n_swa, /* attn_swa_type */ hparams.swa_type, @@ -8245,6 +8531,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_QWEN35: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_QWEN35MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_MISTRAL3: { llm = std::make_unique(*this, params); @@ -8257,6 +8551,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_STEP35: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -8502,12 +8800,15 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_AFMOE: case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_MIMO2: + case LLM_ARCH_STEP35: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: return LLAMA_ROPE_TYPE_MROPE; case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VLMOE: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return LLAMA_ROPE_TYPE_IMROPE; case LLM_ARCH_GLM4: diff --git a/src/llama-model.h b/src/llama-model.h index 5b408bcea2..adc8ff6479 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -118,11 +118,13 @@ enum llm_type { LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, + LLM_TYPE_35B_A3B, // Qwen3.5 LLM_TYPE_48B_A3B, // Kimi Linear LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_196B_A11B, // Step3.5-Flash LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big @@ -321,6 +323,9 @@ struct llama_layer { // qwen3next struct ggml_tensor * ssm_beta_alpha = nullptr; + // qwen3.5 + struct ggml_tensor * ssm_alpha = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 6d6bdfa090..62e137fb84 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -368,6 +368,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_QWEN35: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_PORO: case LLAMA_VOCAB_PRE_TYPE_BLOOM: case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: @@ -1926,6 +1933,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "kormo") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; + } else if ( + tokenizer_pre == "qwen35") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN35; + clean_spaces = false; } else if ( tokenizer_pre == "stablelm2") { pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 28c3a82b91..718238fb86 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -54,6 +54,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, + LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, }; struct LLM_KV; diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 0f037d1a39..942844d071 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -41,8 +41,11 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]); ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_x, - ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs, - (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); + ggml_view_3d(ctx0, conv_states_all, + d_conv - 1, d_inner, n_seqs, + (d_conv - 1) * ggml_element_size(conv_states_all), // nb1: contiguous within one channel's conv taps + n_embd_r_total * ggml_element_size(conv_states_all), // nb2: stride between sequences (skip over K,V states) + (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); // offset to first seq's Q/K/V state // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner] // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv] // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step] diff --git a/src/models/models.h b/src/models/models.h index 71c1fe8108..3c66d32531 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -476,6 +476,7 @@ struct llm_build_qwen3vl : public llm_graph_context { struct llm_build_qwen3vlmoe : public llm_graph_context { llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); }; + struct llm_build_qwen3next : public llm_graph_context_mamba { llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); private: @@ -534,6 +535,124 @@ private: const llama_model & model; }; +struct llm_build_qwen35 : public llm_graph_context_mamba { + llm_build_qwen35(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + // returns pair of output and new state + std::pair build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il); + + // returns pair of output and new state + std::pair build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; +}; + +struct llm_build_qwen35moe : public llm_graph_context_mamba { + llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + // returns pair of output and new state + std::pair build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il); + + // returns pair of output and new state + std::pair build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; +}; + struct llm_build_qwen : public llm_graph_context { llm_build_qwen(const llama_model & model, const llm_graph_params & params); }; @@ -583,6 +702,10 @@ struct llm_build_starcoder : public llm_graph_context { llm_build_starcoder(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_step35_iswa : public llm_graph_context { + llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_t5_dec : public llm_graph_context { llm_build_t5_dec(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp new file mode 100644 index 0000000000..592c170457 --- /dev/null +++ b/src/models/qwen35.cpp @@ -0,0 +1,740 @@ +#include "ggml.h" +#include "models.h" + +#define CHUNK_SIZE 64 + +llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) : + llm_graph_context_mamba(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + cb(inpL, "model.input_embed", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * causal_mask = + ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), + GGML_TRI_TYPE_LOWER); + + ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); + ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); + + ggml_build_forward_expand(gf, causal_mask); + ggml_build_forward_expand(gf, identity); + ggml_build_forward_expand(gf, diag_mask); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recurrent(il)) { + // Linear attention layer (gated delta net) + cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + } else { + // Full attention layer + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + // Dense FFN layer - without residual connection + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_ffn", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + // Final norm + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +std::pair llm_build_qwen35::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state, "state_in", il); + + GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + + // Do padding + const int64_t chunk_size = CHUNK_SIZE; + + const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; + const int64_t n_chunks = (n_tokens + pad) / chunk_size; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, pad, 0, 0, 0); + beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + + cb(q, "q_pad", il); + cb(k, "k_pad", il); + cb(v, "v_pad", il); + cb(beta, "beta_pad", il); + cb(g, "g_pad", il); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + cb(v_beta, "v_beta", il); + cb(k_beta, "k_beta", il); + + q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); + cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + attn = ggml_add(ctx0, attn, identity); + cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * k_cumdecay = + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); + attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); + cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along chunk_size dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], + g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], + (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); + g_last = ggml_cont(ctx0, g_last); + cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); + cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, + 1, chunk_size, n_chunks, g_diff_exp->ne[3]); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); + cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); + cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) + + // state to be updated per chunk + ggml_tensor * new_state = state; // ggml_dup(ctx0, state); + cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) + + // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) + ggml_tensor * core_attn_out = nullptr; + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + // shape: (S_k, chunk_size, 1, H_k * n_seqs) + ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul + + // shape: (S_v, chunk_size, 1, H_v * n_seqs) + ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat + + // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul + + // shape: (chunk_size, 1, H_v * n_seqs) + ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat + + // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + // replaced by precomputed attn_kq + ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); + cb(attn_chunk, "attn_chunk", il); + + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + + // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) + + // v_new = v_i - v_prime + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + cb(v_new, "v_new_chunk", il); + + // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + cb(attn_inter, "attn_inter_chunk", il); + + // core_attn_out[:, :, i] = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); + cb(v_attn, "v_attn_chunk", il); + + ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) + + core_attn_out = core_attn_out == nullptr + ? core_attn_out_chunk + : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); + + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); + //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); + + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); + new_state = ggml_add(ctx0, + ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), + ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + } + + // truncate padded tokens + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(core_attn_out->type, S_v), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); + output_tokens = ggml_cont(ctx0, output_tokens); + cb(output_tokens, "output_tokens", il); + + // permute back to (S_v, H_v, n_tokens, n_seqs) + output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); + output_tokens = ggml_cont(ctx0, output_tokens); + + return {output_tokens, new_state}; +} + +std::pair llm_build_qwen35::build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); + ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); + + // Apply exponential to g_t + g_t = ggml_exp(ctx0, g_t); + + // Apply the gated delta rule for the single timestep + // last_recurrent_state = last_recurrent_state * g_t + state = ggml_mul(ctx0, state, g_t); + + // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); + ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); + // we need to sum over dim=-2, so we transpose, sum, then transpose again + kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); + + // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) + ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); + // delta = (v_t - kv_mem) * beta_t + ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] + ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); + + // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta + ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); + state = ggml_add(ctx0, state, k_t_delta); + + // Compute the attention output + // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t + ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); + // again, since it's over dim = -2, transpose, sum, transpose back + ggml_tensor * core_attn_out = + ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); + + // core_attn_out should be [S_v, 1, H_v, n_seqs] after this + cb(core_attn_out, "output_tokens", il); + cb(state, "new_state", il); + + return {core_attn_out, state}; +} + +std::pair llm_build_qwen35::build_qkvz( + ggml_tensor * input, + int il) { + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + cb(z, "z", il); + + return { qkv_mixed, z }; +} + +ggml_tensor * llm_build_qwen35::build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + + return ggml_mul(ctx0, normalized, gated_silu); +} + +ggml_tensor * llm_build_qwen35::build_layer_attn( + llm_graph_input_attn_kv * inp, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + cb(Qcur_full, "Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply MRoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp, + nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "attn_output", il); + + return cur; +} + +ggml_tensor * llm_build_qwen35::build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = mctx_cur->get_head(); + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; + + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); + cb(beta, "beta", il); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + cb(alpha, "alpha", il); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + // Get convolution states from cache + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); + + // Build the convolution states tensor + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + // Calculate convolution kernel size + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); + cb(qkv_mixed, "qkv_mixed_permuted", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + // Update convolution state cache + // Extract the last (conv_kernel_size - 1) states from conv_input + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + cb(conv_states_all, "conv_states_updated", il); + + // Apply SSM convolution + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = conv_output_silu; + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + cb(q_conv, "q_conv", il); + ggml_tensor * k_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(k_conv, "k_conv", il); + ggml_tensor * v_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, + 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(v_conv, "v_conv", il); + + // Unsqueeze them + q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); + cb(state, "state_predelta", il); + + // if head keys and value keys are different, repeat Q/K to match V's head count + // V heads are in tiled order (from conversion), so simple tiled repeat works + if (num_k_heads != num_v_heads) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens + std::pair attn_out; // pair of (output, new_state) + if (n_seq_tokens == 1) { + attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); + } else { + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + } + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + // Reshape both attn_out_final and z to 2D tensors for normalization + // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; +} + +ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) { + // Qwen3.5 does not use MoE FFN + GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + return cur; +} diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp new file mode 100644 index 0000000000..0db8f825c6 --- /dev/null +++ b/src/models/qwen35moe.cpp @@ -0,0 +1,774 @@ +#include "ggml.h" +#include "models.h" + +#define CHUNK_SIZE 64 + +llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) : + llm_graph_context_mamba(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + cb(inpL, "model.input_embed", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * causal_mask = + ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), + GGML_TRI_TYPE_LOWER); + + ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); + ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); + + ggml_build_forward_expand(gf, causal_mask); + ggml_build_forward_expand(gf, identity); + ggml_build_forward_expand(gf, diag_mask); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recurrent(il)) { + // Linear attention layer (gated delta net) + cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + } else { + // Full attention layer + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + // MOE FFN layer + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_moe", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + // Final norm + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +std::pair llm_build_qwen35moe::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state, "state_in", il); + + GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + + // Do padding + const int64_t chunk_size = CHUNK_SIZE; + + const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; + const int64_t n_chunks = (n_tokens + pad) / chunk_size; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, pad, 0, 0, 0); + beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + + cb(q, "q_pad", il); + cb(k, "k_pad", il); + cb(v, "v_pad", il); + cb(beta, "beta_pad", il); + cb(g, "g_pad", il); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + cb(v_beta, "v_beta", il); + cb(k_beta, "k_beta", il); + + q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); + cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + attn = ggml_add(ctx0, attn, identity); + cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * k_cumdecay = + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); + attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); + cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along chunk_size dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], + g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], + (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); + g_last = ggml_cont(ctx0, g_last); + cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); + cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, + 1, chunk_size, n_chunks, g_diff_exp->ne[3]); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); + cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); + cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) + + + // state to be updated per chunk + ggml_tensor * new_state = state; // ggml_dup(ctx0, state); + cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) + + // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) + ggml_tensor * core_attn_out = nullptr; + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + // shape: (S_k, chunk_size, 1, H_k * n_seqs) + ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul + + // shape: (S_v, chunk_size, 1, H_v * n_seqs) + ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat + + // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul + + // shape: (chunk_size, 1, H_v * n_seqs) + ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat + + // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + // replaced by precomputed attn_kq + ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); + cb(attn_chunk, "attn_chunk", il); + + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + + // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) + + // v_new = v_i - v_prime + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + cb(v_new, "v_new_chunk", il); + + // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + cb(attn_inter, "attn_inter_chunk", il); + + // core_attn_out[:, :, i] = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); + cb(v_attn, "v_attn_chunk", il); + + ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) + + core_attn_out = core_attn_out == nullptr + ? core_attn_out_chunk + : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); + + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); + //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); + + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); + new_state = ggml_add(ctx0, + ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), + ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + } + + // truncate padded tokens + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(core_attn_out->type, S_v), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); + output_tokens = ggml_cont(ctx0, output_tokens); + cb(output_tokens, "output_tokens", il); + + // permute back to (S_v, H_v, n_tokens, n_seqs) + output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); + output_tokens = ggml_cont(ctx0, output_tokens); + + return {output_tokens, new_state}; +} + +std::pair llm_build_qwen35moe::build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); + ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); + + // Apply exponential to g_t + g_t = ggml_exp(ctx0, g_t); + + // Apply the gated delta rule for the single timestep + // last_recurrent_state = last_recurrent_state * g_t + state = ggml_mul(ctx0, state, g_t); + + // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); + ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); + // we need to sum over dim=-2, so we transpose, sum, then transpose again + kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); + + // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) + ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); + // delta = (v_t - kv_mem) * beta_t + ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] + ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); + + // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta + ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); + state = ggml_add(ctx0, state, k_t_delta); + + // Compute the attention output + // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t + ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); + // again, since it's over dim = -2, transpose, sum, transpose back + ggml_tensor * core_attn_out = + ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); + + // core_attn_out should be [S_v, 1, H_v, n_seqs] after this + cb(core_attn_out, "output_tokens", il); + cb(state, "new_state", il); + + return {core_attn_out, state}; +} + +std::pair llm_build_qwen35moe::build_qkvz( + ggml_tensor * input, + int il) { + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + cb(z, "z", il); + + return { qkv_mixed, z }; +} + +ggml_tensor * llm_build_qwen35moe::build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + + return ggml_mul(ctx0, normalized, gated_silu); +} + +ggml_tensor * llm_build_qwen35moe ::build_layer_attn( + llm_graph_input_attn_kv * inp, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + cb(Qcur_full, "Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply IMRoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp, + nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "attn_output", il); + + return cur; +} + +ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = mctx_cur->get_head(); + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; + + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); + cb(beta, "beta", il); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + cb(alpha, "alpha", il); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + // Get convolution states from cache + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); + + // Build the convolution states tensor + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + // Calculate convolution kernel size + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); + cb(qkv_mixed, "qkv_mixed_permuted", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + // Update convolution state cache + // Extract the last (conv_kernel_size - 1) states from conv_input + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + cb(conv_states_all, "conv_states_updated", il); + + // Apply SSM convolution + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = conv_output_silu; + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + cb(q_conv, "q_conv", il); + ggml_tensor * k_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(k_conv, "k_conv", il); + ggml_tensor * v_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, + 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(v_conv, "v_conv", il); + + // Unsqueeze them + q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); + cb(state, "state_predelta", il); + + // if head keys and value keys are different, repeat Q/K to match V's head count + // V heads are in tiled order (from conversion), so simple tiled repeat works + if (num_k_heads != num_v_heads) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens + std::pair attn_out; // pair of (output, new_state) + if (n_seq_tokens == 1) { + attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); + } else { + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + } + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + // Reshape both attn_out_final and z to 2D tensors for normalization + // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; +} + +ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int il) { + // Check if this is an MoE layer + GGML_ASSERT(model.layers[il].ffn_gate_inp != nullptr); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, LLM_FFN_SILU, + true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + cb(moe_out, "ffn_moe_out", il); + + // Add shared experts if present - following Qwen3Next reference implementation + if (model.layers[il].ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + // Apply shared expert gating as in the reference implementation + // The shared expert has its own gate that is sigmoided + // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token) + ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); + cb(shared_gate, "shared_expert_gate", il); + + // Apply sigmoid to the gate + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "shared_expert_gate_sigmoid", il); + + + // Apply the gate to the shared expert output + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + + return cur; +} diff --git a/src/models/step35-iswa.cpp b/src/models/step35-iswa.cpp new file mode 100644 index 0000000000..f8737815a6 --- /dev/null +++ b/src/models/step35-iswa.cpp @@ -0,0 +1,168 @@ +#include "models.h" + +llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + const uint32_t n_head_l = hparams.n_head(il); + const uint32_t n_head_kv_l = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + cur = inpL; + + // dump pre-attn RMSNorm input to pinpoint layer boundary issues + cb(cur, "attn_norm_in", il); + + // self-attention + { + cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + + // Q/K per-head RMSNorm (Step35 q_norm / k_norm) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + // RoPE (partial rotary factors per layer) + const bool is_swa = hparams.is_swa(il); + ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il); + const int64_t n_rot_l = is_swa ? hparams.n_rot : (hparams.n_rot / 2); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); + ggml_tensor * attn_out = build_attn(inp_attn, + nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(attn_out, "attn_out", il); + // head-wise attention gate: sigmoid(g_proj(x)) in torch + if (model.layers[il].wqkv_gate) { + ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, cur); // [n_head_l, n_tokens] + cb(gate, "attn_gate", il); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "attn_gate_sigmoid", il); + + // reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens] + ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens); + ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); + cb(gate_3d, "attn_gate_3d", il); + + attn_3d = ggml_mul(ctx0, attn_3d, gate_3d); + cb(attn_3d, "attn_gated_3d", il); + + attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens); + cb(attn_out, "attn_gated", il); + } + + // output projection + cur = build_lora_mm(model.layers[il].wo, attn_out); + cb(cur, "attn_proj", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward + if (model.layers[il].ffn_gate_inp == nullptr) { + // dense MLP + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, nullptr, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE routed experts + const bool norm_w = hparams.expert_weights_norm; + const float w_scale = hparams.expert_weights_scale; + const bool scale_w = w_scale != 0.0f; + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, + norm_w, scale_w, w_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + // shared expert MLP (always added on MoE layers in Step35) + ggml_tensor * sh_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, nullptr, nullptr, + model.layers[il].ffn_gate_shexp, nullptr, nullptr, + model.layers[il].ffn_down_shexp, nullptr, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(sh_out, "ffn_shared_out", il); + + cur = ggml_add(ctx0, moe_out, sh_out); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/unicode.cpp b/src/unicode.cpp index b47dcbe619..b88d953bd2 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -1,16 +1,10 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "unicode.h" #include "unicode-data.h" #include #include -#include #include #include -#include #include #include #include @@ -199,27 +193,6 @@ static std::unordered_map unicode_utf8_to_byte_map() { return map; } -static inline std::wstring unicode_wstring_from_utf8(const std::string & s) { -#if defined(__clang__) - // disable C++17 deprecation warning for std::codecvt_utf8 -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - - std::wstring_convert> conv; - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - - return conv.from_bytes(s); -} - static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) { std::vector bpe_encoded_words; for (const auto & word : bpe_words) { @@ -497,49 +470,26 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & return bpe_offsets; } -// use std::wregex to split the text -static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) { - std::wregex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs); +template +static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) { + using BidirIt = typename std::basic_string::const_iterator; +#ifdef _MSC_VER + // Bypass bug in MSVC: https://github.com/ggml-org/llama.cpp/issues/17830 + constexpr auto regex_flags = std::regex_constants::ECMAScript; +#else + constexpr auto regex_flags = std::regex_constants::optimize | std::regex_constants::nosubs; +#endif + std::basic_regex expr(regex, regex_flags); std::vector bpe_offsets; // store the offset of each word bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size size_t start = 0; for (auto offset : offsets) { - std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr); - std::wcregex_iterator end; + std::regex_iterator it(text.begin() + start, text.begin() + start + offset, expr); + std::regex_iterator end; int64_t start_idx = 0; while (it != end) { - std::wcmatch match = *it; - if (match.position() > start_idx) { - bpe_offsets.emplace_back(match.position() - start_idx); - } - bpe_offsets.emplace_back(match.length()); - start_idx = match.position() + match.length(); - ++it; - } - - if (start_idx < (int64_t) offset) { - bpe_offsets.emplace_back(offset - start_idx); - } - start += offset; - } - - return bpe_offsets; -} - -// use std::regex to split the text -static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { - std::regex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs); - std::vector bpe_offsets; // store the offset of each word - bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size - size_t start = 0; - for (auto offset : offsets) { - std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr); - std::cregex_iterator end; - - int64_t start_idx = 0; - while (it != end) { - std::cmatch match = *it; + std::match_results match = *it; if (match.position() > start_idx) { bpe_offsets.emplace_back(match.position() - start_idx); } @@ -1051,10 +1001,10 @@ std::vector unicode_regex_split(const std::string & text, const std break; } } + const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); if (use_collapsed) { // sanity-check that the original regex does not contain any non-ASCII characters - const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); for (size_t i = 0; i < cpts_regex.size(); ++i) { if (cpts_regex[i] >= 128) { throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported"); @@ -1110,7 +1060,7 @@ std::vector unicode_regex_split(const std::string & text, const std bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); } else { // no unicode category used, we can use std::wregex directly - const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); + std::wstring wregex_expr(cpts_regex.begin(), cpts_regex.end()); // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback std::wstring wtext(cpts.begin(), cpts.end()); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c9436c5995..350bffc315 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -11,7 +11,9 @@ function(llama_build source) add_executable(${TEST_TARGET} ${TEST_SOURCES}) target_link_libraries(${TEST_TARGET} PRIVATE common) - install(TARGETS ${TEST_TARGET} RUNTIME) + if (LLAMA_TESTS_INSTALL) + install(TARGETS ${TEST_TARGET} RUNTIME) + endif() endfunction() function(llama_test target) @@ -100,7 +102,9 @@ function(llama_build_and_test source) endif() add_executable(${TEST_TARGET} ${TEST_SOURCES}) - install(TARGETS ${TEST_TARGET} RUNTIME) + if (LLAMA_TESTS_INSTALL) + install(TARGETS ${TEST_TARGET} RUNTIME) + endif() target_link_libraries(${TEST_TARGET} PRIVATE common) add_test( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index fbe23037cc..8816f6963f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1943,7 +1943,11 @@ struct test_unary : public test_case { ggml_tensor * a; if (v & 1) { - auto ne = ne_a; ne[0] *= 3; + auto ne = ne_a; + ne[0] *= 3; + ne[1] *= 2; + ne[2] *= 5; + ne[3] *= 4; a = ggml_new_tensor(ctx, type, 4, ne.data()); if (grad_supported) { ggml_set_param(a); @@ -2782,9 +2786,10 @@ struct test_set : public test_case { const ggml_type type_dst; const std::array ne; const int dim; + const bool inplace; std::string vars() override { - return VARS_TO_STR4(type_src, type_dst, ne, dim); + return VARS_TO_STR5(type_src, type_dst, ne, dim, inplace); } size_t op_size(ggml_tensor * t) override { @@ -2792,8 +2797,8 @@ struct test_set : public test_case { } test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32, - std::array ne = {6, 5, 4, 3}, int dim = 1) - : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {} + std::array ne = {6, 5, 4, 3}, int dim = 1, bool inplace = false) + : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); @@ -2804,7 +2809,7 @@ struct test_set : public test_case { for (int i = 0; i < dim; ++i) { ne_dst[i] *= 2; } - ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data()); + ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data()); ggml_set_param(dst); ggml_set_name(dst, "dst"); @@ -2812,9 +2817,16 @@ struct test_set : public test_case { for (int i = 0; i < dim; ++i) { offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i]; } - ggml_tensor * out = ggml_set(ctx, dst, src, - // The backward pass requires setting a contiguous region: - src->nb[1], src->nb[2], src->nb[3], offset); + ggml_tensor * out; + if (inplace) { + out = ggml_set_inplace(ctx, dst, src, + // The backward pass requires setting a contiguous region: + src->nb[1], src->nb[2], src->nb[3], offset); + } else { + out = ggml_set(ctx, dst, src, + // The backward pass requires setting a contiguous region: + src->nb[1], src->nb[2], src->nb[3], offset); + } ggml_set_name(out, "out"); return out; @@ -2964,11 +2976,12 @@ struct test_bin_bcast : public test_case { const std::array ne; const std::array nr; int nf; // number of fused ops, nf == 1 -> single op (no fusion) + bool perm1; // permute src1? bool run_whole_graph() override { return nf > 1; } std::string vars() override { - return VARS_TO_STR4(type, ne, nr, nf); + return VARS_TO_STR5(type, ne, nr, nf, perm1); } size_t op_size(ggml_tensor * t) override { @@ -2978,8 +2991,9 @@ struct test_bin_bcast : public test_case { test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 1, 1}, std::array nr = {1, 2, 1, 1}, - int nf = 1) - : op(op), type(type), ne(ne), nr(nr), nf(nf) {} + int nf = 1, + bool perm1 = false) + : op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {} ggml_tensor * build_graph(ggml_context * ctx) override { GGML_ASSERT(nf <= 16); @@ -2989,12 +3003,19 @@ struct test_bin_bcast : public test_case { ggml_tensor * b[16]; for (int i = 0; i < nf; ++i) { - b[i] = ggml_new_tensor(ctx, type, 4, ne.data()); + if (perm1) { + const int p[4] = { 1, 2, 0, 3 }; // hardcoded for now + + b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]); + b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]); + } else { + b[i] = ggml_new_tensor(ctx, type, 4, ne.data()); + } ggml_set_name(b[i], (std::string("b") + std::to_string(i)).c_str()); } // The backward pass supports broadcasting only for GGML_ADD: - const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1; + const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1 && !perm1; if (grad_supported) { ggml_set_param(a); ggml_set_param(b[0]); @@ -5894,33 +5915,36 @@ struct test_pad_ext : public test_case { const int rp2; const int lp3; const int rp3; - const bool v; + const int tfrm; // 0 - none, 1 - non-cont, 2 - perm const bool circular; std::string vars() override { - return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v, circular); + return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, tfrm, circular); } test_pad_ext(ggml_type type = GGML_TYPE_F32, std::array ne_a = {512, 512, 3, 1}, int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1, int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1, - bool v = false, bool circular = false) + int tfrm = 0, bool circular = false) : type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3), - v(v), circular(circular) {} + tfrm(tfrm), circular(circular) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_set_name(a, "a"); - if (v) { + if (tfrm == 1) { a = ggml_view_4d(ctx, a, (a->ne[0] + 1) / 2, (a->ne[1] + 1) / 2, (a->ne[2] + 1) / 2, (a->ne[3] + 1) / 2, a->nb[1], a->nb[2], a->nb[3], 0); ggml_set_name(a, "view of a"); + } else if (tfrm == 2) { + a = ggml_permute(ctx, a, 2, 1, 0, 3); + ggml_set_name(a, "permuted a"); } ggml_tensor * out = circular ? ggml_pad_ext_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3) - : ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + : ggml_pad_ext (ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); ggml_set_name(out, "out"); return out; @@ -7412,11 +7436,13 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3})); for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) { - test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim)); + test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, false)); + test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, true)); } for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) { - test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim)); + test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, false)); + test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, true)); } // same-type copy @@ -7474,25 +7500,27 @@ static std::vector> make_test_cases_eval() { } } - auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr) { + auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr, bool perm1 = false) { for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) { - test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr)); + test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1)); } }; for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { - add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1}); - add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}); + for (bool perm1 : {false, true}) { + add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}, perm1); + } // test case for k_bin_bcast_unravel in CUDA backend add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1}); @@ -7879,20 +7907,27 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_round (type)); test_cases.emplace_back(new test_trunc (type)); test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_sqr (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_sqrt (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_log (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_log (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_sin (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_sin (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_cos (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_cos (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_clamp (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_leaky_relu(type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3})); - test_cases.emplace_back(new test_floor (type, { 1024, 1024, 1, 1 })); + test_cases.emplace_back(new test_floor (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3})); - test_cases.emplace_back(new test_ceil (type, { 1024, 1024, 1, 1 })); + test_cases.emplace_back(new test_ceil (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_round (type, {7, 1, 5, 3})); - test_cases.emplace_back(new test_round (type, { 1024, 1024, 1, 1 })); + test_cases.emplace_back(new test_round (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3})); - test_cases.emplace_back(new test_trunc (type, { 1024, 1024, 1, 1 })); + test_cases.emplace_back(new test_trunc (type, {1024, 1024, 1, 1})); } test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); @@ -8107,24 +8142,30 @@ static std::vector> make_test_cases_eval() { } test_cases.emplace_back(new test_sum()); - test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1})); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2})); + test_cases.emplace_back(new test_mean()); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 256, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32768, 1, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous + test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true)); - test_cases.emplace_back(new test_mean()); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, false)); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, false, true)); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 })); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 })); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1})); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1})); test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1})); @@ -8198,10 +8239,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 200, 64, 4, 4 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 384, 64, 4, 4 })); - for (bool v : {false, true}) { + for (int tfrm : {0, 1, 2}) { for (bool circular : {false, true}) { - test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v, circular)); - test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v, circular)); + test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, tfrm, circular)); + test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, tfrm, circular)); } } @@ -8231,6 +8272,7 @@ static std::vector> make_test_cases_eval() { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { + if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue; test_cases.emplace_back(new test_flash_attn_ext( hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV)); // run fewer test cases permuted @@ -8519,7 +8561,7 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_rope(type, { 80, 32, 512, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 64, 8, 512, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, {128, 12, 512, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B) - test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B) + test_cases.emplace_back(new test_rope(type, {128, 12, 512, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B) test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) } } diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 751440af32..02d71f224e 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -19,6 +19,7 @@ add_library(mtmd models/glm4v.cpp models/internvl.cpp models/kimivl.cpp + models/kimik25.cpp models/llama4.cpp models/llava.cpp models/minicpmv.cpp diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index ad232178bf..3bc93ead86 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -235,6 +235,7 @@ enum projector_type { PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_YOUTUVL, + PROJECTOR_TYPE_KIMIK25, PROJECTOR_TYPE_UNKNOWN, }; @@ -268,6 +269,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, + { PROJECTOR_TYPE_KIMIK25, "kimik25"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 9fa5afc390..eeccb4cda0 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -10,6 +10,7 @@ #include "ggml-backend.h" #include "gguf.h" +#include #include #include #include @@ -672,8 +673,8 @@ ggml_tensor * clip_graph::build_rope_2d( { first = ggml_view_3d(ctx0, cur, n_dim/2, n_head, n_pos, - ggml_row_size(cur->type, n_dim), - ggml_row_size(cur->type, n_dim*n_head), + cur->nb[1], + cur->nb[2], 0); first = ggml_rope_ext( ctx0, @@ -691,8 +692,8 @@ ggml_tensor * clip_graph::build_rope_2d( { second = ggml_view_3d(ctx0, cur, n_dim/2, n_head, n_pos, - ggml_row_size(cur->type, n_dim), - ggml_row_size(cur->type, n_dim*n_head), + cur->nb[1], + cur->nb[2], n_dim/2 * ggml_element_size(cur)); second = ggml_rope_ext( ctx0, @@ -825,6 +826,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_KIMIK25: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_COGVLM: { builder = std::make_unique(ctx, img); @@ -1116,9 +1121,8 @@ struct clip_model_loader { case PROJECTOR_TYPE_LFM2: { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); - // ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json - // config above specifies number of tokens after downsampling, while here it is before, relax lowerbound to 64 - hparams.set_limit_image_tokens(64, 1024); + // ref: https://huggingface.co/LiquidAI/LFM2.5-VL-1.6B/blob/main/processor_config.json + hparams.set_limit_image_tokens(64, 256); } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: @@ -1139,6 +1143,22 @@ struct clip_model_loader { hparams.set_limit_image_tokens(8, 1024); hparams.set_warmup_n_tokens(256); // avoid OOM on warmup } break; + case PROJECTOR_TYPE_KIMIK25: + { + hparams.rope_theta = 10000.0f; + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); + + int min_pixels = 0, max_pixels = 0; + get_u32(KEY_IMAGE_MIN_PIXELS, min_pixels, false); + get_u32(KEY_IMAGE_MAX_PIXELS, max_pixels, false); + if (min_pixels > 0 && max_pixels > 0) { + hparams.image_min_pixels = min_pixels; + hparams.image_max_pixels = max_pixels; + hparams.warmup_image_size = static_cast(std::sqrt(max_pixels)); + } else { + hparams.set_limit_image_tokens(2, 4096); + } + } break; case PROJECTOR_TYPE_GEMMA3: { // default value (used by all model sizes in gemma 3 family) @@ -1668,6 +1688,7 @@ struct clip_model_loader { model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B); @@ -2807,6 +2828,119 @@ private: } }; +// ref: https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +// some of the logic is similar to llava_uhd, but with different hyperparameters and some logic is unique (e.g. grid layout) +struct lfm2_vl_image_processor { + // ref: https://huggingface.co/LiquidAI/LFM2.5-VL-1.6B/blob/main/processor_config.json + static constexpr int min_tiles = 2; + static constexpr int max_tiles = 10; + static constexpr float max_pixels_tolerance = 2.0f; + static constexpr int tile_size = 512; + + static llava_uhd::slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) { + llava_uhd::slice_instructions inst; + const auto & params = ctx->model.hparams; + const int align_size = params.patch_size * params.n_merge; + + inst.interpolation_overview = img_tool::RESIZE_ALGO_BILINEAR; + inst.interpolation_refined = img_tool::RESIZE_ALGO_BILINEAR; + inst.overview_size = img_tool::calc_size_preserved_ratio(original_size, align_size, params.image_min_pixels, params.image_max_pixels); + + // tile if either dimension exceeds tile_size with tolerance + const bool needs_tiling = original_size.width > tile_size * max_pixels_tolerance || original_size.height > tile_size * max_pixels_tolerance; + + if (!needs_tiling) { + inst.refined_size = clip_image_size{0, 0}; + inst.grid_size = clip_image_size{0, 0}; + return inst; + } + + const clip_image_size grid = get_grid_layout(original_size.height, original_size.width); + + inst.grid_size = grid; + inst.refined_size = clip_image_size{tile_size * grid.width, tile_size * grid.height}; + + LOG_DBG("%s: original size: %d x %d, overview size: %d x %d, refined size: %d x %d, grid size: %d x %d\n", + __func__, + original_size.width, original_size.height, + inst.overview_size.width, inst.overview_size.height, + inst.refined_size.width, inst.refined_size.height, + grid.width, grid.height); + + for (int row = 0; row < grid.height; row++) { + for (int col = 0; col < grid.width; col++) { + llava_uhd::slice_coordinates slice; + slice.x = col * tile_size; + slice.y = row * tile_size; + slice.size = clip_image_size{tile_size, tile_size}; + inst.slices.push_back(slice); + LOG_DBG("%s: slice %d: x=%d, y=%d, size=%d x %d\n", + __func__, (int)inst.slices.size() - 1, + slice.x, slice.y, slice.size.width, slice.size.height); + } + } + + return inst; + } + +private: + static clip_image_size find_closest_aspect_ratio( + float aspect_ratio, + const std::vector & target_ratios, + int width, int height) { + float best_ratio_diff = std::numeric_limits::max(); + clip_image_size best_ratio = {1, 1}; + const float area = static_cast(width * height); + + for (const auto & ratio : target_ratios) { + const float target_aspect_ratio = static_cast(ratio.width) / ratio.height; + const float ratio_diff = std::abs(aspect_ratio - target_aspect_ratio); + if (ratio_diff < best_ratio_diff) { + best_ratio_diff = ratio_diff; + best_ratio = ratio; + } else if (ratio_diff == best_ratio_diff) { + const float target_area = static_cast(tile_size * tile_size * ratio.width * ratio.height); + if (area > 0.5f * target_area) { + best_ratio = ratio; + } + } + } + return best_ratio; + } + + static std::vector get_target_ratios() { + std::vector ratios; + for (int n = min_tiles; n <= max_tiles; n++) { + for (int w = 1; w <= n; w++) { + for (int h = 1; h <= n; h++) { + if (w * h >= min_tiles && w * h <= max_tiles) { + bool found = false; + for (const auto & r : ratios) { + if (r.width == w && r.height == h) { + found = true; + break; + } + } + if (!found) { + ratios.push_back({w, h}); + } + } + } + } + } + std::sort(ratios.begin(), ratios.end(), [](const clip_image_size & a, const clip_image_size & b) { + return a.width * a.height < b.width * b.height; + }); + return ratios; + } + + static clip_image_size get_grid_layout(int height, int width) { + const float aspect_ratio = static_cast(width) / height; + const auto ratios = get_target_ratios(); + return find_closest_aspect_ratio(aspect_ratio, ratios, width, height); + } +}; + // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { @@ -3021,6 +3155,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } break; case PROJECTOR_TYPE_LFM2: + { + auto const inst = lfm2_vl_image_processor::get_slice_instructions(ctx, original_size); + std::vector imgs = llava_uhd::slice_image(img, inst); + + for (size_t i = 0; i < imgs.size(); ++i) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + + res_imgs->grid_x = inst.grid_size.width; + res_imgs->grid_y = inst.grid_size.height; + } break; + case PROJECTOR_TYPE_KIMIVL: { GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); @@ -3032,8 +3180,24 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const std::array pad_color = {122, 116, 104}; clip_image_u8 resized_img; - const bool pad = (ctx->proj_type() != PROJECTOR_TYPE_LFM2); - img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, pad, pad_color); + img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color); + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } break; + + case PROJECTOR_TYPE_KIMIK25: + { + GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); + const clip_image_size target_size = img_tool::calc_size_preserved_ratio( + original_size, + params.patch_size * params.n_merge, + params.image_min_pixels, + params.image_max_pixels); + const std::array pad_color = {0, 0, 0}; + + clip_image_u8 resized_img; + img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BICUBIC, true, pad_color); clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); @@ -3247,6 +3411,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } break; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: { // dynamic size int out_patch_size = params.patch_size * ctx->model.hparams.n_merge; @@ -3588,6 +3753,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: case PROJECTOR_TYPE_LIGHTONOCR: { // set the 2D positions @@ -3724,6 +3890,47 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); } + // Debug: dump final embeddings if MTMD_DEBUG_EMBEDDINGS is set + if (std::getenv("MTMD_DEBUG_EMBEDDINGS") != nullptr) { + const int64_t n_embd = embeddings->ne[0]; + const int64_t n_tokens = embeddings->ne[1]; + std::vector emb_data(n_embd * n_tokens); + ggml_backend_tensor_get(embeddings, emb_data.data(), 0, ggml_nbytes(embeddings)); + + LOG_INF("\n=== MTMD_DEBUG_EMBEDDINGS ===\n"); + LOG_INF("Shape: [%lld, %lld]\n", (long long)n_embd, (long long)n_tokens); + + // Print first few values of first token + LOG_INF("Token 0 (first 16 values): "); + for (int i = 0; i < std::min((int64_t)16, n_embd); i++) { + LOG_INF("%.6f ", emb_data[i]); + } + LOG_INF("\n"); + + // Print last few values of first token + if (n_embd > 16) { + LOG_INF("Token 0 (last 16 values): "); + for (int64_t i = n_embd - 16; i < n_embd; i++) { + LOG_INF("%.6f ", emb_data[i]); + } + LOG_INF("\n"); + } + + // Compute and print statistics + float sum = 0.0f, sum_sq = 0.0f, min_val = emb_data[0], max_val = emb_data[0]; + for (size_t i = 0; i < emb_data.size(); i++) { + sum += emb_data[i]; + sum_sq += emb_data[i] * emb_data[i]; + min_val = std::min(min_val, emb_data[i]); + max_val = std::max(max_val, emb_data[i]); + } + float mean = sum / emb_data.size(); + float variance = (sum_sq / emb_data.size()) - (mean * mean); + LOG_INF("Stats: mean=%.6f, std=%.6f, min=%.6f, max=%.6f, sum=%.6f\n", + mean, sqrtf(variance), min_val, max_val, sum); + LOG_INF("=== END MTMD_DEBUG_EMBEDDINGS ===\n\n"); + } + return true; } @@ -3770,6 +3977,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; diff --git a/tools/mtmd/models/kimik25.cpp b/tools/mtmd/models/kimik25.cpp new file mode 100644 index 0000000000..cf9f27f63a --- /dev/null +++ b/tools/mtmd/models/kimik25.cpp @@ -0,0 +1,101 @@ +#include "models.h" +#include +#include + +// note: this is similar to clip_graph::resize_position_embeddings, major difference is having +// the w/h in ne[1] and ne[2] instead of assuming with sqrt. Could try storing the tensor in 2D instead +// with a w*h? Also the permute is a bit different at (2, 1, 0, 3) instead of (2, 0, 1, 3). +ggml_tensor * clip_graph_kimik25::resize_position_embeddings_3d(uint32_t interpolation_mode) { + ggml_tensor * pos_embd = model.position_embeddings; + const int height = img.ny / patch_size; + const int width = img.nx / patch_size; + const uint32_t mode = interpolation_mode; + + GGML_ASSERT(pos_embd); + + const int64_t stored_c = pos_embd->ne[0]; // C = 1152 + const int64_t orig_w = pos_embd->ne[1]; // W = 64 + const int64_t orig_h = pos_embd->ne[2]; // H = 64 + + GGML_ASSERT(stored_c == n_embd); + + if (height == (int)orig_h && width == (int)orig_w) { + // No interpolation needed, just flatten to [C, H*W] + return ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); + } + + pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3); + pos_embd = ggml_interpolate(ctx0, pos_embd, height, width, n_embd, 1, mode); + pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3); + pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); + return pos_embd; +} + +ggml_cgraph * clip_graph_kimik25::build() { + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC); + + // Kimi-K2.5 uses interleaved 2D RoPE pattern natively, but + // Q / K are permuted during conversion to use split format. + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + cur = build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + return cur; + }; + + ggml_tensor * inp = build_inp(); + + // I don't know why, but doing this in the build_vit lead to the ggml_add not occurring? + // Doing it manually here does work. + inp = ggml_add(ctx0, inp, learned_pos_embd); + + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_NORMAL, + hparams.ffn_op, + nullptr, + add_pos); + + cb(cur, "vit_out", -1); + + { + // patch_merger + const int scale_factor = model.hparams.n_merge; + cur = build_patch_merge_permute(cur, scale_factor); + + // projection norm + int proj_inp_dim = cur->ne[0]; + int n_merged_patches = cur->ne[1]; + cur = ggml_view_2d(ctx0, cur, + n_embd, n_merged_patches * scale_factor * scale_factor, + ggml_row_size(cur->type, n_embd), 0); + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); + cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + cur = ggml_view_2d(ctx0, cur, + proj_inp_dim, n_merged_patches, + ggml_row_size(cur->type, proj_inp_dim), 0); + cb(cur, "proj_inp_normed", -1); + + // projection mlp + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU, + -1); + + cb(cur, "proj_out", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 9970980c7b..c4c67ace62 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -109,3 +109,10 @@ struct clip_graph_mobilenetv5 : clip_graph { ggml_tensor * inp, const mobilenetv5_block & block); }; + +struct clip_graph_kimik25 : clip_graph { + clip_graph_kimik25(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + + ggml_tensor * resize_position_embeddings_3d(uint32_t interpolation_mode); +}; diff --git a/tools/mtmd/models/qwen3vl.cpp b/tools/mtmd/models/qwen3vl.cpp index 35a42cb84d..5ecb10fe43 100644 --- a/tools/mtmd/models/qwen3vl.cpp +++ b/tools/mtmd/models/qwen3vl.cpp @@ -182,7 +182,9 @@ ggml_cgraph * clip_graph_qwen3vl::build() { model.mm_1_w, model.mm_1_b, ffn_op_type::FFN_GELU, -1); - embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension + if (deepstack_features) { + embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); + } // concat along the feature dimension // build the graph ggml_build_forward_expand(gf, embeddings); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index d037e834f3..b7636279cb 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -85,6 +85,7 @@ enum mtmd_slice_tmpl { MTMD_SLICE_TMPL_MINICPMV_2_6, MTMD_SLICE_TMPL_LLAMA4, MTMD_SLICE_TMPL_IDEFICS3, + MTMD_SLICE_TMPL_LFM2, }; const char * mtmd_default_marker() { @@ -307,9 +308,19 @@ struct mtmd_context { img_end = "<|im_end|>"; } else if (proj == PROJECTOR_TYPE_LFM2) { - img_beg = "<|image_start|>"; - img_end = "<|image_end|>"; - + // multi-tile: + // <|image_start|> + // <|img_row_1_col_1|> (tile) <|img_row_1_col_2|> (tile) ... + // <|img_thumbnail|> (thumbnail) + // <|image_end|> + // single-tile: + // <|image_start|> (image) <|image_end|> + img_beg = "<|image_start|>"; + img_end = "<|image_end|>"; + slice_tmpl = MTMD_SLICE_TMPL_LFM2; + sli_img_start_tmpl = "<|img_row_%d_col_%d|>"; + tok_ov_img_start = {lookup_token("<|img_thumbnail|>")}; + ov_img_first = false; } else if (proj == PROJECTOR_TYPE_GLM4V) { img_beg = "<|begin_of_image|>"; img_end = "<|end_of_image|>"; @@ -562,11 +573,13 @@ struct mtmd_tokenizer { } // handle llava-uhd style preprocessing + const bool has_tiling_grid = batch_f32.grid_x > 0 && batch_f32.grid_y > 0; if ( ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6 || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4 || ctx->slice_tmpl == MTMD_SLICE_TMPL_IDEFICS3 + || (ctx->slice_tmpl == MTMD_SLICE_TMPL_LFM2 && has_tiling_grid) ) { const int n_col = batch_f32.grid_x; const int n_row = batch_f32.grid_y; diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 0709e0bda0..c0f49279ee 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -119,27 +119,48 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp [[noreturn]] static void usage(const char * executable) { printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable); - printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--tensor-type-file] [--prune-layers] [--keep-split] [--override-kv]\n"); + printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--tensor-type-file]\n"); + printf(" [--prune-layers] [--keep-split] [--override-kv]\n"); printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); - printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); - printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); - printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); - printf(" --imatrix file_name: use data in file_name as importance matrix for quant optimizations\n"); - printf(" --include-weights tensor_name: use importance matrix for this/these tensor(s)\n"); - printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); - printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n"); - printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); - printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n"); - printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n"); - printf(" --tensor-type-file tensor_type.txt: list of tensors to quantize to specific ggml_type. example: --tensor-type-file tensor_type_list.txt\n"); - printf(" Advanced option to selectively quantize a long list of tensors. Format to be tensor_name=ggml_type, separated by spaces/newline.\n"); - printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n"); - printf(" Advanced option to remove all tensors from the given layers\n"); - printf(" --keep-split: will generate quantized model in the same shards as input\n"); + printf(" --allow-requantize\n"); + printf(" allow requantizing tensors that have already been quantized\n"); + printf(" WARNING: this can severely reduce quality compared to quantizing\n"); + printf(" from 16bit or 32bit!\n"); + printf(" --leave-output-tensor\n"); + printf(" leave output.weight un(re)quantized\n"); + printf(" increases model size but may also increase quality, especially when requantizing\n"); + printf(" --pure\n"); + printf(" disable k-quant mixtures and quantize all tensors to the same type\n"); + printf(" --imatrix file_name\n"); + printf(" use data in file_name as importance matrix for quant optimizations\n"); + printf(" --include-weights tensor_name\n"); + printf(" use importance matrix for this/these tensor(s)\n"); + printf(" --exclude-weights tensor_name\n"); + printf(" do not use importance matrix for this/these tensor(s)\n"); + printf(" --output-tensor-type ggml_type\n"); + printf(" use this ggml_type for the output.weight tensor\n"); + printf(" --token-embedding-type ggml_type\n"); + printf(" use this ggml_type for the token embeddings tensor\n"); + printf(" --tensor-type tensor_name=ggml_type\n"); + printf(" quantize this tensor to this ggml_type\n"); + printf(" this is an advanced option to selectively quantize tensors. may be specified multiple times.\n"); + printf(" example: --tensor-type attn_q=q8_0\n"); + printf(" --tensor-type-file tensor_types.txt\n"); + printf(" list of tensors to quantize to a specific ggml_type\n"); + printf(" this is an advanced option to selectively quantize a long list of tensors.\n"); + printf(" the file should use the same format as above, separated by spaces or newlines.\n"); + printf(" --prune-layers L0,L1,L2...\n"); + printf(" comma-separated list of layer numbers to prune from the model\n"); + printf(" WARNING: this is an advanced option, use with care.\n"); + printf(" --keep-split\n"); + printf(" generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); - printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); - printf("Note: --include-weights and --exclude-weights cannot be used together\n"); - printf("\nAllowed quantization types:\n"); + printf(" override model metadata by key in the quantized model. may be specified multiple times.\n"); + printf(" WARNING: this is an advanced option, use with care.\n\n"); + printf("note: --include-weights and --exclude-weights cannot be used together\n\n"); + printf("-----------------------------------------------------------------------------\n"); + printf(" allowed quantization types\n"); + printf("-----------------------------------------------------------------------------\n\n"); for (const auto & it : QUANT_OPTIONS) { if (it.name != "COPY") { printf(" %2d or ", it.ftype); diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp index 58b93c7468..521f79622d 100644 --- a/tools/rpc/rpc-server.cpp +++ b/tools/rpc/rpc-server.cpp @@ -1,12 +1,7 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "ggml-rpc.h" #ifdef _WIN32 # define NOMINMAX # define DIRECTORY_SEPARATOR '\\' -# include # include # include # include @@ -15,23 +10,43 @@ # include # include #endif -#include #include #include #include -#include #include #include #include -namespace fs = std::filesystem; +#if defined(__linux__) +#include +#include +#endif + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +#ifdef _WIN32 +static std::wstring utf8_to_wstring(const std::string & str) { + if (str.empty()) { + return std::wstring(); + } + + int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0); + + if (size <= 0) { + return std::wstring(); + } + + std::wstring wstr(size, 0); + MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size); + + return wstr; +} +#endif // NOTE: this is copied from common.cpp to avoid linking with libcommon // returns true if successful, false otherwise static bool fs_create_directory_with_parents(const std::string & path) { #ifdef _WIN32 - std::wstring_convert> converter; - std::wstring wpath = converter.from_bytes(path); + std::wstring wpath = utf8_to_wstring(path); // if the path already exists, check whether it's a directory const DWORD attributes = GetFileAttributesW(wpath.c_str()); @@ -44,9 +59,16 @@ static bool fs_create_directory_with_parents(const std::string & path) { // process path from front to back, procedurally creating directories while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { const std::wstring subpath = wpath.substr(0, pos_slash); - const wchar_t * test = subpath.c_str(); - const bool success = CreateDirectoryW(test, NULL); + pos_slash += 1; + + // skip the drive letter, in some systems it can return an access denied error + if (subpath.length() == 2 && subpath[1] == ':') { + continue; + } + + const bool success = CreateDirectoryW(subpath.c_str(), NULL); + if (!success) { const DWORD error = GetLastError(); @@ -60,8 +82,6 @@ static bool fs_create_directory_with_parents(const std::string & path) { return false; } } - - pos_slash += 1; } return true; @@ -115,13 +135,27 @@ static std::string fs_get_cache_directory() { #if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); - } else { + } else if (std::getenv("HOME")) { cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } else { +#if defined(__linux__) + /* no $HOME is defined, fallback to getpwuid */ + struct passwd *pw = getpwuid(getuid()); + if ((!pw) || (!pw->pw_dir)) { + throw std::runtime_error("Failed to find $HOME directory"); + } + + cache_directory = std::string(pw->pw_dir) + std::string("/.cache/"); +#else /* defined(__linux__) */ + throw std::runtime_error("Failed to find $HOME directory"); +#endif /* defined(__linux__) */ } #elif defined(__APPLE__) cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); #elif defined(_WIN32) cache_directory = std::getenv("LOCALAPPDATA"); +#elif defined(__EMSCRIPTEN__) + GGML_ABORT("not implemented on this platform"); #else # error Unknown architecture #endif diff --git a/tools/server/README.md b/tools/server/README.md index d132830171..0b56ca1e27 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -19,7 +19,7 @@ Set of LLM REST APIs and a web UI to interact with llama.cpp. * Speculative decoding * Easy-to-use web UI -For the ful list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291) +For the full list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291) ## Usage diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index e3b06f4901..f4ff57b4c9 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 7f9c3c566b..ceafcac179 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -740,6 +740,11 @@ private: slots.clear(); + const bool can_spec = common_speculative_is_compat(ctx); + if (!can_spec) { + SRV_WRN("%s", "speculative decoding not supported by this context\n"); + } + // initialize slots for (int i = 0; i < params_base.n_parallel; i++) { server_slot slot; @@ -752,7 +757,7 @@ private: slot.prompt.tokens.has_mtmd = mctx != nullptr; // try speculative decoding - { + if (can_spec) { slot.spec = common_speculative_init(params_base.speculative, slot.ctx); if (slot.spec) { if (mctx) { @@ -2502,7 +2507,8 @@ private: slot.n_prompt_tokens_processed++; // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { + const int n_last = std::min(n_batch, 512); + if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) { break; } } @@ -3578,6 +3584,8 @@ void server_routes::init_routes() { auto res = create_response(); std::vector files; json body = convert_responses_to_chatcmpl(json::parse(req.body)); + SRV_DBG("%s\n", "Request converted: OpenAI Responses -> OpenAI Chat Completions"); + SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( body, meta->chat_params, @@ -3594,6 +3602,8 @@ void server_routes::init_routes() { auto res = create_response(); std::vector files; json body = convert_anthropic_to_oai(json::parse(req.body)); + SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions"); + SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( body, meta->chat_params, @@ -3610,6 +3620,8 @@ void server_routes::init_routes() { auto res = create_response(); std::vector files; json body = convert_anthropic_to_oai(json::parse(req.body)); + SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions"); + SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( body, meta->chat_params, diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 2d25db63b7..a137427c69 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -80,7 +80,6 @@ json task_params::to_json(bool only_metrics) const { {"speculative.type", common_speculative_type_to_str(speculative.type)}, {"speculative.ngram_size_n", speculative.ngram_size_n}, {"speculative.ngram_size_m", speculative.ngram_size_m}, - {"speculative.ngram_c_rate", speculative.ngram_check_rate}, {"speculative.ngram_m_hits", speculative.ngram_min_hits}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, @@ -144,7 +143,6 @@ json task_params::to_json(bool only_metrics) const { {"speculative.type", common_speculative_type_to_str(speculative.type)}, {"speculative.ngram_size_n", speculative.ngram_size_n}, {"speculative.ngram_size_m", speculative.ngram_size_m}, - {"speculative.ngram_c_rate", speculative.ngram_check_rate}, {"speculative.ngram_m_hits", speculative.ngram_min_hits}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, @@ -257,12 +255,10 @@ task_params server_task::params_from_json_cmpl( params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n); params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m); - params.speculative.ngram_check_rate = json_value(data, "speculative.ngram_c_rate", defaults.speculative.ngram_check_rate); params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits); params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024); params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024); - params.speculative.ngram_check_rate = std::max(std::min(1, (int) params.speculative.ngram_check_rate), 1024); params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024); // Use OpenAI API logprobs only if n_probs wasn't provided diff --git a/tools/server/webui/docs/flows/settings-flow.md b/tools/server/webui/docs/flows/settings-flow.md index 578e01e6e1..474aef01b0 100644 --- a/tools/server/webui/docs/flows/settings-flow.md +++ b/tools/server/webui/docs/flows/settings-flow.md @@ -139,6 +139,6 @@ sequenceDiagram Note over settingsStore: UI-only (not synced): rect rgb(255, 240, 240) - Note over settingsStore: systemMessage, custom (JSON)
showStatistics, enableContinueGeneration
autoMicOnEmpty, disableAutoScroll
apiKey, pdfAsImage, disableReasoningFormat + Note over settingsStore: systemMessage, custom (JSON)
showStatistics, enableContinueGeneration
autoMicOnEmpty, disableAutoScroll
apiKey, pdfAsImage, disableReasoningParsing, showRawOutputSwitch end ``` diff --git a/tools/server/webui/src/app.css b/tools/server/webui/src/app.css index 9705040a4d..3ab21f0cc7 100644 --- a/tools/server/webui/src/app.css +++ b/tools/server/webui/src/app.css @@ -14,11 +14,11 @@ --popover-foreground: oklch(0.145 0 0); --primary: oklch(0.205 0 0); --primary-foreground: oklch(0.985 0 0); - --secondary: oklch(0.97 0 0); + --secondary: oklch(0.95 0 0); --secondary-foreground: oklch(0.205 0 0); --muted: oklch(0.97 0 0); --muted-foreground: oklch(0.556 0 0); - --accent: oklch(0.97 0 0); + --accent: oklch(0.95 0 0); --accent-foreground: oklch(0.205 0 0); --destructive: oklch(0.577 0.245 27.325); --border: oklch(0.875 0 0); @@ -37,7 +37,7 @@ --sidebar-accent-foreground: oklch(0.205 0 0); --sidebar-border: oklch(0.922 0 0); --sidebar-ring: oklch(0.708 0 0); - --code-background: oklch(0.975 0 0); + --code-background: oklch(0.985 0 0); --code-foreground: oklch(0.145 0 0); --layer-popover: 1000000; } @@ -51,7 +51,7 @@ --popover-foreground: oklch(0.985 0 0); --primary: oklch(0.922 0 0); --primary-foreground: oklch(0.205 0 0); - --secondary: oklch(0.269 0 0); + --secondary: oklch(0.29 0 0); --secondary-foreground: oklch(0.985 0 0); --muted: oklch(0.269 0 0); --muted-foreground: oklch(0.708 0 0); @@ -116,12 +116,62 @@ --color-sidebar-ring: var(--sidebar-ring); } +:root { + --chat-form-area-height: 8rem; + --chat-form-area-offset: 2rem; + --max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem))); +} + +@media (min-width: 640px) { + :root { + --chat-form-area-height: 24rem; + --chat-form-area-offset: 12rem; + } +} + @layer base { * { @apply border-border outline-ring/50; } + body { @apply bg-background text-foreground; + scrollbar-width: thin; + scrollbar-gutter: stable; + } + + /* Global scrollbar styling - visible only on hover */ + * { + scrollbar-width: thin; + scrollbar-color: transparent transparent; + transition: scrollbar-color 0.2s ease; + } + + *:hover { + scrollbar-color: hsl(var(--muted-foreground) / 0.3) transparent; + } + + *::-webkit-scrollbar { + width: 6px; + height: 6px; + } + + *::-webkit-scrollbar-track { + background: transparent; + } + + *::-webkit-scrollbar-thumb { + background: transparent; + border-radius: 3px; + transition: background 0.2s ease; + } + + *:hover::-webkit-scrollbar-thumb { + background: hsl(var(--muted-foreground) / 0.3); + } + + *::-webkit-scrollbar-thumb:hover { + background: hsl(var(--muted-foreground) / 0.5); } } diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte new file mode 100644 index 0000000000..4494ea880b --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte @@ -0,0 +1,48 @@ + + + + + + + + +

{tooltip}

+
+
diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIconCopyToClipboard.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconCopyToClipboard.svelte new file mode 100644 index 0000000000..bf6cd4fb28 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/ActionIconCopyToClipboard.svelte @@ -0,0 +1,18 @@ + + + canCopy && copyToClipboard(text)} +/> diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte new file mode 100644 index 0000000000..1ae3d21774 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte @@ -0,0 +1,26 @@ + + + diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte new file mode 100644 index 0000000000..54ff0af1a0 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte @@ -0,0 +1,46 @@ + + +
+
+ +
+ + {#if showPreview} + + {/if} +
diff --git a/tools/server/webui/src/lib/components/app/actions/index.ts b/tools/server/webui/src/lib/components/app/actions/index.ts new file mode 100644 index 0000000000..43485c7b7e --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/index.ts @@ -0,0 +1,19 @@ +/** + * + * ACTIONS + * + * Small interactive components for user actions. + * + */ + +/** Styled icon button for action triggers with tooltip. */ +export { default as ActionIcon } from './ActionIcon.svelte'; + +/** Code block actions component (copy, preview). */ +export { default as ActionIconsCodeBlock } from './ActionIconsCodeBlock.svelte'; + +/** Copy-to-clipboard icon button with click handler. */ +export { default as ActionIconCopyToClipboard } from './ActionIconCopyToClipboard.svelte'; + +/** Remove/delete icon button with X icon. */ +export { default as ActionIconRemove } from './ActionIconRemove.svelte'; diff --git a/tools/server/webui/src/lib/components/app/badges/BadgeChatStatistic.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeChatStatistic.svelte new file mode 100644 index 0000000000..a2b28d2057 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/badges/BadgeChatStatistic.svelte @@ -0,0 +1,44 @@ + + +{#if tooltipLabel} + + + + {#snippet icon()} + + {/snippet} + + {value} + + + +

{tooltipLabel}

+
+
+{:else} + + {#snippet icon()} + + {/snippet} + + {value} + +{/if} diff --git a/tools/server/webui/src/lib/components/app/badges/BadgeInfo.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeInfo.svelte new file mode 100644 index 0000000000..c70af6f423 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/badges/BadgeInfo.svelte @@ -0,0 +1,27 @@ + + + diff --git a/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte new file mode 100644 index 0000000000..a0d5e863c2 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte @@ -0,0 +1,39 @@ + + +{#each displayableModalities as modality, index (index)} + {@const IconComponent = MODALITY_ICONS[modality]} + {@const label = MODALITY_LABELS[modality]} + + + {#if IconComponent} + + {/if} + + {label} + +{/each} diff --git a/tools/server/webui/src/lib/components/app/badges/index.ts b/tools/server/webui/src/lib/components/app/badges/index.ts new file mode 100644 index 0000000000..860afe3084 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/badges/index.ts @@ -0,0 +1,16 @@ +/** + * + * BADGES & INDICATORS + * + * Small visual indicators for status and metadata. + * + */ + +/** Badge displaying chat statistics (tokens, timing). */ +export { default as BadgeChatStatistic } from './BadgeChatStatistic.svelte'; + +/** Generic info badge with optional tooltip and click handler. */ +export { default as BadgeInfo } from './BadgeInfo.svelte'; + +/** Badge indicating model modality (vision, audio, tools). */ +export { default as BadgeModality } from './BadgeModality.svelte'; diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte index 27ab975cbd..95645295fb 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte @@ -27,11 +27,13 @@ interface Props { class?: string; disabled?: boolean; + initialMessage?: string; isLoading?: boolean; onFileRemove?: (fileId: string) => void; onFileUpload?: (files: File[]) => void; onSend?: (message: string, files?: ChatUploadedFile[]) => Promise; onStop?: () => void; + onSystemPromptAdd?: (draft: { message: string; files: ChatUploadedFile[] }) => void; showHelperText?: boolean; uploadedFiles?: ChatUploadedFile[]; } @@ -39,11 +41,13 @@ let { class: className, disabled = false, + initialMessage = '', isLoading = false, onFileRemove, onFileUpload, onSend, onStop, + onSystemPromptAdd, showHelperText = true, uploadedFiles = $bindable([]) }: Props = $props(); @@ -53,15 +57,28 @@ let currentConfig = $derived(config()); let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined); let isRecording = $state(false); - let message = $state(''); + let message = $state(initialMessage); let pasteLongTextToFileLength = $derived.by(() => { const n = Number(currentConfig.pasteLongTextToFileLen); return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n; }); let previousIsLoading = $state(isLoading); + let previousInitialMessage = $state(initialMessage); let recordingSupported = $state(false); let textareaRef: ChatFormTextarea | undefined = $state(undefined); + // Sync message when initialMessage prop changes (e.g., after draft restoration) + $effect(() => { + if (initialMessage !== previousInitialMessage) { + message = initialMessage; + previousInitialMessage = initialMessage; + } + }); + + function handleSystemPromptClick() { + onSystemPromptAdd?.({ message, files: uploadedFiles }); + } + // Check if model is selected (in ROUTER mode) let conversationModel = $derived( chatStore.getConversationModel(activeMessages() as DatabaseMessage[]) @@ -308,6 +325,7 @@ onFileUpload={handleFileUpload} onMicClick={handleMicClick} onStop={handleStop} + onSystemPromptClick={handleSystemPromptClick} /> diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte index dd37268096..3545b4aebf 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte @@ -1,5 +1,6 @@ -
+
@@ -81,6 +88,16 @@
+ + {#if showRawOutputSwitch} +
+ Show raw output + onRawOutputToggle?.(checked)} + /> +
+ {/if}
{ @@ -238,7 +241,7 @@ {:else if message.role === 'assistant'} - {#if config().disableReasoningFormat} + {#if showRawOutput}
{messageContent || ''}
{:else} @@ -352,6 +355,9 @@ {onConfirmDelete} {onNavigateToSibling} {onShowDeleteDialogChange} + showRawOutputSwitch={currentConfig.showRawOutputSwitch} + rawOutputEnabled={showRawOutput} + onRawOutputToggle={(enabled) => (showRawOutput = enabled)} /> {/if} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte index 24fe5926ba..d457e042fc 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte @@ -3,6 +3,7 @@ import { BadgeChatStatistic } from '$lib/components/app'; import * as Tooltip from '$lib/components/ui/tooltip'; import { ChatMessageStatsView } from '$lib/enums'; + import { formatPerformanceTime } from '$lib/utils/formatters'; interface Props { predictedTokens?: number; @@ -57,8 +58,8 @@ ); let tokensPerSecond = $derived(hasGenerationStats ? (predictedTokens! / predictedMs!) * 1000 : 0); - let timeInSeconds = $derived( - predictedMs !== undefined ? (predictedMs / 1000).toFixed(2) : '0.00' + let formattedTime = $derived( + predictedMs !== undefined ? formatPerformanceTime(predictedMs) : '0s' ); let promptTokensPerSecond = $derived( @@ -67,15 +68,15 @@ : undefined ); - let promptTimeInSeconds = $derived( - promptMs !== undefined ? (promptMs / 1000).toFixed(2) : undefined + let formattedPromptTime = $derived( + promptMs !== undefined ? formatPerformanceTime(promptMs) : undefined ); let hasPromptStats = $derived( promptTokens !== undefined && promptMs !== undefined && promptTokensPerSecond !== undefined && - promptTimeInSeconds !== undefined + formattedPromptTime !== undefined ); // In live mode, generation tab is disabled until we have generation stats @@ -142,7 +143,7 @@ - Send + Save diff --git a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte index 27439551a1..a5450e6af8 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte @@ -21,6 +21,7 @@ chatStore, errorDialog, isLoading, + isChatStreaming, isEditing, getAddFilesHandler } from '$lib/stores/chat.svelte'; @@ -71,6 +72,8 @@ let emptyFileNames = $state([]); + let initialMessage = $state(''); + let isEmpty = $derived( showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading() ); @@ -79,7 +82,7 @@ let isServerLoading = $derived(serverLoading()); let hasPropsError = $derived(!!serverError()); - let isCurrentConversationLoading = $derived(isLoading()); + let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming()); let isRouter = $derived(isRouterMode()); @@ -221,6 +224,14 @@ } } + async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) { + if (draft.message || draft.files.length > 0) { + chatStore.savePendingDraft(draft.message, draft.files); + } + + await chatStore.addSystemPrompt(); + } + function handleScroll() { if (disableAutoScroll || !chatScrollContainer) return; @@ -343,6 +354,12 @@ if (!disableAutoScroll) { setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY); } + + const pendingDraft = chatStore.consumePendingDraft(); + if (pendingDraft) { + initialMessage = pendingDraft.message; + uploadedFiles = pendingDraft.files; + } }); $effect(() => { @@ -428,11 +445,13 @@
chatStore.stopGeneration()} + onSystemPromptAdd={handleSystemPromptAdd} showHelperText={false} bind:uploadedFiles /> @@ -486,11 +505,13 @@
chatStore.stopGeneration()} + onSystemPromptAdd={handleSystemPromptAdd} showHelperText={true} bind:uploadedFiles /> diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte index 5a668aa300..967f19bbce 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte @@ -254,8 +254,13 @@ type: 'checkbox' }, { - key: 'disableReasoningFormat', - label: 'Show raw LLM output', + key: 'disableReasoningParsing', + label: 'Disable reasoning content parsing', + type: 'checkbox' + }, + { + key: 'showRawOutputSwitch', + label: 'Enable raw output toggle', type: 'checkbox' }, { diff --git a/tools/server/webui/src/lib/components/app/content/CollapsibleContentBlock.svelte b/tools/server/webui/src/lib/components/app/content/CollapsibleContentBlock.svelte new file mode 100644 index 0000000000..082738da57 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/content/CollapsibleContentBlock.svelte @@ -0,0 +1,97 @@ + + + { + open = value; + onToggle?.(); + }} + class={className} +> + + +
+ {#if Icon} + + {/if} + + {title} + + {#if subtitle} + {subtitle} + {/if} +
+ +
+ + + Toggle content +
+
+ + +
+ {@render children()} +
+
+
+
diff --git a/tools/server/webui/src/lib/components/app/content/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/content/MarkdownContent.svelte new file mode 100644 index 0000000000..ef6c7e064f --- /dev/null +++ b/tools/server/webui/src/lib/components/app/content/MarkdownContent.svelte @@ -0,0 +1,1201 @@ + + +
+ {#each renderedBlocks as block (block.id)} +
+ + {@html block.html} +
+ {/each} + + {#if unstableBlockHtml} +
+ + {@html unstableBlockHtml} +
+ {/if} + + {#if incompleteCodeBlock} +
+
+ {incompleteCodeBlock.language || 'text'} + { + previewCode = code; + previewLanguage = lang; + previewDialogOpen = true; + }} + /> +
+
streamingAutoScroll.handleScroll()} + > +
{@html highlightCode(
+							incompleteCodeBlock.code,
+							incompleteCodeBlock.language || 'text'
+						)}
+
+
+ {/if} +
+ + + + diff --git a/tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte b/tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte new file mode 100644 index 0000000000..625fdc7b1b --- /dev/null +++ b/tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte @@ -0,0 +1,95 @@ + + +
+ +
{@html highlightedHtml}
+
+ + diff --git a/tools/server/webui/src/lib/components/app/content/index.ts b/tools/server/webui/src/lib/components/app/content/index.ts new file mode 100644 index 0000000000..bca1c9f4c2 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/content/index.ts @@ -0,0 +1,79 @@ +/** + * + * CONTENT RENDERING + * + * Components for rendering rich content: markdown, code, and previews. + * + */ + +/** + * **MarkdownContent** - Rich markdown renderer + * + * Renders markdown content with syntax highlighting, LaTeX math, + * tables, links, and code blocks. Optimized for streaming with + * incremental block-based rendering. + * + * **Features:** + * - GFM (GitHub Flavored Markdown): tables, task lists, strikethrough + * - LaTeX math via KaTeX (`$inline$` and `$$block$$`) + * - Syntax highlighting (highlight.js) with language detection + * - Code copy buttons with click feedback + * - External links open in new tab with security attrs + * - Image attachment resolution from message extras + * - Dark/light theme support (auto-switching) + * - Streaming-optimized incremental rendering + * - Code preview dialog for large blocks + * + * @example + * ```svelte + * + * ``` + */ +export { default as MarkdownContent } from './MarkdownContent.svelte'; + +/** + * **SyntaxHighlightedCode** - Code syntax highlighting + * + * Renders code with syntax highlighting using highlight.js. + * Supports theme switching and scrollable containers. + * + * **Features:** + * - Auto language detection with fallback + * - Dark/light theme auto-switching + * - Scrollable container with configurable max dimensions + * - Monospace font styling + * - Preserves whitespace and formatting + * + * @example + * ```svelte + * + * ``` + */ +export { default as SyntaxHighlightedCode } from './SyntaxHighlightedCode.svelte'; + +/** + * **CollapsibleContentBlock** - Expandable content card + * + * Reusable collapsible card with header, icon, and auto-scroll. + * Used for tool calls and reasoning blocks in chat messages. + * + * **Features:** + * - Collapsible content with smooth animation + * - Custom icon and title display + * - Optional subtitle/status text + * - Auto-scroll during streaming (pauses on user scroll) + * - Configurable max height with overflow scroll + * + * @example + * ```svelte + * + * {reasoningContent} + * + * ``` + */ +export { default as CollapsibleContentBlock } from './CollapsibleContentBlock.svelte'; diff --git a/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte b/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte index e2095e0876..21412f47e5 100644 --- a/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte +++ b/tools/server/webui/src/lib/components/app/misc/ConversationSelection.svelte @@ -17,9 +17,13 @@ let { conversations, messageCountMap = new Map(), mode, onCancel, onConfirm }: Props = $props(); let searchQuery = $state(''); - let selectedIds = $state.raw>(new SvelteSet(conversations.map((c) => c.id))); + let selectedIds = $state.raw>(getInitialSelectedIds()); let lastClickedId = $state(null); + function getInitialSelectedIds(): SvelteSet { + return new SvelteSet(conversations.map((c) => c.id)); + } + let filteredConversations = $derived( conversations.filter((conv) => { const name = conv.name || 'Untitled conversation'; @@ -92,7 +96,7 @@ } function handleCancel() { - selectedIds = new SvelteSet(conversations.map((c) => c.id)); + selectedIds = getInitialSelectedIds(); searchQuery = ''; lastClickedId = null; @@ -100,7 +104,7 @@ } export function reset() { - selectedIds = new SvelteSet(conversations.map((c) => c.id)); + selectedIds = getInitialSelectedIds(); searchQuery = ''; lastClickedId = null; } diff --git a/tools/server/webui/src/lib/components/app/misc/DropdownMenuSearchable.svelte b/tools/server/webui/src/lib/components/app/misc/DropdownMenuSearchable.svelte new file mode 100644 index 0000000000..21ba04cf66 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/DropdownMenuSearchable.svelte @@ -0,0 +1,88 @@ + + + + { + e.preventDefault(); + e.stopPropagation(); + }} + > + {@render trigger()} + + + +
+ +
+ +
+ {@render children()} + + {#if isEmpty} +
{emptyMessage}
+ {/if} +
+ + {#if footer} + + + {@render footer()} + {/if} +
+
diff --git a/tools/server/webui/src/lib/components/app/misc/HorizontalScrollCarousel.svelte b/tools/server/webui/src/lib/components/app/misc/HorizontalScrollCarousel.svelte new file mode 100644 index 0000000000..e302f83e11 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/HorizontalScrollCarousel.svelte @@ -0,0 +1,93 @@ + + +
+ + +
+ {@render children?.()} +
+ + +
diff --git a/tools/server/webui/src/lib/components/app/misc/KeyboardShortcutInfo.svelte b/tools/server/webui/src/lib/components/app/misc/KeyboardShortcutInfo.svelte index 5b7522fe1b..da55abda02 100644 --- a/tools/server/webui/src/lib/components/app/misc/KeyboardShortcutInfo.svelte +++ b/tools/server/webui/src/lib/components/app/misc/KeyboardShortcutInfo.svelte @@ -11,7 +11,9 @@ let baseClasses = 'px-1 pointer-events-none inline-flex select-none items-center gap-0.5 font-sans text-md font-medium opacity-0 transition-opacity -my-1'; - let variantClasses = variant === 'destructive' ? 'text-destructive' : 'text-muted-foreground'; + let variantClasses = $derived( + variant === 'destructive' ? 'text-destructive' : 'text-muted-foreground' + ); diff --git a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte index cb3ae17a63..0084499f85 100644 --- a/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte +++ b/tools/server/webui/src/lib/components/app/misc/MarkdownContent.svelte @@ -486,6 +486,8 @@ text-decoration: underline; text-underline-offset: 2px; transition: color 0.2s ease; + overflow-wrap: anywhere; + word-break: break-all; } div :global(a:hover) { diff --git a/tools/server/webui/src/lib/components/app/misc/TruncatedText.svelte b/tools/server/webui/src/lib/components/app/misc/TruncatedText.svelte new file mode 100644 index 0000000000..9a8731fc78 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/TruncatedText.svelte @@ -0,0 +1,48 @@ + + +{#if isTruncated} + + + + {text} + + + + +

{text}

+
+
+{:else} + + {text} + +{/if} diff --git a/tools/server/webui/src/lib/components/app/misc/index.ts b/tools/server/webui/src/lib/components/app/misc/index.ts new file mode 100644 index 0000000000..02bd70b24f --- /dev/null +++ b/tools/server/webui/src/lib/components/app/misc/index.ts @@ -0,0 +1,45 @@ +/** + * + * MISC + * + * Miscellaneous utility components. + * + */ + +/** + * **ConversationSelection** - Multi-select conversation picker + * + * List of conversations with checkboxes for multi-selection. + * Used in import/export dialogs for selecting conversations. + * + * **Features:** + * - Search/filter conversations by name + * - Select all / deselect all controls + * - Shift-click for range selection + * - Message count display per conversation + * - Mode-specific UI (export vs import) + */ +export { default as ConversationSelection } from './ConversationSelection.svelte'; + +/** + * Horizontal scrollable carousel with navigation arrows. + * Used for displaying items in a horizontally scrollable container + * with left/right navigation buttons that appear on hover. + */ +export { default as HorizontalScrollCarousel } from './HorizontalScrollCarousel.svelte'; + +/** + * **TruncatedText** - Text with ellipsis and tooltip + * + * Displays text with automatic truncation and full content in tooltip. + * Useful for long names or paths in constrained spaces. + */ +export { default as TruncatedText } from './TruncatedText.svelte'; + +/** + * **KeyboardShortcutInfo** - Keyboard shortcut hint display + * + * Displays keyboard shortcut hints (e.g., "⌘ + Enter"). + * Supports special keys like shift, cmd, and custom text. + */ +export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte'; diff --git a/tools/server/webui/src/lib/components/app/navigation/DropdownMenuActions.svelte b/tools/server/webui/src/lib/components/app/navigation/DropdownMenuActions.svelte new file mode 100644 index 0000000000..83d856d10e --- /dev/null +++ b/tools/server/webui/src/lib/components/app/navigation/DropdownMenuActions.svelte @@ -0,0 +1,86 @@ + + + + e.stopPropagation()} + > + {#if triggerTooltip} + + + {@render iconComponent(triggerIcon, 'h-3 w-3')} + {triggerTooltip} + + +

{triggerTooltip}

+
+
+ {:else} + {@render iconComponent(triggerIcon, 'h-3 w-3')} + {/if} +
+ + + {#each actions as action, index (action.label)} + {#if action.separator && index > 0} + + {/if} + + +
+ {@render iconComponent( + action.icon, + `h-4 w-4 ${action.variant === 'destructive' ? 'text-destructive' : ''}` + )} + {action.label} +
+ + {#if action.shortcut} + + {/if} +
+ {/each} +
+
+ +{#snippet iconComponent(IconComponent: Component, className: string)} + +{/snippet} diff --git a/tools/server/webui/src/lib/components/app/navigation/DropdownMenuSearchable.svelte b/tools/server/webui/src/lib/components/app/navigation/DropdownMenuSearchable.svelte new file mode 100644 index 0000000000..3bd68d3bd6 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/navigation/DropdownMenuSearchable.svelte @@ -0,0 +1,50 @@ + + +
+ +
+ +
+ {@render children()} + + {#if isEmpty} +
{emptyMessage}
+ {/if} +
+ +{#if footer} + + + {@render footer()} +{/if} diff --git a/tools/server/webui/src/lib/components/app/navigation/index.ts b/tools/server/webui/src/lib/components/app/navigation/index.ts new file mode 100644 index 0000000000..051491b866 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/navigation/index.ts @@ -0,0 +1,65 @@ +/** + * + * NAVIGATION & MENUS + * + * Components for dropdown menus and action selection. + * + */ + +/** + * **DropdownMenuSearchable** - Searchable content for dropdown menus + * + * Renders a search input with filtered content area, empty state, and optional footer. + * Designed to be injected into any dropdown container (DropdownMenu.Content, + * DropdownMenu.SubContent, etc.) without providing its own Root. + * + * **Features:** + * - Search/filter input + * - Keyboard navigation support + * - Custom content and footer via snippets + * - Empty state message + * + * @example + * ```svelte + * + * ... + * + * + * {#each items as item}{/each} + * + * + * + * ``` + */ +export { default as DropdownMenuSearchable } from './DropdownMenuSearchable.svelte'; + +/** + * **DropdownMenuActions** - Multi-action dropdown menu + * + * Dropdown menu for multiple action options with icons and shortcuts. + * Supports destructive variants and keyboard shortcut hints. + * + * **Features:** + * - Configurable trigger icon with tooltip + * - Action items with icons and labels + * - Destructive variant styling + * - Keyboard shortcut display + * - Separator support between groups + * + * @example + * ```svelte + * + * ``` + */ +export { default as DropdownMenuActions } from './DropdownMenuActions.svelte'; diff --git a/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte b/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte index fa4c2842cc..520e5bf56f 100644 --- a/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte +++ b/tools/server/webui/src/lib/components/app/server/ServerErrorSplash.svelte @@ -8,6 +8,7 @@ import { serverStore, serverLoading } from '$lib/stores/server.svelte'; import { config, settingsStore } from '$lib/stores/settings.svelte'; import { fade, fly, scale } from 'svelte/transition'; + import { KeyboardKey } from '$lib/enums/keyboard'; interface Props { class?: string; @@ -117,7 +118,7 @@ } function handleApiKeyKeydown(event: KeyboardEvent) { - if (event.key === 'Enter') { + if (event.key === KeyboardKey.ENTER) { handleSaveApiKey(); } } diff --git a/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte b/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte index d9f6d4a32a..86a962de12 100644 --- a/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte +++ b/tools/server/webui/src/lib/components/app/server/ServerStatus.svelte @@ -48,7 +48,7 @@ {model || 'Unknown Model'} - {#if serverData.default_generation_settings.n_ctx} + {#if serverData?.default_generation_settings?.n_ctx} ctx: {serverData.default_generation_settings.n_ctx.toLocaleString()} diff --git a/tools/server/webui/src/lib/components/app/server/index.ts b/tools/server/webui/src/lib/components/app/server/index.ts new file mode 100644 index 0000000000..39ac5b482d --- /dev/null +++ b/tools/server/webui/src/lib/components/app/server/index.ts @@ -0,0 +1,80 @@ +/** + * + * SERVER + * + * Components for displaying server connection state and handling + * connection errors. Integrates with serverStore for state management. + * + */ + +/** + * **ServerStatus** - Server connection status indicator + * + * Compact status display showing connection state, model name, + * and context size. Used in headers and loading screens. + * + * **Architecture:** + * - Reads state from serverStore (props, loading, error) + * - Displays model name from modelsStore + * + * **Features:** + * - Status dot: green (connected), yellow (connecting), red (error), gray (unknown) + * - Status text label + * - Model name badge with icon + * - Context size badge + * - Optional error action button + * + * @example + * ```svelte + * + * ``` + */ +export { default as ServerStatus } from './ServerStatus.svelte'; + +/** + * **ServerErrorSplash** - Full-screen connection error display + * + * Blocking error screen shown when server connection fails. + * Provides retry options and API key input for authentication errors. + * + * **Architecture:** + * - Detects access denied errors for API key flow + * - Validates API key against server before saving + * - Integrates with settingsStore for API key persistence + * + * **Features:** + * - Error message display with icon + * - Retry connection button with loading state + * - API key input for authentication errors + * - API key validation with success/error feedback + * - Troubleshooting section with server start commands + * - Animated transitions for UI elements + * + * @example + * ```svelte + * + * ``` + */ +export { default as ServerErrorSplash } from './ServerErrorSplash.svelte'; + +/** + * **ServerLoadingSplash** - Full-screen loading display + * + * Shown during initial server connection. Displays loading animation + * with ServerStatus component for real-time connection state. + * + * **Features:** + * - Animated server icon + * - Customizable loading message + * - Embedded ServerStatus for live updates + * + * @example + * ```svelte + * + * ``` + */ +export { default as ServerLoadingSplash } from './ServerLoadingSplash.svelte'; diff --git a/tools/server/webui/src/lib/components/ui/badge/badge.svelte b/tools/server/webui/src/lib/components/ui/badge/badge.svelte index 4d15145493..c3e6ac0720 100644 --- a/tools/server/webui/src/lib/components/ui/badge/badge.svelte +++ b/tools/server/webui/src/lib/components/ui/badge/badge.svelte @@ -42,7 +42,7 @@ bind:this={ref} data-slot="badge" {href} - class={cn(badgeVariants({ variant }), className)} + class={cn(badgeVariants({ variant }), className, 'backdrop-blur-sm')} {...restProps} > {@render children?.()} diff --git a/tools/server/webui/src/lib/components/ui/button/button.svelte b/tools/server/webui/src/lib/components/ui/button/button.svelte index d12c8de147..d29358c8e0 100644 --- a/tools/server/webui/src/lib/components/ui/button/button.svelte +++ b/tools/server/webui/src/lib/components/ui/button/button.svelte @@ -12,8 +12,9 @@ 'bg-destructive shadow-xs hover:bg-destructive/90 focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 dark:bg-destructive/60 text-white', outline: 'bg-background shadow-xs hover:bg-accent hover:text-accent-foreground dark:bg-input/30 dark:border-input dark:hover:bg-input/50 border', - secondary: 'bg-secondary text-secondary-foreground shadow-xs hover:bg-secondary/80', - ghost: 'hover:bg-accent hover:text-accent-foreground dark:hover:bg-accent/50', + secondary: + 'dark:bg-secondary dark:text-secondary-foreground bg-background shadow-sm text-foreground hover:bg-muted-foreground/20', + ghost: 'hover:text-accent-foreground hover:bg-muted-foreground/10', link: 'text-primary underline-offset-4 hover:underline' }, size: { diff --git a/tools/server/webui/src/lib/components/ui/card/card.svelte b/tools/server/webui/src/lib/components/ui/card/card.svelte index c40d14309f..b9dcd2de6f 100644 --- a/tools/server/webui/src/lib/components/ui/card/card.svelte +++ b/tools/server/webui/src/lib/components/ui/card/card.svelte @@ -1,6 +1,7 @@ - +{#snippet tooltipContent()} {@render children?.()} @@ -44,4 +50,12 @@ {/snippet} - +{/snippet} + +{#if noPortal} + {@render tooltipContent()} +{:else} + + {@render tooltipContent()} + +{/if} diff --git a/tools/server/webui/src/lib/constants/binary-detection.ts b/tools/server/webui/src/lib/constants/binary-detection.ts index a4440fde5d..eac919ad96 100644 --- a/tools/server/webui/src/lib/constants/binary-detection.ts +++ b/tools/server/webui/src/lib/constants/binary-detection.ts @@ -1,9 +1,6 @@ export interface BinaryDetectionOptions { - /** Number of characters to check from the beginning of the file */ prefixLength: number; - /** Maximum ratio of suspicious characters allowed (0.0 to 1.0) */ suspiciousCharThresholdRatio: number; - /** Maximum absolute number of null bytes allowed */ maxAbsoluteNullBytes: number; } diff --git a/tools/server/webui/src/lib/constants/chat-form.ts b/tools/server/webui/src/lib/constants/chat-form.ts new file mode 100644 index 0000000000..c5e3dc3d1b --- /dev/null +++ b/tools/server/webui/src/lib/constants/chat-form.ts @@ -0,0 +1,3 @@ +export const INITIAL_FILE_SIZE = 0; +export const PROMPT_CONTENT_SEPARATOR = '\n\n'; +export const CLIPBOARD_CONTENT_QUOTE_PREFIX = '"'; diff --git a/tools/server/webui/src/lib/constants/code-blocks.ts b/tools/server/webui/src/lib/constants/code-blocks.ts new file mode 100644 index 0000000000..0f7265104d --- /dev/null +++ b/tools/server/webui/src/lib/constants/code-blocks.ts @@ -0,0 +1,8 @@ +export const CODE_BLOCK_SCROLL_CONTAINER_CLASS = 'code-block-scroll-container'; +export const CODE_BLOCK_WRAPPER_CLASS = 'code-block-wrapper'; +export const CODE_BLOCK_HEADER_CLASS = 'code-block-header'; +export const CODE_BLOCK_ACTIONS_CLASS = 'code-block-actions'; +export const CODE_LANGUAGE_CLASS = 'code-language'; +export const COPY_CODE_BTN_CLASS = 'copy-code-btn'; +export const PREVIEW_CODE_BTN_CLASS = 'preview-code-btn'; +export const RELATIVE_CLASS = 'relative'; diff --git a/tools/server/webui/src/lib/constants/code.ts b/tools/server/webui/src/lib/constants/code.ts new file mode 100644 index 0000000000..12bcd0db77 --- /dev/null +++ b/tools/server/webui/src/lib/constants/code.ts @@ -0,0 +1,7 @@ +export const NEWLINE = '\n'; +export const DEFAULT_LANGUAGE = 'text'; +export const LANG_PATTERN = /^(\w*)\n?/; +export const AMPERSAND_REGEX = /&/g; +export const LT_REGEX = //g; +export const FENCE_PATTERN = /^```|\n```/g; diff --git a/tools/server/webui/src/lib/constants/css-classes.ts b/tools/server/webui/src/lib/constants/css-classes.ts new file mode 100644 index 0000000000..46076e55f6 --- /dev/null +++ b/tools/server/webui/src/lib/constants/css-classes.ts @@ -0,0 +1,10 @@ +export const BOX_BORDER = + 'border border-border/30 focus-within:border-border dark:border-border/20 dark:focus-within:border-border'; + +export const INPUT_CLASSES = ` + bg-muted/60 dark:bg-muted/75 + ${BOX_BORDER} + shadow-sm + outline-none + text-foreground +`; diff --git a/tools/server/webui/src/lib/constants/formatters.ts b/tools/server/webui/src/lib/constants/formatters.ts new file mode 100644 index 0000000000..d6d1b883ff --- /dev/null +++ b/tools/server/webui/src/lib/constants/formatters.ts @@ -0,0 +1,8 @@ +export const MS_PER_SECOND = 1000; +export const SECONDS_PER_MINUTE = 60; +export const SECONDS_PER_HOUR = 3600; +export const SHORT_DURATION_THRESHOLD = 1; +export const MEDIUM_DURATION_THRESHOLD = 10; + +/** Default display value when no performance time is available */ +export const DEFAULT_PERFORMANCE_TIME = '0s'; diff --git a/tools/server/webui/src/lib/constants/markdown.ts b/tools/server/webui/src/lib/constants/markdown.ts new file mode 100644 index 0000000000..783d31a22c --- /dev/null +++ b/tools/server/webui/src/lib/constants/markdown.ts @@ -0,0 +1,4 @@ +export const IMAGE_NOT_ERROR_BOUND_SELECTOR = 'img:not([data-error-bound])'; +export const DATA_ERROR_BOUND_ATTR = 'errorBound'; +export const DATA_ERROR_HANDLED_ATTR = 'errorHandled'; +export const BOOL_TRUE_STRING = 'true'; diff --git a/tools/server/webui/src/lib/constants/processing-info.ts b/tools/server/webui/src/lib/constants/processing-info.ts index 726439211b..2c3f7dc534 100644 --- a/tools/server/webui/src/lib/constants/processing-info.ts +++ b/tools/server/webui/src/lib/constants/processing-info.ts @@ -1 +1,8 @@ export const PROCESSING_INFO_TIMEOUT = 2000; + +/** + * Statistics units labels + */ +export const STATS_UNITS = { + TOKENS_PER_SECOND: 't/s' +} as const; diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index cac48a557c..1b959f3b69 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -7,7 +7,8 @@ export const SETTING_CONFIG_DEFAULT: Record = theme: 'system', showThoughtInProgress: false, showToolCalls: false, - disableReasoningFormat: false, + disableReasoningParsing: false, + showRawOutputSwitch: false, keepStatsVisible: false, showMessageStats: true, askForTitleConfirmation: false, @@ -92,8 +93,10 @@ export const SETTING_CONFIG_INFO: Record = { showThoughtInProgress: 'Expand thought process by default when generating messages.', showToolCalls: 'Display tool call labels and payloads from Harmony-compatible delta.tool_calls data below assistant messages.', - disableReasoningFormat: - 'Show raw LLM output without backend parsing and frontend Markdown rendering to inspect streaming across different models.', + disableReasoningParsing: + 'Send reasoning_format=none to prevent server-side extraction of reasoning tokens into separate field', + showRawOutputSwitch: + 'Show toggle button to display messages as plain text instead of Markdown-formatted content', keepStatsVisible: 'Keep processing statistics visible after generation finishes.', showMessageStats: 'Display generation statistics (tokens/second, token count, duration) below each assistant message.', diff --git a/tools/server/webui/src/lib/constants/settings-fields.ts b/tools/server/webui/src/lib/constants/settings-fields.ts new file mode 100644 index 0000000000..79a6e92870 --- /dev/null +++ b/tools/server/webui/src/lib/constants/settings-fields.ts @@ -0,0 +1,33 @@ +/** + * List of all numeric fields in settings configuration. + * These fields will be converted from strings to numbers during save. + */ +export const NUMERIC_FIELDS = [ + 'temperature', + 'top_k', + 'top_p', + 'min_p', + 'max_tokens', + 'pasteLongTextToFileLen', + 'dynatemp_range', + 'dynatemp_exponent', + 'typ_p', + 'xtc_probability', + 'xtc_threshold', + 'repeat_last_n', + 'repeat_penalty', + 'presence_penalty', + 'frequency_penalty', + 'dry_multiplier', + 'dry_base', + 'dry_allowed_length', + 'dry_penalty_last_n', + 'agenticMaxTurns', + 'agenticMaxToolPreviewLines' +] as const; + +/** + * Fields that must be positive integers (>= 1). + * These will be clamped to minimum 1 and rounded during save. + */ +export const POSITIVE_INTEGER_FIELDS = ['agenticMaxTurns', 'agenticMaxToolPreviewLines'] as const; diff --git a/tools/server/webui/src/lib/constants/tooltip-config.ts b/tools/server/webui/src/lib/constants/tooltip-config.ts index 3c30c8c072..ad76ab3522 100644 --- a/tools/server/webui/src/lib/constants/tooltip-config.ts +++ b/tools/server/webui/src/lib/constants/tooltip-config.ts @@ -1 +1 @@ -export const TOOLTIP_DELAY_DURATION = 100; +export const TOOLTIP_DELAY_DURATION = 500; diff --git a/tools/server/webui/src/lib/constants/ui.ts b/tools/server/webui/src/lib/constants/ui.ts new file mode 100644 index 0000000000..a75b30f2f8 --- /dev/null +++ b/tools/server/webui/src/lib/constants/ui.ts @@ -0,0 +1 @@ +export const SYSTEM_MESSAGE_PLACEHOLDER = 'System message'; diff --git a/tools/server/webui/src/lib/contexts/chat-actions.context.ts b/tools/server/webui/src/lib/contexts/chat-actions.context.ts new file mode 100644 index 0000000000..eba0fec027 --- /dev/null +++ b/tools/server/webui/src/lib/contexts/chat-actions.context.ts @@ -0,0 +1,34 @@ +import { getContext, setContext } from 'svelte'; + +export interface ChatActionsContext { + copy: (message: DatabaseMessage) => void; + delete: (message: DatabaseMessage) => void; + navigateToSibling: (siblingId: string) => void; + editWithBranching: ( + message: DatabaseMessage, + newContent: string, + newExtras?: DatabaseMessageExtra[] + ) => void; + editWithReplacement: ( + message: DatabaseMessage, + newContent: string, + shouldBranch: boolean + ) => void; + editUserMessagePreserveResponses: ( + message: DatabaseMessage, + newContent: string, + newExtras?: DatabaseMessageExtra[] + ) => void; + regenerateWithBranching: (message: DatabaseMessage, modelOverride?: string) => void; + continueAssistantMessage: (message: DatabaseMessage) => void; +} + +const CHAT_ACTIONS_KEY = Symbol.for('chat-actions'); + +export function setChatActionsContext(ctx: ChatActionsContext): ChatActionsContext { + return setContext(CHAT_ACTIONS_KEY, ctx); +} + +export function getChatActionsContext(): ChatActionsContext { + return getContext(CHAT_ACTIONS_KEY); +} diff --git a/tools/server/webui/src/lib/contexts/index.ts b/tools/server/webui/src/lib/contexts/index.ts new file mode 100644 index 0000000000..73ff6f96fa --- /dev/null +++ b/tools/server/webui/src/lib/contexts/index.ts @@ -0,0 +1,13 @@ +export { + getMessageEditContext, + setMessageEditContext, + type MessageEditContext, + type MessageEditState, + type MessageEditActions +} from './message-edit.context'; + +export { + getChatActionsContext, + setChatActionsContext, + type ChatActionsContext +} from './chat-actions.context'; diff --git a/tools/server/webui/src/lib/contexts/message-edit.context.ts b/tools/server/webui/src/lib/contexts/message-edit.context.ts new file mode 100644 index 0000000000..7af116daa5 --- /dev/null +++ b/tools/server/webui/src/lib/contexts/message-edit.context.ts @@ -0,0 +1,39 @@ +import { getContext, setContext } from 'svelte'; + +export interface MessageEditState { + readonly isEditing: boolean; + readonly editedContent: string; + readonly editedExtras: DatabaseMessageExtra[]; + readonly editedUploadedFiles: ChatUploadedFile[]; + readonly originalContent: string; + readonly originalExtras: DatabaseMessageExtra[]; + readonly showSaveOnlyOption: boolean; +} + +export interface MessageEditActions { + setContent: (content: string) => void; + setExtras: (extras: DatabaseMessageExtra[]) => void; + setUploadedFiles: (files: ChatUploadedFile[]) => void; + save: () => void; + saveOnly: () => void; + cancel: () => void; + startEdit: () => void; +} + +export type MessageEditContext = MessageEditState & MessageEditActions; + +const MESSAGE_EDIT_KEY = Symbol.for('chat-message-edit'); + +/** + * Sets the message edit context. Call this in the parent component (ChatMessage.svelte). + */ +export function setMessageEditContext(ctx: MessageEditContext): MessageEditContext { + return setContext(MESSAGE_EDIT_KEY, ctx); +} + +/** + * Gets the message edit context. Call this in child components. + */ +export function getMessageEditContext(): MessageEditContext { + return getContext(MESSAGE_EDIT_KEY); +} diff --git a/tools/server/webui/src/lib/enums/chat.ts b/tools/server/webui/src/lib/enums/chat.ts index 2b9eb7bc2e..0b6f357d9a 100644 --- a/tools/server/webui/src/lib/enums/chat.ts +++ b/tools/server/webui/src/lib/enums/chat.ts @@ -1,4 +1,51 @@ export enum ChatMessageStatsView { GENERATION = 'generation', - READING = 'reading' + READING = 'reading', + TOOLS = 'tools', + SUMMARY = 'summary' +} + +/** + * Reasoning format options for API requests. + */ +export enum ReasoningFormat { + NONE = 'none', + AUTO = 'auto' +} + +/** + * Message roles for chat messages. + */ +export enum MessageRole { + USER = 'user', + ASSISTANT = 'assistant', + SYSTEM = 'system', + TOOL = 'tool' +} + +/** + * Message types for different content kinds. + */ +export enum MessageType { + ROOT = 'root', + TEXT = 'text', + THINK = 'think', + SYSTEM = 'system' +} + +/** + * Content part types for API chat message content. + */ +export enum ContentPartType { + TEXT = 'text', + IMAGE_URL = 'image_url', + INPUT_AUDIO = 'input_audio' +} + +/** + * Error dialog types for displaying server/timeout errors. + */ +export enum ErrorDialogType { + TIMEOUT = 'timeout', + SERVER = 'server' } diff --git a/tools/server/webui/src/lib/enums/keyboard.ts b/tools/server/webui/src/lib/enums/keyboard.ts new file mode 100644 index 0000000000..b8f6d5f7a2 --- /dev/null +++ b/tools/server/webui/src/lib/enums/keyboard.ts @@ -0,0 +1,15 @@ +/** + * Keyboard key names for event handling + */ +export enum KeyboardKey { + ENTER = 'Enter', + ESCAPE = 'Escape', + ARROW_UP = 'ArrowUp', + ARROW_DOWN = 'ArrowDown', + TAB = 'Tab', + D_LOWER = 'd', + D_UPPER = 'D', + E_UPPER = 'E', + K_LOWER = 'k', + O_UPPER = 'O' +} diff --git a/tools/server/webui/src/lib/enums/settings.ts b/tools/server/webui/src/lib/enums/settings.ts new file mode 100644 index 0000000000..f17f219762 --- /dev/null +++ b/tools/server/webui/src/lib/enums/settings.ts @@ -0,0 +1,26 @@ +/** + * Parameter source - indicates whether a parameter uses default or custom value + */ +export enum ParameterSource { + DEFAULT = 'default', + CUSTOM = 'custom' +} + +/** + * Syncable parameter type - data types for parameters that can be synced with server + */ +export enum SyncableParameterType { + NUMBER = 'number', + STRING = 'string', + BOOLEAN = 'boolean' +} + +/** + * Settings field type - defines the input type for settings fields + */ +export enum SettingsFieldType { + INPUT = 'input', + TEXTAREA = 'textarea', + CHECKBOX = 'checkbox', + SELECT = 'select' +} diff --git a/tools/server/webui/src/lib/hooks/use-auto-scroll.svelte.ts b/tools/server/webui/src/lib/hooks/use-auto-scroll.svelte.ts new file mode 100644 index 0000000000..bbaa5d1362 --- /dev/null +++ b/tools/server/webui/src/lib/hooks/use-auto-scroll.svelte.ts @@ -0,0 +1,165 @@ +import { AUTO_SCROLL_AT_BOTTOM_THRESHOLD, AUTO_SCROLL_INTERVAL } from '$lib/constants/auto-scroll'; + +export interface AutoScrollOptions { + /** Whether auto-scroll is disabled globally (e.g., from settings) */ + disabled?: boolean; +} + +/** + * Creates an auto-scroll controller for a scrollable container. + * + * Features: + * - Auto-scrolls to bottom during streaming/loading + * - Stops auto-scroll when user manually scrolls up + * - Resumes auto-scroll when user scrolls back to bottom + */ +export class AutoScrollController { + private _autoScrollEnabled = $state(true); + private _userScrolledUp = $state(false); + private _lastScrollTop = $state(0); + private _scrollInterval: ReturnType | undefined; + private _scrollTimeout: ReturnType | undefined; + private _container: HTMLElement | undefined; + private _disabled: boolean; + + constructor(options: AutoScrollOptions = {}) { + this._disabled = options.disabled ?? false; + } + + get autoScrollEnabled(): boolean { + return this._autoScrollEnabled; + } + + get userScrolledUp(): boolean { + return this._userScrolledUp; + } + + /** + * Binds the controller to a scrollable container element. + */ + setContainer(container: HTMLElement | undefined): void { + this._container = container; + } + + /** + * Updates the disabled state. + */ + setDisabled(disabled: boolean): void { + this._disabled = disabled; + if (disabled) { + this._autoScrollEnabled = false; + this.stopInterval(); + } + } + + /** + * Handles scroll events to detect user scroll direction and toggle auto-scroll. + */ + handleScroll(): void { + if (this._disabled || !this._container) return; + + const { scrollTop, scrollHeight, clientHeight } = this._container; + const distanceFromBottom = scrollHeight - scrollTop - clientHeight; + const isAtBottom = distanceFromBottom < AUTO_SCROLL_AT_BOTTOM_THRESHOLD; + + if (scrollTop < this._lastScrollTop && !isAtBottom) { + this._userScrolledUp = true; + this._autoScrollEnabled = false; + } else if (isAtBottom && this._userScrolledUp) { + this._userScrolledUp = false; + this._autoScrollEnabled = true; + } + + if (this._scrollTimeout) { + clearTimeout(this._scrollTimeout); + } + + this._scrollTimeout = setTimeout(() => { + if (isAtBottom) { + this._userScrolledUp = false; + this._autoScrollEnabled = true; + } + }, AUTO_SCROLL_INTERVAL); + + this._lastScrollTop = scrollTop; + } + + /** + * Scrolls the container to the bottom. + */ + scrollToBottom(behavior: ScrollBehavior = 'smooth'): void { + if (this._disabled || !this._container) return; + + this._container.scrollTo({ + top: this._container.scrollHeight, + behavior + }); + } + + /** + * Enables auto-scroll (e.g., when user sends a message). + */ + enable(): void { + if (this._disabled) return; + this._userScrolledUp = false; + this._autoScrollEnabled = true; + } + + /** + * Starts the auto-scroll interval for continuous scrolling during streaming. + */ + startInterval(): void { + if (this._disabled || this._scrollInterval) return; + + this._scrollInterval = setInterval(() => { + this.scrollToBottom(); + }, AUTO_SCROLL_INTERVAL); + } + + /** + * Stops the auto-scroll interval. + */ + stopInterval(): void { + if (this._scrollInterval) { + clearInterval(this._scrollInterval); + this._scrollInterval = undefined; + } + } + + /** + * Updates the auto-scroll interval based on streaming state. + * Call this in a $effect to automatically manage the interval. + */ + updateInterval(isStreaming: boolean): void { + if (this._disabled) { + this.stopInterval(); + return; + } + + if (isStreaming && this._autoScrollEnabled) { + if (!this._scrollInterval) { + this.startInterval(); + } + } else { + this.stopInterval(); + } + } + + /** + * Cleans up resources. Call this in onDestroy or when the component unmounts. + */ + destroy(): void { + this.stopInterval(); + if (this._scrollTimeout) { + clearTimeout(this._scrollTimeout); + this._scrollTimeout = undefined; + } + } +} + +/** + * Creates a new AutoScrollController instance. + */ +export function createAutoScrollController(options: AutoScrollOptions = {}): AutoScrollController { + return new AutoScrollController(options); +} diff --git a/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts b/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts index c06cf28864..068440cdc0 100644 --- a/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts +++ b/tools/server/webui/src/lib/hooks/use-processing-state.svelte.ts @@ -1,7 +1,9 @@ import { activeProcessingState } from '$lib/stores/chat.svelte'; import { config } from '$lib/stores/settings.svelte'; +import { STATS_UNITS } from '$lib/constants/processing-info'; +import type { ApiProcessingState } from '$lib/types'; -export interface LiveProcessingStats { +interface LiveProcessingStats { tokensProcessed: number; totalTokens: number; timeMs: number; @@ -9,7 +11,7 @@ export interface LiveProcessingStats { etaSecs?: number; } -export interface LiveGenerationStats { +interface LiveGenerationStats { tokensGenerated: number; timeMs: number; tokensPerSecond: number; @@ -18,6 +20,7 @@ export interface LiveGenerationStats { export interface UseProcessingStateReturn { readonly processingState: ApiProcessingState | null; getProcessingDetails(): string[]; + getTechnicalDetails(): string[]; getProcessingMessage(): string; getPromptProgressText(): string | null; getLiveProcessingStats(): LiveProcessingStats | null; @@ -138,8 +141,31 @@ export function useProcessingState(): UseProcessingStateReturn { const details: string[] = []; + // Show prompt processing progress with ETA during preparation phase + if (stateToUse.promptProgress) { + const { processed, total, time_ms, cache } = stateToUse.promptProgress; + const actualProcessed = processed - cache; + const actualTotal = total - cache; + + if (actualProcessed < actualTotal && actualProcessed > 0) { + const percent = Math.round((actualProcessed / actualTotal) * 100); + const eta = getETASecs(actualProcessed, actualTotal, time_ms); + + if (eta !== undefined) { + const etaSecs = Math.ceil(eta); + details.push(`Processing ${percent}% (ETA: ${etaSecs}s)`); + } else { + details.push(`Processing ${percent}%`); + } + } + } + // Always show context info when we have valid data - if (stateToUse.contextUsed >= 0 && stateToUse.contextTotal > 0) { + if ( + typeof stateToUse.contextTotal === 'number' && + stateToUse.contextUsed >= 0 && + stateToUse.contextTotal > 0 + ) { const contextPercent = Math.round((stateToUse.contextUsed / stateToUse.contextTotal) * 100); details.push( @@ -163,7 +189,57 @@ export function useProcessingState(): UseProcessingStateReturn { } if (stateToUse.tokensPerSecond && stateToUse.tokensPerSecond > 0) { - details.push(`${stateToUse.tokensPerSecond.toFixed(1)} tokens/sec`); + details.push(`${stateToUse.tokensPerSecond.toFixed(1)} ${STATS_UNITS.TOKENS_PER_SECOND}`); + } + + if (stateToUse.speculative) { + details.push('Speculative decoding enabled'); + } + + return details; + } + + /** + * Returns technical details without the progress message (for bottom bar) + */ + function getTechnicalDetails(): string[] { + const stateToUse = processingState || lastKnownState; + if (!stateToUse) { + return []; + } + + const details: string[] = []; + + // Always show context info when we have valid data + if ( + typeof stateToUse.contextTotal === 'number' && + stateToUse.contextUsed >= 0 && + stateToUse.contextTotal > 0 + ) { + const contextPercent = Math.round((stateToUse.contextUsed / stateToUse.contextTotal) * 100); + + details.push( + `Context: ${stateToUse.contextUsed}/${stateToUse.contextTotal} (${contextPercent}%)` + ); + } + + if (stateToUse.outputTokensUsed > 0) { + // Handle infinite max_tokens (-1) case + if (stateToUse.outputTokensMax <= 0) { + details.push(`Output: ${stateToUse.outputTokensUsed}/∞`); + } else { + const outputPercent = Math.round( + (stateToUse.outputTokensUsed / stateToUse.outputTokensMax) * 100 + ); + + details.push( + `Output: ${stateToUse.outputTokensUsed}/${stateToUse.outputTokensMax} (${outputPercent}%)` + ); + } + } + + if (stateToUse.tokensPerSecond && stateToUse.tokensPerSecond > 0) { + details.push(`${stateToUse.tokensPerSecond.toFixed(1)} ${STATS_UNITS.TOKENS_PER_SECOND}`); } if (stateToUse.speculative) { @@ -251,6 +327,7 @@ export function useProcessingState(): UseProcessingStateReturn { return processingState; }, getProcessingDetails, + getTechnicalDetails, getProcessingMessage, getPromptProgressText, getLiveProcessingStats, diff --git a/tools/server/webui/src/lib/markdown/enhance-code-blocks.ts b/tools/server/webui/src/lib/markdown/enhance-code-blocks.ts index 6f0e03e211..168de97403 100644 --- a/tools/server/webui/src/lib/markdown/enhance-code-blocks.ts +++ b/tools/server/webui/src/lib/markdown/enhance-code-blocks.ts @@ -13,6 +13,16 @@ import type { Plugin } from 'unified'; import type { Root, Element, ElementContent } from 'hast'; import { visit } from 'unist-util-visit'; +import { + CODE_BLOCK_SCROLL_CONTAINER_CLASS, + CODE_BLOCK_WRAPPER_CLASS, + CODE_BLOCK_HEADER_CLASS, + CODE_BLOCK_ACTIONS_CLASS, + CODE_LANGUAGE_CLASS, + COPY_CODE_BTN_CLASS, + PREVIEW_CODE_BTN_CLASS, + RELATIVE_CLASS +} from '$lib/constants/code-blocks'; declare global { interface Window { @@ -42,7 +52,7 @@ function createCopyButton(codeId: string): Element { type: 'element', tagName: 'button', properties: { - className: ['copy-code-btn'], + className: [COPY_CODE_BTN_CLASS], 'data-code-id': codeId, title: 'Copy code', type: 'button' @@ -56,7 +66,7 @@ function createPreviewButton(codeId: string): Element { type: 'element', tagName: 'button', properties: { - className: ['preview-code-btn'], + className: [PREVIEW_CODE_BTN_CLASS], 'data-code-id': codeId, title: 'Preview code', type: 'button' @@ -75,30 +85,39 @@ function createHeader(language: string, codeId: string): Element { return { type: 'element', tagName: 'div', - properties: { className: ['code-block-header'] }, + properties: { className: [CODE_BLOCK_HEADER_CLASS] }, children: [ { type: 'element', tagName: 'span', - properties: { className: ['code-language'] }, + properties: { className: [CODE_LANGUAGE_CLASS] }, children: [{ type: 'text', value: language }] }, { type: 'element', tagName: 'div', - properties: { className: ['code-block-actions'] }, + properties: { className: [CODE_BLOCK_ACTIONS_CLASS] }, children: actions } ] }; } +function createScrollContainer(preElement: Element): Element { + return { + type: 'element', + tagName: 'div', + properties: { className: [CODE_BLOCK_SCROLL_CONTAINER_CLASS] }, + children: [preElement] + }; +} + function createWrapper(header: Element, preElement: Element): Element { return { type: 'element', tagName: 'div', - properties: { className: ['code-block-wrapper'] }, - children: [header, preElement] + properties: { className: [CODE_BLOCK_WRAPPER_CLASS, RELATIVE_CLASS] }, + children: [header, createScrollContainer(preElement)] }; } diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts index 02fc6381c0..55af0ce816 100644 --- a/tools/server/webui/src/lib/services/chat.ts +++ b/tools/server/webui/src/lib/services/chat.ts @@ -90,7 +90,7 @@ export class ChatService { custom, timings_per_token, // Config options - disableReasoningFormat + disableReasoningParsing } = options; const normalizedMessages: ApiChatMessageData[] = messages @@ -127,7 +127,7 @@ export class ChatService { requestBody.model = options.model; } - requestBody.reasoning_format = disableReasoningFormat ? 'none' : 'auto'; + requestBody.reasoning_format = disableReasoningParsing ? 'none' : 'auto'; if (temperature !== undefined) requestBody.temperature = temperature; if (max_tokens !== undefined) { diff --git a/tools/server/webui/src/lib/services/database.service.ts b/tools/server/webui/src/lib/services/database.service.ts new file mode 100644 index 0000000000..0d5a9c1b99 --- /dev/null +++ b/tools/server/webui/src/lib/services/database.service.ts @@ -0,0 +1,368 @@ +import Dexie, { type EntityTable } from 'dexie'; +import { findDescendantMessages } from '$lib/utils'; + +class LlamacppDatabase extends Dexie { + conversations!: EntityTable; + messages!: EntityTable; + + constructor() { + super('LlamacppWebui'); + + this.version(1).stores({ + conversations: 'id, lastModified, currNode, name', + messages: 'id, convId, type, role, timestamp, parent, children' + }); + } +} + +const db = new LlamacppDatabase(); +import { v4 as uuid } from 'uuid'; +import { MessageRole } from '$lib/enums/chat'; + +export class DatabaseService { + /** + * + * + * Conversations + * + * + */ + + /** + * Creates a new conversation. + * + * @param name - Name of the conversation + * @returns The created conversation + */ + static async createConversation(name: string): Promise { + const conversation: DatabaseConversation = { + id: uuid(), + name, + lastModified: Date.now(), + currNode: '' + }; + + await db.conversations.add(conversation); + return conversation; + } + + /** + * + * + * Messages + * + * + */ + + /** + * Creates a new message branch by adding a message and updating parent/child relationships. + * Also updates the conversation's currNode to point to the new message. + * + * @param message - Message to add (without id) + * @param parentId - Parent message ID to attach to + * @returns The created message + */ + static async createMessageBranch( + message: Omit, + parentId: string | null + ): Promise { + return await db.transaction('rw', [db.conversations, db.messages], async () => { + // Handle null parent (root message case) + if (parentId !== null) { + const parentMessage = await db.messages.get(parentId); + if (!parentMessage) { + throw new Error(`Parent message ${parentId} not found`); + } + } + + const newMessage: DatabaseMessage = { + ...message, + id: uuid(), + parent: parentId, + toolCalls: message.toolCalls ?? '', + children: [] + }; + + await db.messages.add(newMessage); + + // Update parent's children array if parent exists + if (parentId !== null) { + const parentMessage = await db.messages.get(parentId); + if (parentMessage) { + await db.messages.update(parentId, { + children: [...parentMessage.children, newMessage.id] + }); + } + } + + await this.updateConversation(message.convId, { + currNode: newMessage.id + }); + + return newMessage; + }); + } + + /** + * Creates a root message for a new conversation. + * Root messages are not displayed but serve as the tree root for branching. + * + * @param convId - Conversation ID + * @returns The created root message + */ + static async createRootMessage(convId: string): Promise { + const rootMessage: DatabaseMessage = { + id: uuid(), + convId, + type: 'root', + timestamp: Date.now(), + role: MessageRole.SYSTEM, + content: '', + parent: null, + toolCalls: '', + children: [] + }; + + await db.messages.add(rootMessage); + return rootMessage.id; + } + + /** + * Creates a system prompt message for a conversation. + * + * @param convId - Conversation ID + * @param systemPrompt - The system prompt content (must be non-empty) + * @param parentId - Parent message ID (typically the root message) + * @returns The created system message + * @throws Error if systemPrompt is empty + */ + static async createSystemMessage( + convId: string, + systemPrompt: string, + parentId: string + ): Promise { + const trimmedPrompt = systemPrompt.trim(); + if (!trimmedPrompt) { + throw new Error('Cannot create system message with empty content'); + } + + const systemMessage: DatabaseMessage = { + id: uuid(), + convId, + type: MessageRole.SYSTEM, + timestamp: Date.now(), + role: MessageRole.SYSTEM, + content: trimmedPrompt, + parent: parentId, + children: [] + }; + + await db.messages.add(systemMessage); + + const parentMessage = await db.messages.get(parentId); + if (parentMessage) { + await db.messages.update(parentId, { + children: [...parentMessage.children, systemMessage.id] + }); + } + + return systemMessage; + } + + /** + * Deletes a conversation and all its messages. + * + * @param id - Conversation ID + */ + static async deleteConversation(id: string): Promise { + await db.transaction('rw', [db.conversations, db.messages], async () => { + await db.conversations.delete(id); + await db.messages.where('convId').equals(id).delete(); + }); + } + + /** + * Deletes a message and removes it from its parent's children array. + * + * @param messageId - ID of the message to delete + */ + static async deleteMessage(messageId: string): Promise { + await db.transaction('rw', db.messages, async () => { + const message = await db.messages.get(messageId); + if (!message) return; + + // Remove this message from its parent's children array + if (message.parent) { + const parent = await db.messages.get(message.parent); + if (parent) { + parent.children = parent.children.filter((childId: string) => childId !== messageId); + await db.messages.put(parent); + } + } + + // Delete the message + await db.messages.delete(messageId); + }); + } + + /** + * Deletes a message and all its descendant messages (cascading deletion). + * This removes the entire branch starting from the specified message. + * + * @param conversationId - ID of the conversation containing the message + * @param messageId - ID of the root message to delete (along with all descendants) + * @returns Array of all deleted message IDs + */ + static async deleteMessageCascading( + conversationId: string, + messageId: string + ): Promise { + return await db.transaction('rw', db.messages, async () => { + // Get all messages in the conversation to find descendants + const allMessages = await db.messages.where('convId').equals(conversationId).toArray(); + + // Find all descendant messages + const descendants = findDescendantMessages(allMessages, messageId); + const allToDelete = [messageId, ...descendants]; + + // Get the message to delete for parent cleanup + const message = await db.messages.get(messageId); + if (message && message.parent) { + const parent = await db.messages.get(message.parent); + if (parent) { + parent.children = parent.children.filter((childId: string) => childId !== messageId); + await db.messages.put(parent); + } + } + + // Delete all messages in the branch + await db.messages.bulkDelete(allToDelete); + + return allToDelete; + }); + } + + /** + * Gets all conversations, sorted by last modified time (newest first). + * + * @returns Array of conversations + */ + static async getAllConversations(): Promise { + return await db.conversations.orderBy('lastModified').reverse().toArray(); + } + + /** + * Gets a conversation by ID. + * + * @param id - Conversation ID + * @returns The conversation if found, otherwise undefined + */ + static async getConversation(id: string): Promise { + return await db.conversations.get(id); + } + + /** + * Gets all messages in a conversation, sorted by timestamp (oldest first). + * + * @param convId - Conversation ID + * @returns Array of messages in the conversation + */ + static async getConversationMessages(convId: string): Promise { + return await db.messages.where('convId').equals(convId).sortBy('timestamp'); + } + + /** + * Updates a conversation. + * + * @param id - Conversation ID + * @param updates - Partial updates to apply + * @returns Promise that resolves when the conversation is updated + */ + static async updateConversation( + id: string, + updates: Partial> + ): Promise { + await db.conversations.update(id, { + ...updates, + lastModified: Date.now() + }); + } + + /** + * + * + * Navigation + * + * + */ + + /** + * Updates the conversation's current node (active branch). + * This determines which conversation path is currently being viewed. + * + * @param convId - Conversation ID + * @param nodeId - Message ID to set as current node + */ + static async updateCurrentNode(convId: string, nodeId: string): Promise { + await this.updateConversation(convId, { + currNode: nodeId + }); + } + + /** + * Updates a message. + * + * @param id - Message ID + * @param updates - Partial updates to apply + * @returns Promise that resolves when the message is updated + */ + static async updateMessage( + id: string, + updates: Partial> + ): Promise { + await db.messages.update(id, updates); + } + + /** + * + * + * Import + * + * + */ + + /** + * Imports multiple conversations and their messages. + * Skips conversations that already exist. + * + * @param data - Array of { conv, messages } objects + */ + static async importConversations( + data: { conv: DatabaseConversation; messages: DatabaseMessage[] }[] + ): Promise<{ imported: number; skipped: number }> { + let importedCount = 0; + let skippedCount = 0; + + return await db.transaction('rw', [db.conversations, db.messages], async () => { + for (const item of data) { + const { conv, messages } = item; + + const existing = await db.conversations.get(conv.id); + if (existing) { + console.warn(`Conversation "${conv.name}" already exists, skipping...`); + skippedCount++; + continue; + } + + await db.conversations.add(conv); + for (const msg of messages) { + await db.messages.put(msg); + } + + importedCount++; + } + + return { imported: importedCount, skipped: skippedCount }; + }); + } +} diff --git a/tools/server/webui/src/lib/services/models.service.ts b/tools/server/webui/src/lib/services/models.service.ts new file mode 100644 index 0000000000..7357c3f400 --- /dev/null +++ b/tools/server/webui/src/lib/services/models.service.ts @@ -0,0 +1,99 @@ +import { ServerModelStatus } from '$lib/enums'; +import { apiFetch, apiPost } from '$lib/utils/api-fetch'; + +export class ModelsService { + /** + * + * + * Listing + * + * + */ + + /** + * Fetch list of models from OpenAI-compatible endpoint. + * Works in both MODEL and ROUTER modes. + * + * @returns List of available models with basic metadata + */ + static async list(): Promise { + return apiFetch('/v1/models'); + } + + /** + * Fetch list of all models with detailed metadata (ROUTER mode). + * Returns models with load status, paths, and other metadata + * beyond what the OpenAI-compatible endpoint provides. + * + * @returns List of models with detailed status and configuration info + */ + static async listRouter(): Promise { + return apiFetch('/v1/models'); + } + + /** + * + * + * Load/Unload + * + * + */ + + /** + * Load a model (ROUTER mode only). + * Sends POST request to `/models/load`. Note: the endpoint returns success + * before loading completes — use polling to await actual load status. + * + * @param modelId - Model identifier to load + * @param extraArgs - Optional additional arguments to pass to the model instance + * @returns Load response from the server + */ + static async load(modelId: string, extraArgs?: string[]): Promise { + const payload: { model: string; extra_args?: string[] } = { model: modelId }; + if (extraArgs && extraArgs.length > 0) { + payload.extra_args = extraArgs; + } + + return apiPost('/models/load', payload); + } + + /** + * Unload a model (ROUTER mode only). + * Sends POST request to `/models/unload`. Note: the endpoint returns success + * before unloading completes — use polling to await actual unload status. + * + * @param modelId - Model identifier to unload + * @returns Unload response from the server + */ + static async unload(modelId: string): Promise { + return apiPost('/models/unload', { model: modelId }); + } + + /** + * + * + * Status + * + * + */ + + /** + * Check if a model is loaded based on its metadata. + * + * @param model - Model data entry from the API response + * @returns True if the model status is LOADED + */ + static isModelLoaded(model: ApiModelDataEntry): boolean { + return model.status.value === ServerModelStatus.LOADED; + } + + /** + * Check if a model is currently loading. + * + * @param model - Model data entry from the API response + * @returns True if the model status is LOADING + */ + static isModelLoading(model: ApiModelDataEntry): boolean { + return model.status.value === ServerModelStatus.LOADING; + } +} diff --git a/tools/server/webui/src/lib/services/parameter-sync.service.spec.ts b/tools/server/webui/src/lib/services/parameter-sync.service.spec.ts new file mode 100644 index 0000000000..46cce5e7cb --- /dev/null +++ b/tools/server/webui/src/lib/services/parameter-sync.service.spec.ts @@ -0,0 +1,148 @@ +import { describe, it, expect } from 'vitest'; +import { ParameterSyncService } from './parameter-sync.service'; + +describe('ParameterSyncService', () => { + describe('roundFloatingPoint', () => { + it('should fix JavaScript floating-point precision issues', () => { + // Test the specific values from the screenshot + const mockServerParams = { + top_p: 0.949999988079071, + min_p: 0.009999999776482582, + temperature: 0.800000011920929, + top_k: 40, + samplers: ['top_k', 'typ_p', 'top_p', 'min_p', 'temperature'] + }; + + const result = ParameterSyncService.extractServerDefaults({ + ...mockServerParams, + // Add other required fields to match the API type + n_predict: 512, + seed: -1, + dynatemp_range: 0.0, + dynatemp_exponent: 1.0, + xtc_probability: 0.0, + xtc_threshold: 0.1, + typ_p: 1.0, + repeat_last_n: 64, + repeat_penalty: 1.0, + presence_penalty: 0.0, + frequency_penalty: 0.0, + dry_multiplier: 0.0, + dry_base: 1.75, + dry_allowed_length: 2, + dry_penalty_last_n: -1, + mirostat: 0, + mirostat_tau: 5.0, + mirostat_eta: 0.1, + stop: [], + max_tokens: -1, + n_keep: 0, + n_discard: 0, + ignore_eos: false, + stream: true, + logit_bias: [], + n_probs: 0, + min_keep: 0, + grammar: '', + grammar_lazy: false, + grammar_triggers: [], + preserved_tokens: [], + chat_format: '', + reasoning_format: '', + reasoning_in_content: false, + thinking_forced_open: false, + 'speculative.n_max': 0, + 'speculative.n_min': 0, + 'speculative.p_min': 0.0, + timings_per_token: false, + post_sampling_probs: false, + lora: [], + top_n_sigma: 0.0, + dry_sequence_breakers: [] + } as ApiLlamaCppServerProps['default_generation_settings']['params']); + + // Check that the problematic floating-point values are rounded correctly + expect(result.top_p).toBe(0.95); + expect(result.min_p).toBe(0.01); + expect(result.temperature).toBe(0.8); + expect(result.top_k).toBe(40); // Integer should remain unchanged + expect(result.samplers).toBe('top_k;typ_p;top_p;min_p;temperature'); + }); + + it('should preserve non-numeric values', () => { + const mockServerParams = { + samplers: ['top_k', 'temperature'], + max_tokens: -1, + temperature: 0.7 + }; + + const result = ParameterSyncService.extractServerDefaults({ + ...mockServerParams, + // Minimal required fields + n_predict: 512, + seed: -1, + dynatemp_range: 0.0, + dynatemp_exponent: 1.0, + top_k: 40, + top_p: 0.95, + min_p: 0.05, + xtc_probability: 0.0, + xtc_threshold: 0.1, + typ_p: 1.0, + repeat_last_n: 64, + repeat_penalty: 1.0, + presence_penalty: 0.0, + frequency_penalty: 0.0, + dry_multiplier: 0.0, + dry_base: 1.75, + dry_allowed_length: 2, + dry_penalty_last_n: -1, + mirostat: 0, + mirostat_tau: 5.0, + mirostat_eta: 0.1, + stop: [], + n_keep: 0, + n_discard: 0, + ignore_eos: false, + stream: true, + logit_bias: [], + n_probs: 0, + min_keep: 0, + grammar: '', + grammar_lazy: false, + grammar_triggers: [], + preserved_tokens: [], + chat_format: '', + reasoning_format: '', + reasoning_in_content: false, + thinking_forced_open: false, + 'speculative.n_max': 0, + 'speculative.n_min': 0, + 'speculative.p_min': 0.0, + timings_per_token: false, + post_sampling_probs: false, + lora: [], + top_n_sigma: 0.0, + dry_sequence_breakers: [] + } as ApiLlamaCppServerProps['default_generation_settings']['params']); + + expect(result.samplers).toBe('top_k;temperature'); + expect(result.max_tokens).toBe(-1); + expect(result.temperature).toBe(0.7); + }); + + it('should merge webui settings from props when provided', () => { + const result = ParameterSyncService.extractServerDefaults(null, { + pasteLongTextToFileLen: 0, + pdfAsImage: true, + renderUserContentAsMarkdown: false, + theme: 'dark' + }); + + expect(result.pasteLongTextToFileLen).toBe(0); + expect(result.pdfAsImage).toBe(true); + expect(result.renderUserContentAsMarkdown).toBe(false); + expect(result.theme).toBeUndefined(); + }); + }); +}); diff --git a/tools/server/webui/src/lib/services/parameter-sync.service.ts b/tools/server/webui/src/lib/services/parameter-sync.service.ts new file mode 100644 index 0000000000..6cb53d12d1 --- /dev/null +++ b/tools/server/webui/src/lib/services/parameter-sync.service.ts @@ -0,0 +1,400 @@ +import { normalizeFloatingPoint } from '$lib/utils'; +import { SyncableParameterType, ParameterSource } from '$lib/enums/settings'; + +type ParameterValue = string | number | boolean; +type ParameterRecord = Record; + +interface ParameterInfo { + value: string | number | boolean; + source: ParameterSource; + serverDefault?: string | number | boolean; + userOverride?: string | number | boolean; +} + +interface SyncableParameter { + key: string; + serverKey: string; + type: SyncableParameterType; + canSync: boolean; +} + +/** + * Mapping of webui setting keys to server parameter keys. + * Only parameters listed here can be synced from the server `/props` endpoint. + * Each entry defines the webui key, corresponding server key, value type, + * and whether sync is enabled. + */ +export const SYNCABLE_PARAMETERS: SyncableParameter[] = [ + { + key: 'temperature', + serverKey: 'temperature', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { key: 'top_k', serverKey: 'top_k', type: SyncableParameterType.NUMBER, canSync: true }, + { key: 'top_p', serverKey: 'top_p', type: SyncableParameterType.NUMBER, canSync: true }, + { key: 'min_p', serverKey: 'min_p', type: SyncableParameterType.NUMBER, canSync: true }, + { + key: 'dynatemp_range', + serverKey: 'dynatemp_range', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'dynatemp_exponent', + serverKey: 'dynatemp_exponent', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'xtc_probability', + serverKey: 'xtc_probability', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'xtc_threshold', + serverKey: 'xtc_threshold', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { key: 'typ_p', serverKey: 'typ_p', type: SyncableParameterType.NUMBER, canSync: true }, + { + key: 'repeat_last_n', + serverKey: 'repeat_last_n', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'repeat_penalty', + serverKey: 'repeat_penalty', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'presence_penalty', + serverKey: 'presence_penalty', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'frequency_penalty', + serverKey: 'frequency_penalty', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'dry_multiplier', + serverKey: 'dry_multiplier', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { key: 'dry_base', serverKey: 'dry_base', type: SyncableParameterType.NUMBER, canSync: true }, + { + key: 'dry_allowed_length', + serverKey: 'dry_allowed_length', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'dry_penalty_last_n', + serverKey: 'dry_penalty_last_n', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { key: 'max_tokens', serverKey: 'max_tokens', type: SyncableParameterType.NUMBER, canSync: true }, + { key: 'samplers', serverKey: 'samplers', type: SyncableParameterType.STRING, canSync: true }, + { + key: 'pasteLongTextToFileLen', + serverKey: 'pasteLongTextToFileLen', + type: SyncableParameterType.NUMBER, + canSync: true + }, + { + key: 'pdfAsImage', + serverKey: 'pdfAsImage', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'showThoughtInProgress', + serverKey: 'showThoughtInProgress', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'keepStatsVisible', + serverKey: 'keepStatsVisible', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'showMessageStats', + serverKey: 'showMessageStats', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'askForTitleConfirmation', + serverKey: 'askForTitleConfirmation', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'disableAutoScroll', + serverKey: 'disableAutoScroll', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'renderUserContentAsMarkdown', + serverKey: 'renderUserContentAsMarkdown', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'autoMicOnEmpty', + serverKey: 'autoMicOnEmpty', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'pyInterpreterEnabled', + serverKey: 'pyInterpreterEnabled', + type: SyncableParameterType.BOOLEAN, + canSync: true + }, + { + key: 'enableContinueGeneration', + serverKey: 'enableContinueGeneration', + type: SyncableParameterType.BOOLEAN, + canSync: true + } +]; + +export class ParameterSyncService { + /** + * + * + * Extraction + * + * + */ + + /** + * Round floating-point numbers to avoid JavaScript precision issues. + * E.g., 0.1 + 0.2 = 0.30000000000000004 → 0.3 + * + * @param value - Parameter value to normalize + * @returns Precision-normalized value + */ + private static roundFloatingPoint(value: ParameterValue): ParameterValue { + return normalizeFloatingPoint(value) as ParameterValue; + } + + /** + * Extract server default parameters that can be synced from `/props` response. + * Handles both generation settings parameters and webui-specific settings. + * Converts samplers array to semicolon-delimited string for UI display. + * + * @param serverParams - Raw generation settings from server `/props` endpoint + * @param webuiSettings - Optional webui-specific settings from server + * @returns Record of extracted parameter key-value pairs with normalized precision + */ + static extractServerDefaults( + serverParams: ApiLlamaCppServerProps['default_generation_settings']['params'] | null, + webuiSettings?: Record + ): ParameterRecord { + const extracted: ParameterRecord = {}; + + if (serverParams) { + for (const param of SYNCABLE_PARAMETERS) { + if (param.canSync && param.serverKey in serverParams) { + const value = (serverParams as unknown as Record)[ + param.serverKey + ]; + if (value !== undefined) { + // Apply precision rounding to avoid JavaScript floating-point issues + extracted[param.key] = this.roundFloatingPoint(value); + } + } + } + + // Handle samplers array conversion to string + if (serverParams.samplers && Array.isArray(serverParams.samplers)) { + extracted.samplers = serverParams.samplers.join(';'); + } + } + + if (webuiSettings) { + for (const param of SYNCABLE_PARAMETERS) { + if (param.canSync && param.serverKey in webuiSettings) { + const value = webuiSettings[param.serverKey]; + if (value !== undefined) { + extracted[param.key] = this.roundFloatingPoint(value); + } + } + } + } + + return extracted; + } + + /** + * + * + * Merging + * + * + */ + + /** + * Merge server defaults with current user settings. + * User overrides always take priority — only parameters not in `userOverrides` + * set will be updated from server defaults. + * + * @param currentSettings - Current parameter values in the settings store + * @param serverDefaults - Default values extracted from server props + * @param userOverrides - Set of parameter keys explicitly overridden by the user + * @returns Merged parameter record with user overrides preserved + */ + static mergeWithServerDefaults( + currentSettings: ParameterRecord, + serverDefaults: ParameterRecord, + userOverrides: Set = new Set() + ): ParameterRecord { + const merged = { ...currentSettings }; + + for (const [key, serverValue] of Object.entries(serverDefaults)) { + // Only update if user hasn't explicitly overridden this parameter + if (!userOverrides.has(key)) { + merged[key] = this.roundFloatingPoint(serverValue); + } + } + + return merged; + } + + /** + * + * + * Info + * + * + */ + + /** + * Get parameter information including source and values. + * Used by ChatSettingsParameterSourceIndicator to display the correct badge + * (Custom vs Default) for each parameter in the settings UI. + * + * @param key - The parameter key to get info for + * @param currentValue - The current value of the parameter + * @param propsDefaults - Server default values from `/props` + * @param userOverrides - Set of parameter keys explicitly overridden by the user + * @returns Parameter info with source, server default, and user override values + */ + static getParameterInfo( + key: string, + currentValue: ParameterValue, + propsDefaults: ParameterRecord, + userOverrides: Set + ): ParameterInfo { + const hasPropsDefault = propsDefaults[key] !== undefined; + const isUserOverride = userOverrides.has(key); + + // Simple logic: either using default (from props) or custom (user override) + const source = isUserOverride ? ParameterSource.CUSTOM : ParameterSource.DEFAULT; + + return { + value: currentValue, + source, + serverDefault: hasPropsDefault ? propsDefaults[key] : undefined, // Keep same field name for compatibility + userOverride: isUserOverride ? currentValue : undefined + }; + } + + /** + * Check if a parameter can be synced from server. + * + * @param key - The parameter key to check + * @returns True if the parameter is in the syncable parameters list + */ + static canSyncParameter(key: string): boolean { + return SYNCABLE_PARAMETERS.some((param) => param.key === key && param.canSync); + } + + /** + * Get all syncable parameter keys. + * + * @returns Array of parameter keys that can be synced from server + */ + static getSyncableParameterKeys(): string[] { + return SYNCABLE_PARAMETERS.filter((param) => param.canSync).map((param) => param.key); + } + + /** + * Validate a server parameter value against its expected type. + * + * @param key - The parameter key to validate + * @param value - The value to validate + * @returns True if value matches the expected type for this parameter + */ + static validateServerParameter(key: string, value: ParameterValue): boolean { + const param = SYNCABLE_PARAMETERS.find((p) => p.key === key); + if (!param) return false; + + switch (param.type) { + case SyncableParameterType.NUMBER: + return typeof value === 'number' && !isNaN(value); + case SyncableParameterType.STRING: + return typeof value === 'string'; + case SyncableParameterType.BOOLEAN: + return typeof value === 'boolean'; + default: + return false; + } + } + + /** + * + * + * Diff + * + * + */ + + /** + * Create a diff between current settings and server defaults. + * Shows which parameters differ from server values, useful for debugging + * and for the "Reset to defaults" functionality. + * + * @param currentSettings - Current parameter values in the settings store + * @param serverDefaults - Default values extracted from server props + * @returns Record of parameter diffs with current value, server value, and whether they differ + */ + static createParameterDiff( + currentSettings: ParameterRecord, + serverDefaults: ParameterRecord + ): Record { + const diff: Record< + string, + { current: ParameterValue; server: ParameterValue; differs: boolean } + > = {}; + + for (const key of this.getSyncableParameterKeys()) { + const currentValue = currentSettings[key]; + const serverValue = serverDefaults[key]; + + if (serverValue !== undefined) { + diff[key] = { + current: currentValue, + server: serverValue, + differs: currentValue !== serverValue + }; + } + } + + return diff; + } +} diff --git a/tools/server/webui/src/lib/services/parameter-sync.ts b/tools/server/webui/src/lib/services/parameter-sync.ts index d124cf5c8d..333260701f 100644 --- a/tools/server/webui/src/lib/services/parameter-sync.ts +++ b/tools/server/webui/src/lib/services/parameter-sync.ts @@ -70,12 +70,6 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [ canSync: true }, { key: 'showToolCalls', serverKey: 'showToolCalls', type: 'boolean', canSync: true }, - { - key: 'disableReasoningFormat', - serverKey: 'disableReasoningFormat', - type: 'boolean', - canSync: true - }, { key: 'keepStatsVisible', serverKey: 'keepStatsVisible', type: 'boolean', canSync: true }, { key: 'showMessageStats', serverKey: 'showMessageStats', type: 'boolean', canSync: true }, { diff --git a/tools/server/webui/src/lib/services/props.service.ts b/tools/server/webui/src/lib/services/props.service.ts new file mode 100644 index 0000000000..7373b7e016 --- /dev/null +++ b/tools/server/webui/src/lib/services/props.service.ts @@ -0,0 +1,47 @@ +import { apiFetchWithParams } from '$lib/utils/api-fetch'; + +export class PropsService { + /** + * + * + * Fetching + * + * + */ + + /** + * Fetches global server properties from the `/props` endpoint. + * In MODEL mode, returns modalities for the single loaded model. + * In ROUTER mode, returns server-wide settings without model-specific modalities. + * + * @param autoload - If false, prevents automatic model loading (default: false) + * @returns Server properties including default generation settings and capabilities + * @throws {Error} If the request fails or returns invalid data + */ + static async fetch(autoload = false): Promise { + const params: Record = {}; + if (!autoload) { + params.autoload = 'false'; + } + + return apiFetchWithParams('./props', params, { authOnly: true }); + } + + /** + * Fetches server properties for a specific model (ROUTER mode only). + * Required in ROUTER mode because global `/props` does not include per-model modalities. + * + * @param modelId - The model ID to fetch properties for + * @param autoload - If false, prevents automatic model loading (default: false) + * @returns Server properties specific to the requested model + * @throws {Error} If the request fails, model not found, or model not loaded + */ + static async fetchForModel(modelId: string, autoload = false): Promise { + const params: Record = { model: modelId }; + if (!autoload) { + params.autoload = 'false'; + } + + return apiFetchWithParams('./props', params, { authOnly: true }); + } +} diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts index 879b2f3245..f00f418b4c 100644 --- a/tools/server/webui/src/lib/stores/chat.svelte.ts +++ b/tools/server/webui/src/lib/stores/chat.svelte.ts @@ -15,6 +15,7 @@ import { } from '$lib/utils'; import { SvelteMap } from 'svelte/reactivity'; import { DEFAULT_CONTEXT } from '$lib/constants/default-context'; +import { SYSTEM_MESSAGE_PLACEHOLDER } from '$lib/constants/ui'; /** * chatStore - Active AI interaction and streaming state management @@ -76,6 +77,10 @@ class ChatStore { private isStreamingActive = $state(false); private isEditModeActive = $state(false); private addFilesHandler: ((files: File[]) => void) | null = $state(null); + pendingEditMessageId = $state(null); + // Draft preservation for navigation (e.g., when adding system prompt from welcome page) + private _pendingDraftMessage = $state(''); + private _pendingDraftFiles = $state([]); // ───────────────────────────────────────────────────────────────────────────── // Loading State @@ -113,6 +118,16 @@ class ChatStore { this.isLoading = this.isChatLoading(convId); const streamingState = this.getChatStreaming(convId); this.currentResponse = streamingState?.response || ''; + this.isStreamingActive = streamingState !== undefined; + this.setActiveProcessingConversation(convId); + + // Sync streaming content to activeMessages so UI displays current content + if (streamingState?.response && streamingState?.messageId) { + const idx = conversationsStore.findMessageIndex(streamingState.messageId); + if (idx !== -1) { + conversationsStore.updateMessageAtIndex(idx, { content: streamingState.response }); + } + } } /** @@ -455,6 +470,166 @@ class ChatStore { } } + /** + * Adds a system message at the top of a conversation and triggers edit mode. + * The system message is inserted between root and the first message of the active branch. + * Creates a new conversation if one doesn't exist. + */ + async addSystemPrompt(): Promise { + let activeConv = conversationsStore.activeConversation; + + // Create conversation if needed + if (!activeConv) { + await conversationsStore.createConversation(); + activeConv = conversationsStore.activeConversation; + } + if (!activeConv) return; + + try { + // Get all messages to find the root + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); + const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); + let rootId: string; + + // Create root message if it doesn't exist + if (!rootMessage) { + rootId = await DatabaseService.createRootMessage(activeConv.id); + } else { + rootId = rootMessage.id; + } + + // Check if there's already a system message as root's child + const existingSystemMessage = allMessages.find( + (m) => m.role === 'system' && m.parent === rootId + ); + + if (existingSystemMessage) { + // If system message exists, just trigger edit mode on it + this.pendingEditMessageId = existingSystemMessage.id; + + // Make sure it's in active messages at the beginning + if (!conversationsStore.activeMessages.some((m) => m.id === existingSystemMessage.id)) { + conversationsStore.activeMessages.unshift(existingSystemMessage); + } + return; + } + + // Find the first message of the active branch (child of root that's in activeMessages) + const activeMessages = conversationsStore.activeMessages; + const firstActiveMessage = activeMessages.find((m) => m.parent === rootId); + + // Create new system message with placeholder content (will be edited by user) + const systemMessage = await DatabaseService.createSystemMessage( + activeConv.id, + SYSTEM_MESSAGE_PLACEHOLDER, + rootId + ); + + // If there's a first message in the active branch, re-parent it to the system message + if (firstActiveMessage) { + // Update the first message's parent to be the system message + await DatabaseService.updateMessage(firstActiveMessage.id, { + parent: systemMessage.id + }); + + // Update the system message's children to include the first message + await DatabaseService.updateMessage(systemMessage.id, { + children: [firstActiveMessage.id] + }); + + // Remove first message from root's children + const updatedRootChildren = rootMessage + ? rootMessage.children.filter((id: string) => id !== firstActiveMessage.id) + : []; + // Note: system message was already added to root's children by createSystemMessage + await DatabaseService.updateMessage(rootId, { + children: [ + ...updatedRootChildren.filter((id: string) => id !== systemMessage.id), + systemMessage.id + ] + }); + + // Update local state + const firstMsgIndex = conversationsStore.findMessageIndex(firstActiveMessage.id); + if (firstMsgIndex !== -1) { + conversationsStore.updateMessageAtIndex(firstMsgIndex, { parent: systemMessage.id }); + } + } + + // Add system message to active messages at the beginning + conversationsStore.activeMessages.unshift(systemMessage); + + // Set pending edit message ID to trigger edit mode + this.pendingEditMessageId = systemMessage.id; + + conversationsStore.updateConversationTimestamp(); + } catch (error) { + console.error('Failed to add system prompt:', error); + } + } + + /** + * Removes a system message placeholder without deleting its children. + * Re-parents children back to the root message. + * If this is a new empty conversation (only root + system placeholder), deletes the entire conversation. + * @returns true if the entire conversation was deleted, false otherwise + */ + async removeSystemPromptPlaceholder(messageId: string): Promise { + const activeConv = conversationsStore.activeConversation; + if (!activeConv) return false; + + try { + const allMessages = await conversationsStore.getConversationMessages(activeConv.id); + const systemMessage = allMessages.find((m) => m.id === messageId); + if (!systemMessage || systemMessage.role !== 'system') return false; + + const rootMessage = allMessages.find((m) => m.type === 'root' && m.parent === null); + if (!rootMessage) return false; + + // Check if this is a new empty conversation (only root + system placeholder) + const isEmptyConversation = allMessages.length === 2 && systemMessage.children.length === 0; + + if (isEmptyConversation) { + // Delete the entire conversation + await conversationsStore.deleteConversation(activeConv.id); + return true; + } + + // Re-parent system message's children to root + for (const childId of systemMessage.children) { + await DatabaseService.updateMessage(childId, { parent: rootMessage.id }); + + // Update local state + const childIndex = conversationsStore.findMessageIndex(childId); + if (childIndex !== -1) { + conversationsStore.updateMessageAtIndex(childIndex, { parent: rootMessage.id }); + } + } + + // Update root's children: remove system message, add system's children + const newRootChildren = [ + ...rootMessage.children.filter((id: string) => id !== messageId), + ...systemMessage.children + ]; + await DatabaseService.updateMessage(rootMessage.id, { children: newRootChildren }); + + // Delete the system message (without cascade) + await DatabaseService.deleteMessage(messageId); + + // Remove from active messages + const systemIndex = conversationsStore.findMessageIndex(messageId); + if (systemIndex !== -1) { + conversationsStore.activeMessages.splice(systemIndex, 1); + } + + conversationsStore.updateConversationTimestamp(); + return false; + } catch (error) { + console.error('Failed to remove system prompt placeholder:', error); + return false; + } + } + private async createAssistantMessage(parentId?: string): Promise { const activeConv = conversationsStore.activeConversation; if (!activeConv) return null; @@ -916,6 +1091,28 @@ class ChatStore { if (!activeConv) return { totalCount: 0, userMessages: 0, assistantMessages: 0, messageTypes: [] }; const allMessages = await conversationsStore.getConversationMessages(activeConv.id); + const messageToDelete = allMessages.find((m) => m.id === messageId); + + // For system messages, don't count descendants as they will be preserved (reparented to root) + if (messageToDelete?.role === 'system') { + const messagesToDelete = allMessages.filter((m) => m.id === messageId); + let userMessages = 0, + assistantMessages = 0; + const messageTypes: string[] = []; + + for (const msg of messagesToDelete) { + if (msg.role === 'user') { + userMessages++; + if (!messageTypes.includes('user message')) messageTypes.push('user message'); + } else if (msg.role === 'assistant') { + assistantMessages++; + if (!messageTypes.includes('assistant response')) messageTypes.push('assistant response'); + } + } + + return { totalCount: 1, userMessages, assistantMessages, messageTypes }; + } + const descendants = findDescendantMessages(allMessages, messageId); const allToDelete = [messageId, ...descendants]; const messagesToDelete = allMessages.filter((m) => allToDelete.includes(m.id)); @@ -1381,6 +1578,31 @@ class ChatStore { return this.addFilesHandler; } + savePendingDraft(message: string, files: ChatUploadedFile[]): void { + this._pendingDraftMessage = message; + this._pendingDraftFiles = [...files]; + } + + consumePendingDraft(): { message: string; files: ChatUploadedFile[] } | null { + if (!this._pendingDraftMessage && this._pendingDraftFiles.length === 0) { + return null; + } + + const draft = { + message: this._pendingDraftMessage, + files: [...this._pendingDraftFiles] + }; + + this._pendingDraftMessage = ''; + this._pendingDraftFiles = []; + + return draft; + } + + hasPendingDraft(): boolean { + return Boolean(this._pendingDraftMessage) || this._pendingDraftFiles.length > 0; + } + public getAllLoadingChats(): string[] { return Array.from(this.chatLoadingStates.keys()); } @@ -1427,7 +1649,7 @@ class ChatStore { // Config options needed by ChatService if (currentConfig.systemMessage) apiOptions.systemMessage = currentConfig.systemMessage; - if (currentConfig.disableReasoningFormat) apiOptions.disableReasoningFormat = true; + if (currentConfig.disableReasoningParsing) apiOptions.disableReasoningParsing = true; if (hasValue(currentConfig.temperature)) apiOptions.temperature = Number(currentConfig.temperature); @@ -1485,3 +1707,7 @@ export const isEditing = () => chatStore.isEditing(); export const isLoading = () => chatStore.isLoading; export const setEditModeActive = (handler: (files: File[]) => void) => chatStore.setEditModeActive(handler); +export const pendingEditMessageId = () => chatStore.pendingEditMessageId; +export const clearPendingEditMessageId = () => (chatStore.pendingEditMessageId = null); +export const removeSystemPromptPlaceholder = (messageId: string) => + chatStore.removeSystemPromptPlaceholder(messageId); diff --git a/tools/server/webui/src/lib/types/api.d.ts b/tools/server/webui/src/lib/types/api.d.ts index 714509f024..307e3b71d9 100644 --- a/tools/server/webui/src/lib/types/api.d.ts +++ b/tools/server/webui/src/lib/types/api.d.ts @@ -1,8 +1,19 @@ -import type { ServerModelStatus, ServerRole } from '$lib/enums'; -import type { ChatMessagePromptProgress } from './chat'; +import type { ContentPartType, ServerModelStatus, ServerRole } from '$lib/enums'; +import type { ChatMessagePromptProgress, ChatRole } from './chat'; + +export interface ApiChatCompletionToolFunction { + name: string; + description?: string; + parameters: Record; +} + +export interface ApiChatCompletionTool { + type: 'function'; + function: ApiChatCompletionToolFunction; +} export interface ApiChatMessageContentPart { - type: 'text' | 'image_url' | 'input_audio'; + type: ContentPartType; text?: string; image_url?: { url: string; @@ -34,6 +45,8 @@ export interface ApiErrorResponse { export interface ApiChatMessageData { role: ChatRole; content: string | ApiChatMessageContentPart[]; + tool_calls?: ApiChatCompletionToolCall[]; + tool_call_id?: string; timestamp?: number; } @@ -188,6 +201,7 @@ export interface ApiChatCompletionRequest { stream?: boolean; model?: string; return_progress?: boolean; + tools?: ApiChatCompletionTool[]; // Reasoning parameters reasoning_format?: string; // Generation parameters @@ -247,6 +261,7 @@ export interface ApiChatCompletionStreamChunk { model?: string; tool_calls?: ApiChatCompletionToolCallDelta[]; }; + finish_reason?: string | null; }>; timings?: { prompt_n?: number; @@ -267,8 +282,9 @@ export interface ApiChatCompletionResponse { content: string; reasoning_content?: string; model?: string; - tool_calls?: ApiChatCompletionToolCallDelta[]; + tool_calls?: ApiChatCompletionToolCall[]; }; + finish_reason?: string | null; }>; } @@ -335,7 +351,7 @@ export interface ApiProcessingState { tokensDecoded: number; tokensRemaining: number; contextUsed: number; - contextTotal: number; + contextTotal: number | null; outputTokensUsed: number; // Total output tokens (thinking + regular content) outputTokensMax: number; // Max output tokens allowed temperature: number; diff --git a/tools/server/webui/src/lib/types/models.d.ts b/tools/server/webui/src/lib/types/models.d.ts index ef44a2cb6d..505867a1f0 100644 --- a/tools/server/webui/src/lib/types/models.d.ts +++ b/tools/server/webui/src/lib/types/models.d.ts @@ -1,8 +1,5 @@ import type { ApiModelDataEntry, ApiModelDetails } from '$lib/types/api'; -/** - * Model modalities - vision and audio capabilities - */ export interface ModelModalities { vision: boolean; audio: boolean; @@ -14,8 +11,15 @@ export interface ModelOption { model: string; description?: string; capabilities: string[]; - /** Model modalities from /props endpoint */ modalities?: ModelModalities; details?: ApiModelDetails['details']; meta?: ApiModelDataEntry['meta']; } + +/** + * Modality capabilities for file validation + */ +export interface ModalityCapabilities { + hasVision: boolean; + hasAudio: boolean; +} diff --git a/tools/server/webui/src/lib/types/settings.d.ts b/tools/server/webui/src/lib/types/settings.d.ts index 38b3047dd0..d894245ec3 100644 --- a/tools/server/webui/src/lib/types/settings.d.ts +++ b/tools/server/webui/src/lib/types/settings.d.ts @@ -18,8 +18,8 @@ export interface SettingsChatServiceOptions { model?: string; // System message to inject systemMessage?: string; - // Disable reasoning format (use 'none' instead of 'auto') - disableReasoningFormat?: boolean; + // Disable reasoning parsing (use 'none' instead of 'auto') + disableReasoningParsing?: boolean; // Generation parameters temperature?: number; max_tokens?: number; diff --git a/tools/server/webui/src/lib/utils/abort.ts b/tools/server/webui/src/lib/utils/abort.ts new file mode 100644 index 0000000000..fc4f31ec69 --- /dev/null +++ b/tools/server/webui/src/lib/utils/abort.ts @@ -0,0 +1,151 @@ +/** + * Abort Signal Utilities + * + * Provides utilities for consistent AbortSignal propagation across the application. + * These utilities help ensure that async operations can be properly cancelled + * when needed (e.g., user stops generation, navigates away, etc.). + */ + +/** + * Throws an AbortError if the signal is aborted. + * Use this at the start of async operations to fail fast. + * + * @param signal - Optional AbortSignal to check + * @throws DOMException with name 'AbortError' if signal is aborted + * + * @example + * ```ts + * async function fetchData(signal?: AbortSignal) { + * throwIfAborted(signal); + * // ... proceed with operation + * } + * ``` + */ +export function throwIfAborted(signal?: AbortSignal): void { + if (signal?.aborted) { + throw new DOMException('Operation was aborted', 'AbortError'); + } +} + +/** + * Checks if an error is an AbortError. + * Use this to distinguish between user-initiated cancellation and actual errors. + * + * @param error - Error to check + * @returns true if the error is an AbortError + * + * @example + * ```ts + * try { + * await fetchData(signal); + * } catch (error) { + * if (isAbortError(error)) { + * // User cancelled - no error dialog needed + * return; + * } + * // Handle actual error + * } + * ``` + */ +export function isAbortError(error: unknown): boolean { + if (error instanceof DOMException && error.name === 'AbortError') { + return true; + } + if (error instanceof Error && error.name === 'AbortError') { + return true; + } + return false; +} + +/** + * Creates a new AbortController that is linked to one or more parent signals. + * When any parent signal aborts, the returned controller also aborts. + * + * Useful for creating child operations that should be cancelled when + * either the parent operation or their own timeout/condition triggers. + * + * @param signals - Parent signals to link to (undefined signals are ignored) + * @returns A new AbortController linked to all provided signals + * + * @example + * ```ts + * // Link to user's abort signal and add a timeout + * const linked = createLinkedController(userSignal, timeoutSignal); + * await fetch(url, { signal: linked.signal }); + * ``` + */ +export function createLinkedController(...signals: (AbortSignal | undefined)[]): AbortController { + const controller = new AbortController(); + + for (const signal of signals) { + if (!signal) continue; + + // If already aborted, abort immediately + if (signal.aborted) { + controller.abort(signal.reason); + return controller; + } + + // Link to parent signal + signal.addEventListener('abort', () => controller.abort(signal.reason), { once: true }); + } + + return controller; +} + +/** + * Creates an AbortSignal that times out after the specified duration. + * + * @param ms - Timeout duration in milliseconds + * @returns AbortSignal that will abort after the timeout + * + * @example + * ```ts + * const signal = createTimeoutSignal(5000); // 5 second timeout + * await fetch(url, { signal }); + * ``` + */ +export function createTimeoutSignal(ms: number): AbortSignal { + return AbortSignal.timeout(ms); +} + +/** + * Wraps a promise to reject if the signal is aborted. + * Useful for making non-abortable promises respect an AbortSignal. + * + * @param promise - Promise to wrap + * @param signal - AbortSignal to respect + * @returns Promise that rejects with AbortError if signal aborts + * + * @example + * ```ts + * // Make a non-abortable operation respect abort signal + * const result = await withAbortSignal( + * someNonAbortableOperation(), + * signal + * ); + * ``` + */ +export async function withAbortSignal(promise: Promise, signal?: AbortSignal): Promise { + if (!signal) return promise; + + throwIfAborted(signal); + + return new Promise((resolve, reject) => { + const abortHandler = () => { + reject(new DOMException('Operation was aborted', 'AbortError')); + }; + + signal.addEventListener('abort', abortHandler, { once: true }); + + promise + .then((value) => { + signal.removeEventListener('abort', abortHandler); + resolve(value); + }) + .catch((error) => { + signal.removeEventListener('abort', abortHandler); + reject(error); + }); + }); +} diff --git a/tools/server/webui/src/lib/utils/api-fetch.ts b/tools/server/webui/src/lib/utils/api-fetch.ts new file mode 100644 index 0000000000..28757a966f --- /dev/null +++ b/tools/server/webui/src/lib/utils/api-fetch.ts @@ -0,0 +1,154 @@ +import { base } from '$app/paths'; +import { getJsonHeaders, getAuthHeaders } from './api-headers'; + +/** + * API Fetch Utilities + * + * Provides common fetch patterns used across services: + * - Automatic JSON headers + * - Error handling with proper error messages + * - Base path resolution + */ + +export interface ApiFetchOptions extends Omit { + /** + * Use auth-only headers (no Content-Type). + * Default: false (uses JSON headers with Content-Type: application/json) + */ + authOnly?: boolean; + /** + * Additional headers to merge with default headers. + */ + headers?: Record; +} + +/** + * Fetch JSON data from an API endpoint with standard headers and error handling. + * + * @param path - API path (will be prefixed with base path) + * @param options - Fetch options with additional authOnly flag + * @returns Parsed JSON response + * @throws Error with formatted message on failure + * + * @example + * ```typescript + * // GET request + * const models = await apiFetch('/v1/models'); + * + * // POST request + * const result = await apiFetch('/models/load', { + * method: 'POST', + * body: JSON.stringify({ model: 'gpt-4' }) + * }); + * ``` + */ +export async function apiFetch(path: string, options: ApiFetchOptions = {}): Promise { + const { authOnly = false, headers: customHeaders, ...fetchOptions } = options; + + const baseHeaders = authOnly ? getAuthHeaders() : getJsonHeaders(); + const headers = { ...baseHeaders, ...customHeaders }; + + const url = path.startsWith('http://') || path.startsWith('https://') ? path : `${base}${path}`; + + const response = await fetch(url, { + ...fetchOptions, + headers + }); + + if (!response.ok) { + const errorMessage = await parseErrorMessage(response); + throw new Error(errorMessage); + } + + return response.json() as Promise; +} + +/** + * Fetch with URL constructed from base URL and query parameters. + * + * @param basePath - Base API path + * @param params - Query parameters to append + * @param options - Fetch options + * @returns Parsed JSON response + * + * @example + * ```typescript + * const props = await apiFetchWithParams('./props', { + * model: 'gpt-4', + * autoload: 'false' + * }); + * ``` + */ +export async function apiFetchWithParams( + basePath: string, + params: Record, + options: ApiFetchOptions = {} +): Promise { + const url = new URL(basePath, window.location.href); + + for (const [key, value] of Object.entries(params)) { + if (value !== undefined && value !== null) { + url.searchParams.set(key, value); + } + } + + const { authOnly = false, headers: customHeaders, ...fetchOptions } = options; + + const baseHeaders = authOnly ? getAuthHeaders() : getJsonHeaders(); + const headers = { ...baseHeaders, ...customHeaders }; + + const response = await fetch(url.toString(), { + ...fetchOptions, + headers + }); + + if (!response.ok) { + const errorMessage = await parseErrorMessage(response); + throw new Error(errorMessage); + } + + return response.json() as Promise; +} + +/** + * POST JSON data to an API endpoint. + * + * @param path - API path + * @param body - Request body (will be JSON stringified) + * @param options - Additional fetch options + * @returns Parsed JSON response + */ +export async function apiPost( + path: string, + body: B, + options: ApiFetchOptions = {} +): Promise { + return apiFetch(path, { + method: 'POST', + body: JSON.stringify(body), + ...options + }); +} + +/** + * Parse error message from a failed response. + * Tries to extract error message from JSON body, falls back to status text. + */ +async function parseErrorMessage(response: Response): Promise { + try { + const errorData = await response.json(); + if (errorData?.error?.message) { + return errorData.error.message; + } + if (errorData?.error && typeof errorData.error === 'string') { + return errorData.error; + } + if (errorData?.message) { + return errorData.message; + } + } catch { + // JSON parsing failed, use status text + } + + return `Request failed: ${response.status} ${response.statusText}`; +} diff --git a/tools/server/webui/src/lib/utils/branching.ts b/tools/server/webui/src/lib/utils/branching.ts index 3be56047a5..ee3a505eed 100644 --- a/tools/server/webui/src/lib/utils/branching.ts +++ b/tools/server/webui/src/lib/utils/branching.ts @@ -15,6 +15,8 @@ * └── message 5 (assistant) */ +import { MessageRole } from '$lib/enums/chat'; + /** * Filters messages to get the conversation path from root to a specific leaf node. * If the leafNodeId doesn't exist, returns the path with the latest timestamp. @@ -65,8 +67,13 @@ export function filterByLeafNodeId( currentNode = nodeMap.get(currentNode.parent); } - // Sort by timestamp to get chronological order (root to leaf) - result.sort((a, b) => a.timestamp - b.timestamp); + // Sort: system messages first, then by timestamp + result.sort((a, b) => { + if (a.role === MessageRole.SYSTEM && b.role !== MessageRole.SYSTEM) return -1; + if (a.role !== MessageRole.SYSTEM && b.role === MessageRole.SYSTEM) return 1; + + return a.timestamp - b.timestamp; + }); return result; } diff --git a/tools/server/webui/src/lib/utils/browser-only.ts b/tools/server/webui/src/lib/utils/browser-only.ts index 0af800638b..27d2be4aaa 100644 --- a/tools/server/webui/src/lib/utils/browser-only.ts +++ b/tools/server/webui/src/lib/utils/browser-only.ts @@ -23,7 +23,7 @@ export { } from './pdf-processing'; // File conversion utilities (depends on pdf-processing) -export { parseFilesToMessageExtras, type FileProcessingResult } from './convert-files-to-extra'; +export { parseFilesToMessageExtras } from './convert-files-to-extra'; // File upload processing utilities (depends on pdf-processing, svg-to-png, webp-to-png) export { processFilesToChatUploaded } from './process-uploaded-files'; diff --git a/tools/server/webui/src/lib/utils/cache-ttl.ts b/tools/server/webui/src/lib/utils/cache-ttl.ts new file mode 100644 index 0000000000..9d1f005822 --- /dev/null +++ b/tools/server/webui/src/lib/utils/cache-ttl.ts @@ -0,0 +1,293 @@ +const DEFAULT_CACHE_TTL_MS = 5 * 60 * 1000; +const DEFAULT_CACHE_MAX_ENTRIES = 100; + +/** + * TTL Cache - Time-To-Live cache implementation for memory optimization + * + * Provides automatic expiration of cached entries to prevent memory bloat + * in long-running sessions. + * + * @example + * ```ts + * const cache = new TTLCache({ ttlMs: 5 * 60 * 1000 }); // 5 minutes + * cache.set('key', data); + * const value = cache.get('key'); // null if expired + * ``` + */ + +export interface TTLCacheOptions { + /** Time-to-live in milliseconds. Default: 5 minutes */ + ttlMs?: number; + /** Maximum number of entries. Oldest entries are evicted when exceeded. Default: 100 */ + maxEntries?: number; + /** Callback when an entry expires or is evicted */ + onEvict?: (key: string, value: unknown) => void; +} + +interface CacheEntry { + value: T; + expiresAt: number; + lastAccessed: number; +} + +export class TTLCache { + private cache = new Map>(); + private readonly ttlMs: number; + private readonly maxEntries: number; + private readonly onEvict?: (key: string, value: unknown) => void; + + constructor(options: TTLCacheOptions = {}) { + this.ttlMs = options.ttlMs ?? DEFAULT_CACHE_TTL_MS; + this.maxEntries = options.maxEntries ?? DEFAULT_CACHE_MAX_ENTRIES; + this.onEvict = options.onEvict; + } + + /** + * Get a value from cache. Returns null if expired or not found. + */ + get(key: K): V | null { + const entry = this.cache.get(key); + if (!entry) return null; + + if (Date.now() > entry.expiresAt) { + this.delete(key); + return null; + } + + // Update last accessed time for LRU-like behavior + entry.lastAccessed = Date.now(); + return entry.value; + } + + /** + * Set a value in cache with TTL. + */ + set(key: K, value: V, customTtlMs?: number): void { + // Evict oldest entries if at capacity + if (this.cache.size >= this.maxEntries && !this.cache.has(key)) { + this.evictOldest(); + } + + const ttl = customTtlMs ?? this.ttlMs; + const now = Date.now(); + + this.cache.set(key, { + value, + expiresAt: now + ttl, + lastAccessed: now + }); + } + + /** + * Check if key exists and is not expired. + */ + has(key: K): boolean { + const entry = this.cache.get(key); + if (!entry) return false; + + if (Date.now() > entry.expiresAt) { + this.delete(key); + return false; + } + + return true; + } + + /** + * Delete a specific key from cache. + */ + delete(key: K): boolean { + const entry = this.cache.get(key); + if (entry && this.onEvict) { + this.onEvict(key, entry.value); + } + return this.cache.delete(key); + } + + /** + * Clear all entries from cache. + */ + clear(): void { + if (this.onEvict) { + for (const [key, entry] of this.cache) { + this.onEvict(key, entry.value); + } + } + this.cache.clear(); + } + + /** + * Get the number of entries (including potentially expired ones). + */ + get size(): number { + return this.cache.size; + } + + /** + * Remove all expired entries from cache. + * Call periodically for proactive cleanup. + */ + prune(): number { + const now = Date.now(); + let pruned = 0; + + for (const [key, entry] of this.cache) { + if (now > entry.expiresAt) { + this.delete(key); + pruned++; + } + } + + return pruned; + } + + /** + * Get all valid (non-expired) keys. + */ + keys(): K[] { + const now = Date.now(); + const validKeys: K[] = []; + + for (const [key, entry] of this.cache) { + if (now <= entry.expiresAt) { + validKeys.push(key); + } + } + + return validKeys; + } + + /** + * Evict the oldest (least recently accessed) entry. + */ + private evictOldest(): void { + let oldestKey: K | null = null; + let oldestTime = Infinity; + + for (const [key, entry] of this.cache) { + if (entry.lastAccessed < oldestTime) { + oldestTime = entry.lastAccessed; + oldestKey = key; + } + } + + if (oldestKey !== null) { + this.delete(oldestKey); + } + } + + /** + * Refresh TTL for an existing entry without changing the value. + */ + touch(key: K): boolean { + const entry = this.cache.get(key); + if (!entry) return false; + + const now = Date.now(); + if (now > entry.expiresAt) { + this.delete(key); + return false; + } + + entry.expiresAt = now + this.ttlMs; + entry.lastAccessed = now; + return true; + } +} + +/** + * Reactive TTL Map for Svelte stores + * Wraps SvelteMap with TTL functionality + */ +export class ReactiveTTLMap { + private entries = $state>>(new Map()); + private readonly ttlMs: number; + private readonly maxEntries: number; + + constructor(options: TTLCacheOptions = {}) { + this.ttlMs = options.ttlMs ?? DEFAULT_CACHE_TTL_MS; + this.maxEntries = options.maxEntries ?? DEFAULT_CACHE_MAX_ENTRIES; + } + + get(key: K): V | null { + const entry = this.entries.get(key); + if (!entry) return null; + + if (Date.now() > entry.expiresAt) { + this.entries.delete(key); + return null; + } + + entry.lastAccessed = Date.now(); + return entry.value; + } + + set(key: K, value: V, customTtlMs?: number): void { + if (this.entries.size >= this.maxEntries && !this.entries.has(key)) { + this.evictOldest(); + } + + const ttl = customTtlMs ?? this.ttlMs; + const now = Date.now(); + + this.entries.set(key, { + value, + expiresAt: now + ttl, + lastAccessed: now + }); + } + + has(key: K): boolean { + const entry = this.entries.get(key); + if (!entry) return false; + + if (Date.now() > entry.expiresAt) { + this.entries.delete(key); + return false; + } + + return true; + } + + delete(key: K): boolean { + return this.entries.delete(key); + } + + clear(): void { + this.entries.clear(); + } + + get size(): number { + return this.entries.size; + } + + prune(): number { + const now = Date.now(); + let pruned = 0; + + for (const [key, entry] of this.entries) { + if (now > entry.expiresAt) { + this.entries.delete(key); + pruned++; + } + } + + return pruned; + } + + private evictOldest(): void { + let oldestKey: K | null = null; + let oldestTime = Infinity; + + for (const [key, entry] of this.entries) { + if (entry.lastAccessed < oldestTime) { + oldestTime = entry.lastAccessed; + oldestKey = key; + } + } + + if (oldestKey !== null) { + this.entries.delete(oldestKey); + } + } +} diff --git a/tools/server/webui/src/lib/utils/code.ts b/tools/server/webui/src/lib/utils/code.ts new file mode 100644 index 0000000000..67efc6b27e --- /dev/null +++ b/tools/server/webui/src/lib/utils/code.ts @@ -0,0 +1,85 @@ +import hljs from 'highlight.js'; +import { + NEWLINE, + DEFAULT_LANGUAGE, + LANG_PATTERN, + AMPERSAND_REGEX, + LT_REGEX, + GT_REGEX, + FENCE_PATTERN +} from '$lib/constants/code'; + +export interface IncompleteCodeBlock { + language: string; + code: string; + openingIndex: number; +} + +/** + * Highlights code using highlight.js + * @param code - The code to highlight + * @param language - The programming language + * @returns HTML string with syntax highlighting + */ +export function highlightCode(code: string, language: string): string { + if (!code) return ''; + + try { + const lang = language.toLowerCase(); + const isSupported = hljs.getLanguage(lang); + + if (isSupported) { + return hljs.highlight(code, { language: lang }).value; + } else { + return hljs.highlightAuto(code).value; + } + } catch { + // Fallback to escaped plain text + return code + .replace(AMPERSAND_REGEX, '&') + .replace(LT_REGEX, '<') + .replace(GT_REGEX, '>'); + } +} + +/** + * Detects if markdown ends with an incomplete code block (opened but not closed). + * Returns the code block info if found, null otherwise. + * @param markdown - The raw markdown string to check + * @returns IncompleteCodeBlock info or null + */ +export function detectIncompleteCodeBlock(markdown: string): IncompleteCodeBlock | null { + // Count all code fences in the markdown + // A code block is incomplete if there's an odd number of ``` fences + const fencePattern = new RegExp(FENCE_PATTERN.source, FENCE_PATTERN.flags); + const fences: number[] = []; + let fenceMatch; + + while ((fenceMatch = fencePattern.exec(markdown)) !== null) { + // Store the position after the ``` + const pos = fenceMatch[0].startsWith(NEWLINE) ? fenceMatch.index + 1 : fenceMatch.index; + fences.push(pos); + } + + // If even number of fences (including 0), all code blocks are closed + if (fences.length % 2 === 0) { + return null; + } + + // Odd number means last code block is incomplete + // The last fence is the opening of the incomplete block + const openingIndex = fences[fences.length - 1]; + const afterOpening = markdown.slice(openingIndex + 3); + + // Extract language and code content + const langMatch = afterOpening.match(LANG_PATTERN); + const language = langMatch?.[1] || DEFAULT_LANGUAGE; + const codeStartIndex = openingIndex + 3 + (langMatch?.[0]?.length ?? 0); + const code = markdown.slice(codeStartIndex); + + return { + language, + code, + openingIndex + }; +} diff --git a/tools/server/webui/src/lib/utils/data-url.ts b/tools/server/webui/src/lib/utils/data-url.ts new file mode 100644 index 0000000000..6f55be793d --- /dev/null +++ b/tools/server/webui/src/lib/utils/data-url.ts @@ -0,0 +1,10 @@ +/** + * Creates a base64 data URL from MIME type and base64-encoded data. + * + * @param mimeType - The MIME type (e.g., 'image/png', 'audio/mp3') + * @param base64Data - The base64-encoded data + * @returns A data URL string in format 'data:{mimeType};base64,{data}' + */ +export function createBase64DataUrl(mimeType: string, base64Data: string): string { + return `data:${mimeType};base64,${base64Data}`; +} diff --git a/tools/server/webui/src/lib/utils/debounce.ts b/tools/server/webui/src/lib/utils/debounce.ts new file mode 100644 index 0000000000..90a5a01783 --- /dev/null +++ b/tools/server/webui/src/lib/utils/debounce.ts @@ -0,0 +1,22 @@ +/** + * @param fn - The function to debounce + * @param delay - The delay in milliseconds + * @returns A debounced version of the function + */ +export function debounce) => void>( + fn: T, + delay: number +): (...args: Parameters) => void { + let timeoutId: ReturnType | null = null; + + return (...args: Parameters) => { + if (timeoutId) { + clearTimeout(timeoutId); + } + + timeoutId = setTimeout(() => { + fn(...args); + timeoutId = null; + }, delay); + }; +} diff --git a/tools/server/webui/src/lib/utils/formatters.ts b/tools/server/webui/src/lib/utils/formatters.ts index ae9f59a39c..bdf2ca26fd 100644 --- a/tools/server/webui/src/lib/utils/formatters.ts +++ b/tools/server/webui/src/lib/utils/formatters.ts @@ -51,3 +51,75 @@ export function formatNumber(num: number | unknown): string { return num.toLocaleString(); } + +/** + * Format JSON string with pretty printing (2-space indentation) + * Returns original string if parsing fails + * + * @param jsonString - JSON string to format + * @returns Pretty-printed JSON string or original if invalid + */ +export function formatJsonPretty(jsonString: string): string { + try { + const parsed = JSON.parse(jsonString); + return JSON.stringify(parsed, null, 2); + } catch { + return jsonString; + } +} + +/** + * Format time as HH:MM:SS in 24-hour format + * + * @param date - Date object to format + * @returns Formatted time string (HH:MM:SS) + */ +export function formatTime(date: Date): string { + return date.toLocaleTimeString('en-US', { + hour12: false, + hour: '2-digit', + minute: '2-digit', + second: '2-digit' + }); +} + +/** + * Formats milliseconds to a human-readable time string for performance metrics. + * Examples: "4h 12min 54s", "12min 34s", "45s", "0.5s" + * + * @param ms - Time in milliseconds + * @returns Formatted time string + */ +export function formatPerformanceTime(ms: number): string { + if (ms < 0) return '0s'; + + const totalSeconds = ms / 1000; + + if (totalSeconds < 1) { + return `${totalSeconds.toFixed(1)}s`; + } + + if (totalSeconds < 10) { + return `${totalSeconds.toFixed(1)}s`; + } + + const hours = Math.floor(totalSeconds / 3600); + const minutes = Math.floor((totalSeconds % 3600) / 60); + const seconds = Math.floor(totalSeconds % 60); + + const parts: string[] = []; + + if (hours > 0) { + parts.push(`${hours}h`); + } + + if (minutes > 0) { + parts.push(`${minutes}min`); + } + + if (seconds > 0 || parts.length === 0) { + parts.push(`${seconds}s`); + } + + return parts.join(' '); +} diff --git a/tools/server/webui/src/lib/utils/image-error-fallback.ts b/tools/server/webui/src/lib/utils/image-error-fallback.ts new file mode 100644 index 0000000000..6e3260f4ae --- /dev/null +++ b/tools/server/webui/src/lib/utils/image-error-fallback.ts @@ -0,0 +1,10 @@ +/** + * Simplified HTML fallback for external images that fail to load. + * Displays a centered message with a link to open the image in a new tab. + */ +export function getImageErrorFallbackHtml(src: string): string { + return `
+ Image cannot be displayed + (open link) +
`; +} diff --git a/tools/server/webui/src/lib/utils/index.ts b/tools/server/webui/src/lib/utils/index.ts index 588167b8ca..5eb2bbaea1 100644 --- a/tools/server/webui/src/lib/utils/index.ts +++ b/tools/server/webui/src/lib/utils/index.ts @@ -9,6 +9,7 @@ // API utilities export { getAuthHeaders, getJsonHeaders } from './api-headers'; +export { apiFetch, apiFetchWithParams, apiPost, type ApiFetchOptions } from './api-fetch'; export { validateApiKey } from './api-key-validation'; // Attachment utilities @@ -75,8 +76,7 @@ export { maskInlineLaTeX, preprocessLaTeX } from './latex-protection'; export { isFileTypeSupportedByModel, filterFilesByModalities, - generateModalityErrorMessage, - type ModalityCapabilities + generateModalityErrorMessage } from './modality-file-validation'; // Model name utilities @@ -93,3 +93,6 @@ export { getLanguageFromFilename } from './syntax-highlight-language'; // Text file utilities export { isTextFileByName, readFileAsText, isLikelyTextFile } from './text-files'; + +// Image error fallback utilities +export { getImageErrorFallbackHtml } from './image-error-fallback'; diff --git a/tools/server/webui/src/lib/utils/modality-file-validation.ts b/tools/server/webui/src/lib/utils/modality-file-validation.ts index 136c084146..02fb4e4a36 100644 --- a/tools/server/webui/src/lib/utils/modality-file-validation.ts +++ b/tools/server/webui/src/lib/utils/modality-file-validation.ts @@ -5,12 +5,7 @@ import { getFileTypeCategory } from '$lib/utils'; import { FileTypeCategory } from '$lib/enums'; - -/** Modality capabilities for file validation */ -export interface ModalityCapabilities { - hasVision: boolean; - hasAudio: boolean; -} +import type { ModalityCapabilities } from '$lib/types/models'; /** * Check if a file type is supported by the given modalities diff --git a/tools/server/webui/src/lib/utils/text-files.ts b/tools/server/webui/src/lib/utils/text-files.ts index e8006de64d..2f1a575d1d 100644 --- a/tools/server/webui/src/lib/utils/text-files.ts +++ b/tools/server/webui/src/lib/utils/text-files.ts @@ -3,10 +3,8 @@ * Handles text file detection, reading, and validation */ -import { - DEFAULT_BINARY_DETECTION_OPTIONS, - type BinaryDetectionOptions -} from '$lib/constants/binary-detection'; +import { DEFAULT_BINARY_DETECTION_OPTIONS } from '$lib/constants/binary-detection'; +import type { BinaryDetectionOptions } from '$lib/constants/binary-detection'; import { FileExtensionText } from '$lib/enums'; /** diff --git a/tools/server/webui/tests/stories/ChatForm.stories.svelte b/tools/server/webui/tests/stories/ChatForm.stories.svelte index 18319e8e61..a8a4c21b44 100644 --- a/tools/server/webui/tests/stories/ChatForm.stories.svelte +++ b/tools/server/webui/tests/stories/ChatForm.stories.svelte @@ -2,7 +2,6 @@ import { defineMeta } from '@storybook/addon-svelte-csf'; import ChatForm from '$lib/components/app/chat/ChatForm/ChatForm.svelte'; import { expect } from 'storybook/test'; - import { mockServerProps, mockConfigs } from './fixtures/storybook-mocks'; import jpgAsset from './fixtures/assets/1.jpg?url'; import svgAsset from './fixtures/assets/hf-logo.svg?url'; import pdfAsset from './fixtures/assets/example.pdf?raw'; @@ -46,8 +45,6 @@ name="Default" args={{ class: 'max-w-[56rem] w-[calc(100vw-2rem)]' }} play={async ({ canvas, userEvent }) => { - mockServerProps(mockConfigs.noModalities); - const textarea = await canvas.findByRole('textbox'); const submitButton = await canvas.findByRole('button', { name: 'Send' }); @@ -66,73 +63,11 @@ const fileInput = document.querySelector('input[type="file"]'); await expect(fileInput).not.toHaveAttribute('accept'); - - // Open file attachments dropdown - const fileUploadButton = canvas.getByText('Attach files'); - await userEvent.click(fileUploadButton); - - // Check dropdown menu items are disabled (no modalities) - const imagesButton = document.querySelector('.images-button'); - const audioButton = document.querySelector('.audio-button'); - - await expect(imagesButton).toHaveAttribute('data-disabled'); - await expect(audioButton).toHaveAttribute('data-disabled'); - - // Close dropdown by pressing Escape - await userEvent.keyboard('{Escape}'); }} /> - { - mockServerProps(mockConfigs.visionOnly); - - // Open file attachments dropdown and verify it works - const fileUploadButton = canvas.getByText('Attach files'); - await userEvent.click(fileUploadButton); - - // Verify dropdown menu items exist - const imagesButton = document.querySelector('.images-button'); - const audioButton = document.querySelector('.audio-button'); - - await expect(imagesButton).toBeInTheDocument(); - await expect(audioButton).toBeInTheDocument(); - - // Close dropdown by pressing Escape - await userEvent.keyboard('{Escape}'); - - console.log('✅ Vision modality: Dropdown menu verified'); - }} -/> - - { - mockServerProps(mockConfigs.audioOnly); - - // Open file attachments dropdown and verify it works - const fileUploadButton = canvas.getByText('Attach files'); - await userEvent.click(fileUploadButton); - - // Verify dropdown menu items exist - const imagesButton = document.querySelector('.images-button'); - const audioButton = document.querySelector('.audio-button'); - - await expect(imagesButton).toBeInTheDocument(); - await expect(audioButton).toBeInTheDocument(); - - // Close dropdown by pressing Escape - await userEvent.keyboard('{Escape}'); - - console.log('✅ Audio modality: Dropdown menu verified'); - }} -/> - { - mockServerProps(mockConfigs.bothModalities); - const jpgAttachment = canvas.getByAltText('1.jpg'); const svgAttachment = canvas.getByAltText('hf-logo.svg'); const pdfFileExtension = canvas.getByText('PDF'); diff --git a/tools/server/webui/tests/stories/ChatMessage.stories.svelte b/tools/server/webui/tests/stories/ChatMessage.stories.svelte index 5f4de7d476..a3579cf04e 100644 --- a/tools/server/webui/tests/stories/ChatMessage.stories.svelte +++ b/tools/server/webui/tests/stories/ChatMessage.stories.svelte @@ -93,7 +93,7 @@ }} play={async () => { const { settingsStore } = await import('$lib/stores/settings.svelte'); - settingsStore.updateConfig('disableReasoningFormat', false); + settingsStore.updateConfig('showRawOutputSwitch', false); }} /> @@ -105,7 +105,7 @@ }} play={async () => { const { settingsStore } = await import('$lib/stores/settings.svelte'); - settingsStore.updateConfig('disableReasoningFormat', false); + settingsStore.updateConfig('showRawOutputSwitch', false); }} /> @@ -117,7 +117,7 @@ }} play={async () => { const { settingsStore } = await import('$lib/stores/settings.svelte'); - settingsStore.updateConfig('disableReasoningFormat', false); + settingsStore.updateConfig('showRawOutputSwitch', false); }} /> @@ -129,7 +129,7 @@ }} play={async () => { const { settingsStore } = await import('$lib/stores/settings.svelte'); - settingsStore.updateConfig('disableReasoningFormat', true); + settingsStore.updateConfig('showRawOutputSwitch', true); }} /> @@ -141,7 +141,7 @@ asChild play={async () => { const { settingsStore } = await import('$lib/stores/settings.svelte'); - settingsStore.updateConfig('disableReasoningFormat', false); + settingsStore.updateConfig('showRawOutputSwitch', false); // Phase 1: Stream reasoning content in chunks let reasoningText = 'I need to think about this carefully. Let me break down the problem:\n\n1. The user is asking for help with something complex\n2. I should provide a thorough and helpful response\n3. I need to consider multiple approaches\n4. The best solution would be to explain step by step\n\nThis approach will ensure clarity and understanding.'; @@ -193,7 +193,7 @@ }} play={async () => { const { settingsStore } = await import('$lib/stores/settings.svelte'); - settingsStore.updateConfig('disableReasoningFormat', false); + settingsStore.updateConfig('showRawOutputSwitch', false); // Import the chat store to simulate loading state const { chatStore } = await import('$lib/stores/chat.svelte'); diff --git a/tools/tts/README.md b/tools/tts/README.md index 48302c070b..4749bb9f5a 100644 --- a/tools/tts/README.md +++ b/tools/tts/README.md @@ -34,7 +34,7 @@ $ build/bin/llama-quantize models/outetts-0.2-0.5B-f16.gguf \ ``` The quantized model will be `models/outetts-0.2-0.5B-q8_0.gguf`. -Next we do something simlar for the audio decoder. First download or checkout +Next we do something similar for the audio decoder. First download or checkout the model for the voice decoder: ```console $ pushd models @@ -42,7 +42,7 @@ $ git clone --branch main --single-branch --depth 1 https://huggingface.co/novat $ cd WavTokenizer-large-speech-75token && git lfs install && git lfs pull $ popd ``` -This model file is PyTorch checkpoint (.ckpt) and we first need to convert it to +This model file is a PyTorch checkpoint (.ckpt) and we first need to convert it to huggingface format: ```console (venv) python tools/tts/convert_pt_to_hf.py \ diff --git a/tools/tts/tts.cpp b/tools/tts/tts.cpp index 8c39fce8ba..ac55a8b1ca 100644 --- a/tools/tts/tts.cpp +++ b/tools/tts/tts.cpp @@ -1036,7 +1036,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 #if 1 // spectral operations - const int n_embd = llama_model_n_embd(model_cts); + const int n_embd = llama_model_n_embd_out(model_cts); const float * embd = llama_get_embeddings(ctx_cts); auto audio = embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads); diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index a8a59e02f4..a5887476af 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -39,7 +39,7 @@ if (LLAMA_BUILD_BORINGSSL) set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)") set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository") - set(BORINGSSL_VERSION "0.20260204.0" CACHE STRING "BoringSSL version") + set(BORINGSSL_VERSION "0.20260211.0" CACHE STRING "BoringSSL version") message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}") diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index ba5f9c8ff9..e309a7ad5d 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -6,8 +6,472 @@ namespace httplib { * Implementation that will be part of the .cc file if split into .h + .cc. */ +namespace stream { + +// stream::Result implementations +Result::Result() : chunk_size_(8192) {} + +Result::Result(ClientImpl::StreamHandle &&handle, size_t chunk_size) + : handle_(std::move(handle)), chunk_size_(chunk_size) {} + +Result::Result(Result &&other) noexcept + : handle_(std::move(other.handle_)), buffer_(std::move(other.buffer_)), + current_size_(other.current_size_), chunk_size_(other.chunk_size_), + finished_(other.finished_) { + other.current_size_ = 0; + other.finished_ = true; +} + +Result &Result::operator=(Result &&other) noexcept { + if (this != &other) { + handle_ = std::move(other.handle_); + buffer_ = std::move(other.buffer_); + current_size_ = other.current_size_; + chunk_size_ = other.chunk_size_; + finished_ = other.finished_; + other.current_size_ = 0; + other.finished_ = true; + } + return *this; +} + +bool Result::is_valid() const { return handle_.is_valid(); } +Result::operator bool() const { return is_valid(); } + +int Result::status() const { + return handle_.response ? handle_.response->status : -1; +} + +const Headers &Result::headers() const { + static const Headers empty_headers; + return handle_.response ? handle_.response->headers : empty_headers; +} + +std::string Result::get_header_value(const std::string &key, + const char *def) const { + return handle_.response ? handle_.response->get_header_value(key, def) : def; +} + +bool Result::has_header(const std::string &key) const { + return handle_.response ? handle_.response->has_header(key) : false; +} + +Error Result::error() const { return handle_.error; } +Error Result::read_error() const { return handle_.get_read_error(); } +bool Result::has_read_error() const { return handle_.has_read_error(); } + +bool Result::next() { + if (!handle_.is_valid() || finished_) { return false; } + + if (buffer_.size() < chunk_size_) { buffer_.resize(chunk_size_); } + + ssize_t n = handle_.read(&buffer_[0], chunk_size_); + if (n > 0) { + current_size_ = static_cast(n); + return true; + } + + current_size_ = 0; + finished_ = true; + return false; +} + +const char *Result::data() const { return buffer_.data(); } +size_t Result::size() const { return current_size_; } + +std::string Result::read_all() { + std::string result; + while (next()) { + result.append(data(), size()); + } + return result; +} + +} // namespace stream + +namespace sse { + +// SSEMessage implementations +SSEMessage::SSEMessage() : event("message") {} + +void SSEMessage::clear() { + event = "message"; + data.clear(); + id.clear(); +} + +// SSEClient implementations +SSEClient::SSEClient(Client &client, const std::string &path) + : client_(client), path_(path) {} + +SSEClient::SSEClient(Client &client, const std::string &path, + const Headers &headers) + : client_(client), path_(path), headers_(headers) {} + +SSEClient::~SSEClient() { stop(); } + +SSEClient &SSEClient::on_message(MessageHandler handler) { + on_message_ = std::move(handler); + return *this; +} + +SSEClient &SSEClient::on_event(const std::string &type, + MessageHandler handler) { + event_handlers_[type] = std::move(handler); + return *this; +} + +SSEClient &SSEClient::on_open(OpenHandler handler) { + on_open_ = std::move(handler); + return *this; +} + +SSEClient &SSEClient::on_error(ErrorHandler handler) { + on_error_ = std::move(handler); + return *this; +} + +SSEClient &SSEClient::set_reconnect_interval(int ms) { + reconnect_interval_ms_ = ms; + return *this; +} + +SSEClient &SSEClient::set_max_reconnect_attempts(int n) { + max_reconnect_attempts_ = n; + return *this; +} + +bool SSEClient::is_connected() const { return connected_.load(); } + +const std::string &SSEClient::last_event_id() const { + return last_event_id_; +} + +void SSEClient::start() { + running_.store(true); + run_event_loop(); +} + +void SSEClient::start_async() { + running_.store(true); + async_thread_ = std::thread([this]() { run_event_loop(); }); +} + +void SSEClient::stop() { + running_.store(false); + client_.stop(); // Cancel any pending operations + if (async_thread_.joinable()) { async_thread_.join(); } +} + +bool SSEClient::parse_sse_line(const std::string &line, SSEMessage &msg, + int &retry_ms) { + // Blank line signals end of event + if (line.empty() || line == "\r") { return true; } + + // Lines starting with ':' are comments (ignored) + if (!line.empty() && line[0] == ':') { return false; } + + // Find the colon separator + auto colon_pos = line.find(':'); + if (colon_pos == std::string::npos) { + // Line with no colon is treated as field name with empty value + return false; + } + + auto field = line.substr(0, colon_pos); + std::string value; + + // Value starts after colon, skip optional single space + if (colon_pos + 1 < line.size()) { + auto value_start = colon_pos + 1; + if (line[value_start] == ' ') { value_start++; } + value = line.substr(value_start); + // Remove trailing \r if present + if (!value.empty() && value.back() == '\r') { value.pop_back(); } + } + + // Handle known fields + if (field == "event") { + msg.event = value; + } else if (field == "data") { + // Multiple data lines are concatenated with newlines + if (!msg.data.empty()) { msg.data += "\n"; } + msg.data += value; + } else if (field == "id") { + // Empty id is valid (clears the last event ID) + msg.id = value; + } else if (field == "retry") { + // Parse retry interval in milliseconds + { + int v = 0; + auto res = + detail::from_chars(value.data(), value.data() + value.size(), v); + if (res.ec == std::errc{}) { retry_ms = v; } + } + } + // Unknown fields are ignored per SSE spec + + return false; +} + +void SSEClient::run_event_loop() { + auto reconnect_count = 0; + + while (running_.load()) { + // Build headers, including Last-Event-ID if we have one + auto request_headers = headers_; + if (!last_event_id_.empty()) { + request_headers.emplace("Last-Event-ID", last_event_id_); + } + + // Open streaming connection + auto result = stream::Get(client_, path_, request_headers); + + // Connection error handling + if (!result) { + connected_.store(false); + if (on_error_) { on_error_(result.error()); } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + if (result.status() != 200) { + connected_.store(false); + // For certain errors, don't reconnect + if (result.status() == 204 || // No Content - server wants us to stop + result.status() == 404 || // Not Found + result.status() == 401 || // Unauthorized + result.status() == 403) { // Forbidden + if (on_error_) { on_error_(Error::Connection); } + break; + } + + if (on_error_) { on_error_(Error::Connection); } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + // Connection successful + connected_.store(true); + reconnect_count = 0; + if (on_open_) { on_open_(); } + + // Event receiving loop + std::string buffer; + SSEMessage current_msg; + + while (running_.load() && result.next()) { + buffer.append(result.data(), result.size()); + + // Process complete lines in the buffer + size_t line_start = 0; + size_t newline_pos; + + while ((newline_pos = buffer.find('\n', line_start)) != + std::string::npos) { + auto line = buffer.substr(line_start, newline_pos - line_start); + line_start = newline_pos + 1; + + // Parse the line and check if event is complete + auto event_complete = + parse_sse_line(line, current_msg, reconnect_interval_ms_); + + if (event_complete && !current_msg.data.empty()) { + // Update last_event_id for reconnection + if (!current_msg.id.empty()) { last_event_id_ = current_msg.id; } + + // Dispatch event to appropriate handler + dispatch_event(current_msg); + + current_msg.clear(); + } + } + + // Keep unprocessed data in buffer + buffer.erase(0, line_start); + } + + // Connection ended + connected_.store(false); + + if (!running_.load()) { break; } + + // Check for read errors + if (result.has_read_error()) { + if (on_error_) { on_error_(result.read_error()); } + } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + } + + connected_.store(false); +} + +void SSEClient::dispatch_event(const SSEMessage &msg) { + // Check for specific event type handler first + auto it = event_handlers_.find(msg.event); + if (it != event_handlers_.end()) { + it->second(msg); + return; + } + + // Fall back to generic message handler + if (on_message_) { on_message_(msg); } +} + +bool SSEClient::should_reconnect(int count) const { + if (!running_.load()) { return false; } + if (max_reconnect_attempts_ == 0) { return true; } // unlimited + return count < max_reconnect_attempts_; +} + +void SSEClient::wait_for_reconnect() { + // Use small increments to check running_ flag frequently + auto waited = 0; + while (running_.load() && waited < reconnect_interval_ms_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + waited += 100; + } +} + +} // namespace sse + +#ifdef CPPHTTPLIB_SSL_ENABLED +/* + * TLS abstraction layer - internal function declarations + * These are implementation details and not part of the public API. + */ +namespace tls { + +// Client context +ctx_t create_client_context(); +void free_context(ctx_t ctx); +bool set_min_version(ctx_t ctx, Version version); +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len); +bool load_ca_file(ctx_t ctx, const char *file_path); +bool load_ca_dir(ctx_t ctx, const char *dir_path); +bool load_system_certs(ctx_t ctx); +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password); +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password); + +// Server context +ctx_t create_server_context(); +bool set_server_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password); +bool set_server_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password); +bool set_client_ca_file(ctx_t ctx, const char *ca_file, const char *ca_dir); +void set_verify_client(ctx_t ctx, bool require); + +// Session management +session_t create_session(ctx_t ctx, socket_t sock); +void free_session(session_t session); +bool set_sni(session_t session, const char *hostname); +bool set_hostname(session_t session, const char *hostname); + +// Handshake (non-blocking capable) +TlsError connect(session_t session); +TlsError accept(session_t session); + +// Handshake with timeout (blocking until timeout) +bool connect_nonblocking(session_t session, socket_t sock, time_t timeout_sec, + time_t timeout_usec, TlsError *err); +bool accept_nonblocking(session_t session, socket_t sock, time_t timeout_sec, + time_t timeout_usec, TlsError *err); + +// I/O (non-blocking capable) +ssize_t read(session_t session, void *buf, size_t len, TlsError &err); +ssize_t write(session_t session, const void *buf, size_t len, TlsError &err); +int pending(const_session_t session); +void shutdown(session_t session, bool graceful); + +// Connection state +bool is_peer_closed(session_t session, socket_t sock); + +// Certificate verification +cert_t get_peer_cert(const_session_t session); +void free_cert(cert_t cert); +bool verify_hostname(cert_t cert, const char *hostname); +uint64_t hostname_mismatch_code(); +long get_verify_result(const_session_t session); + +// Certificate introspection +std::string get_cert_subject_cn(cert_t cert); +std::string get_cert_issuer_name(cert_t cert); +bool get_cert_sans(cert_t cert, std::vector &sans); +bool get_cert_validity(cert_t cert, time_t ¬_before, time_t ¬_after); +std::string get_cert_serial(cert_t cert); +bool get_cert_der(cert_t cert, std::vector &der); +const char *get_sni(const_session_t session); + +// CA store management +ca_store_t create_ca_store(const char *pem, size_t len); +void free_ca_store(ca_store_t store); +bool set_ca_store(ctx_t ctx, ca_store_t store); +size_t get_ca_certs(ctx_t ctx, std::vector &certs); +std::vector get_ca_names(ctx_t ctx); + +// Dynamic certificate update (for servers) +bool update_server_cert(ctx_t ctx, const char *cert_pem, const char *key_pem, + const char *password); +bool update_server_client_ca(ctx_t ctx, const char *ca_pem); + +// Certificate verification callback +bool set_verify_callback(ctx_t ctx, VerifyCallback callback); +long get_verify_error(const_session_t session); +std::string verify_error_string(long error_code); + +// TlsError information +uint64_t peek_error(); +uint64_t get_error(); +std::string error_string(uint64_t code); + +} // namespace tls +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 1: detail namespace - Non-SSL utilities + */ + namespace detail { +bool set_socket_opt_impl(socket_t sock, int level, int optname, + const void *optval, socklen_t optlen) { + return setsockopt(sock, level, optname, +#ifdef _WIN32 + reinterpret_cast(optval), +#else + optval, +#endif + optlen) == 0; +} + +bool set_socket_opt(socket_t sock, int level, int optname, int optval) { + return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval)); +} + +bool set_socket_opt_time(socket_t sock, int level, int optname, + time_t sec, time_t usec) { +#ifdef _WIN32 + auto timeout = static_cast(sec * 1000 + usec / 1000); +#else + timeval timeout; + timeout.tv_sec = static_cast(sec); + timeout.tv_usec = static_cast(usec); +#endif + return set_socket_opt_impl(sock, level, optname, &timeout, sizeof(timeout)); +} + bool is_hex(char c, int &v) { if (isdigit(c)) { v = c - '0'; @@ -940,39 +1404,6 @@ private: static const size_t read_buff_size_ = 1024l * 4; }; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -class SSLSocketStream final : public Stream { -public: - SSLSocketStream( - socket_t sock, SSL *ssl, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, time_t max_timeout_msec = 0, - std::chrono::time_point start_time = - (std::chrono::steady_clock::time_point::min)()); - ~SSLSocketStream() override; - - bool is_readable() const override; - bool wait_readable() const override; - bool wait_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; - void get_local_ip_and_port(std::string &ip, int &port) const override; - socket_t socket() const override; - time_t duration() const override; - -private: - socket_t sock_; - SSL *ssl_; - time_t read_timeout_sec_; - time_t read_timeout_usec_; - time_t write_timeout_sec_; - time_t write_timeout_usec_; - time_t max_timeout_msec_; - const std::chrono::time_point start_time_; -}; -#endif - bool keep_alive(const std::atomic &svr_sock, socket_t sock, time_t keep_alive_timeout_sec) { using namespace std::chrono; @@ -2270,14 +2701,23 @@ bool read_headers(Stream &strm, Headers &headers) { return true; } -bool read_content_with_length(Stream &strm, size_t len, - DownloadProgress progress, - ContentReceiverWithProgress out) { +enum class ReadContentResult { + Success, // Successfully read the content + PayloadTooLarge, // The content exceeds the specified payload limit + Error // An error occurred while reading the content +}; + +ReadContentResult read_content_with_length( + Stream &strm, size_t len, DownloadProgress progress, + ContentReceiverWithProgress out, + size_t payload_max_length = (std::numeric_limits::max)()) { char buf[CPPHTTPLIB_RECV_BUFSIZ]; detail::BodyReader br; br.stream = &strm; + br.has_content_length = true; br.content_length = len; + br.payload_max_length = payload_max_length; br.chunked = false; br.bytes_read = 0; br.last_error = Error::Success; @@ -2287,36 +2727,27 @@ bool read_content_with_length(Stream &strm, size_t len, auto read_len = static_cast(len - r); auto to_read = (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ); auto n = detail::read_body_content(&strm, br, buf, to_read); - if (n <= 0) { return false; } + if (n <= 0) { + // Check if it was a payload size error + if (br.last_error == Error::ExceedMaxPayloadSize) { + return ReadContentResult::PayloadTooLarge; + } + return ReadContentResult::Error; + } - if (!out(buf, static_cast(n), r, len)) { return false; } + if (!out(buf, static_cast(n), r, len)) { + return ReadContentResult::Error; + } r += static_cast(n); if (progress) { - if (!progress(r, len)) { return false; } + if (!progress(r, len)) { return ReadContentResult::Error; } } } - return true; + return ReadContentResult::Success; } -void skip_content_with_length(Stream &strm, size_t len) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - size_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return; } - r += static_cast(n); - } -} - -enum class ReadContentResult { - Success, // Successfully read the content - PayloadTooLarge, // The content exceeds the specified payload limit - Error // An error occurred while reading the content -}; - ReadContentResult read_content_without_length(Stream &strm, size_t payload_max_length, ContentReceiverWithProgress out) { @@ -2462,12 +2893,13 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, if (is_invalid_value) { ret = false; - } else if (len > payload_max_length) { - exceed_payload_max_length = true; - skip_content_with_length(strm, len); - ret = false; } else if (len > 0) { - ret = read_content_with_length(strm, len, std::move(progress), out); + auto result = read_content_with_length( + strm, len, std::move(progress), out, payload_max_length); + ret = (result == ReadContentResult::Success); + if (result == ReadContentResult::PayloadTooLarge) { + exceed_payload_max_length = true; + } } } @@ -3645,226 +4077,6 @@ bool has_crlf(const std::string &s) { return false; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -std::string message_digest(const std::string &s, const EVP_MD *algo) { - auto context = std::unique_ptr( - EVP_MD_CTX_new(), EVP_MD_CTX_free); - - unsigned int hash_length = 0; - unsigned char hash[EVP_MAX_MD_SIZE]; - - EVP_DigestInit_ex(context.get(), algo, nullptr); - EVP_DigestUpdate(context.get(), s.c_str(), s.size()); - EVP_DigestFinal_ex(context.get(), hash, &hash_length); - - std::stringstream ss; - for (auto i = 0u; i < hash_length; ++i) { - ss << std::hex << std::setw(2) << std::setfill('0') - << static_cast(hash[i]); - } - - return ss.str(); -} - -std::string MD5(const std::string &s) { - return message_digest(s, EVP_md5()); -} - -std::string SHA_256(const std::string &s) { - return message_digest(s, EVP_sha256()); -} - -std::string SHA_512(const std::string &s) { - return message_digest(s, EVP_sha512()); -} - -std::pair make_digest_authentication_header( - const Request &req, const std::map &auth, - size_t cnonce_count, const std::string &cnonce, const std::string &username, - const std::string &password, bool is_proxy = false) { - std::string nc; - { - std::stringstream ss; - ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; - nc = ss.str(); - } - - std::string qop; - if (auth.find("qop") != auth.end()) { - qop = auth.at("qop"); - if (qop.find("auth-int") != std::string::npos) { - qop = "auth-int"; - } else if (qop.find("auth") != std::string::npos) { - qop = "auth"; - } else { - qop.clear(); - } - } - - std::string algo = "MD5"; - if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } - - std::string response; - { - auto H = algo == "SHA-256" ? detail::SHA_256 - : algo == "SHA-512" ? detail::SHA_512 - : detail::MD5; - - auto A1 = username + ":" + auth.at("realm") + ":" + password; - - auto A2 = req.method + ":" + req.path; - if (qop == "auth-int") { A2 += ":" + H(req.body); } - - if (qop.empty()) { - response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); - } else { - response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + - ":" + qop + ":" + H(A2)); - } - } - - auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; - - auto field = "Digest username=\"" + username + "\", realm=\"" + - auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + - "\", uri=\"" + req.path + "\", algorithm=" + algo + - (qop.empty() ? ", response=\"" - : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + - cnonce + "\", response=\"") + - response + "\"" + - (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); - - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); -} - -bool is_ssl_peer_could_be_closed(SSL *ssl, socket_t sock) { - detail::set_nonblocking(sock, true); - auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); - - char buf[1]; - return !SSL_peek(ssl, buf, 1) && - SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; -} - -#ifdef _WIN32 -// NOTE: This code came up with the following stackoverflow post: -// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store -bool load_system_certs_on_windows(X509_STORE *store) { - auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); - if (!hStore) { return false; } - - auto result = false; - PCCERT_CONTEXT pContext = NULL; - while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != - nullptr) { - auto encoded_cert = - static_cast(pContext->pbCertEncoded); - - auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); - if (x509) { - X509_STORE_add_cert(store, x509); - X509_free(x509); - result = true; - } - } - - CertFreeCertificateContext(pContext); - CertCloseStore(hStore, 0); - - return result; -} -#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && TARGET_OS_MAC -template -using CFObjectPtr = - std::unique_ptr::type, void (*)(CFTypeRef)>; - -void cf_object_ptr_deleter(CFTypeRef obj) { - if (obj) { CFRelease(obj); } -} - -bool retrieve_certs_from_keychain(CFObjectPtr &certs) { - CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef}; - CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll, - kCFBooleanTrue}; - - CFObjectPtr query( - CFDictionaryCreate(nullptr, reinterpret_cast(keys), values, - sizeof(keys) / sizeof(keys[0]), - &kCFTypeDictionaryKeyCallBacks, - &kCFTypeDictionaryValueCallBacks), - cf_object_ptr_deleter); - - if (!query) { return false; } - - CFTypeRef security_items = nullptr; - if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess || - CFArrayGetTypeID() != CFGetTypeID(security_items)) { - return false; - } - - certs.reset(reinterpret_cast(security_items)); - return true; -} - -bool retrieve_root_certs_from_keychain(CFObjectPtr &certs) { - CFArrayRef root_security_items = nullptr; - if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) { - return false; - } - - certs.reset(root_security_items); - return true; -} - -bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) { - auto result = false; - for (auto i = 0; i < CFArrayGetCount(certs); ++i) { - const auto cert = reinterpret_cast( - CFArrayGetValueAtIndex(certs, i)); - - if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { continue; } - - CFDataRef cert_data = nullptr; - if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) != - errSecSuccess) { - continue; - } - - CFObjectPtr cert_data_ptr(cert_data, cf_object_ptr_deleter); - - auto encoded_cert = static_cast( - CFDataGetBytePtr(cert_data_ptr.get())); - - auto x509 = - d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get())); - - if (x509) { - X509_STORE_add_cert(store, x509); - X509_free(x509); - result = true; - } - } - - return result; -} - -bool load_system_certs_on_macos(X509_STORE *store) { - auto result = false; - CFObjectPtr certs(nullptr, cf_object_ptr_deleter); - if (retrieve_certs_from_keychain(certs) && certs) { - result = add_certs_to_x509_store(certs.get(), store); - } - - if (retrieve_root_certs_from_keychain(certs) && certs) { - result = add_certs_to_x509_store(certs.get(), store) || result; - } - - return result; -} -#endif // _WIN32 -#endif // CPPHTTPLIB_OPENSSL_SUPPORT - #ifdef _WIN32 class WSInit { public: @@ -3984,8 +4196,393 @@ bool is_field_content(const std::string &s) { bool is_field_value(const std::string &s) { return is_field_content(s); } } // namespace fields +} // namespace detail + +/* + * Group 2: detail namespace - SSL common utilities + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED +namespace detail { + +class SSLSocketStream final : public Stream { +public: + SSLSocketStream( + socket_t sock, tls::session_t session, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec = 0, + std::chrono::time_point start_time = + (std::chrono::steady_clock::time_point::min)()); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + +private: + socket_t sock_; + tls::session_t session_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + time_t max_timeout_msec_; + const std::chrono::time_point start_time_; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +std::string message_digest(const std::string &s, const EVP_MD *algo) { + auto context = std::unique_ptr( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + + unsigned int hash_length = 0; + unsigned char hash[EVP_MAX_MD_SIZE]; + + EVP_DigestInit_ex(context.get(), algo, nullptr); + EVP_DigestUpdate(context.get(), s.c_str(), s.size()); + EVP_DigestFinal_ex(context.get(), hash, &hash_length); + + std::stringstream ss; + for (auto i = 0u; i < hash_length; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + + return ss.str(); +} + +std::string MD5(const std::string &s) { + return message_digest(s, EVP_md5()); +} + +std::string SHA_256(const std::string &s) { + return message_digest(s, EVP_sha256()); +} + +std::string SHA_512(const std::string &s) { + return message_digest(s, EVP_sha512()); +} +#elif defined(CPPHTTPLIB_MBEDTLS_SUPPORT) +namespace { +template +std::string hash_to_hex(const unsigned char (&hash)[N]) { + std::stringstream ss; + for (size_t i = 0; i < N; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + return ss.str(); +} +} // namespace + +std::string MD5(const std::string &s) { + unsigned char hash[16]; +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_md5(reinterpret_cast(s.c_str()), s.size(), + hash); +#else + mbedtls_md5_ret(reinterpret_cast(s.c_str()), s.size(), + hash); +#endif + return hash_to_hex(hash); +} + +std::string SHA_256(const std::string &s) { + unsigned char hash[32]; +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_sha256(reinterpret_cast(s.c_str()), s.size(), + hash, 0); +#else + mbedtls_sha256_ret(reinterpret_cast(s.c_str()), + s.size(), hash, 0); +#endif + return hash_to_hex(hash); +} + +std::string SHA_512(const std::string &s) { + unsigned char hash[64]; +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_sha512(reinterpret_cast(s.c_str()), s.size(), + hash, 0); +#else + mbedtls_sha512_ret(reinterpret_cast(s.c_str()), + s.size(), hash, 0); +#endif + return hash_to_hex(hash); +} +#endif + +bool is_ip_address(const std::string &host) { + struct in_addr addr4; + struct in6_addr addr6; + return inet_pton(AF_INET, host.c_str(), &addr4) == 1 || + inet_pton(AF_INET6, host.c_str(), &addr6) == 1; +} + +template +bool process_server_socket_ssl( + const std::atomic &svr_sock, tls::session_t session, + socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, session, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +template +bool process_client_socket_ssl( + tls::session_t session, socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec, + std::chrono::time_point start_time, T callback) { + SSLSocketStream strm(sock, session, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); + return callback(strm); +} + +std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 + : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" + : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + + cnonce + "\", response=\"") + + response + "\"" + + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} + +bool match_hostname(const std::string &pattern, + const std::string &hostname) { + // Exact match (case-insensitive) + if (detail::case_ignore::equal(hostname, pattern)) { return true; } + + // Split both pattern and hostname into components by '.' + std::vector pattern_components; + if (!pattern.empty()) { + split(pattern.data(), pattern.data() + pattern.size(), '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(b, e); + }); + } + + std::vector host_components; + if (!hostname.empty()) { + split(hostname.data(), hostname.data() + hostname.size(), '.', + [&](const char *b, const char *e) { + host_components.emplace_back(b, e); + }); + } + + // Component count must match + if (host_components.size() != pattern_components.size()) { return false; } + + // Compare each component with wildcard support + // Supports: "*" (full wildcard), "prefix*" (partial wildcard) + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + auto itr = pattern_components.begin(); + for (const auto &h : host_components) { + auto &p = *itr; + if (!detail::case_ignore::equal(p, h) && p != "*") { + bool partial_match = false; + if (!p.empty() && p[p.size() - 1] == '*') { + const auto prefix_length = p.size() - 1; + if (prefix_length == 0) { + partial_match = true; + } else if (h.size() >= prefix_length) { + partial_match = + std::equal(p.begin(), + p.begin() + static_cast( + prefix_length), + h.begin(), [](const char ca, const char cb) { + return detail::case_ignore::to_lower(ca) == + detail::case_ignore::to_lower(cb); + }); + } + } + if (!partial_match) { return false; } + } + ++itr; + } + + return true; +} + +#ifdef _WIN32 +// Verify certificate using Windows CertGetCertificateChain API. +// This provides real-time certificate validation with Windows Update +// integration, independent of the TLS backend (OpenSSL or MbedTLS). +bool verify_cert_with_windows_schannel( + const std::vector &der_cert, const std::string &hostname, + bool verify_hostname, unsigned long &out_error) { + if (der_cert.empty()) { return false; } + + out_error = 0; + + // Create Windows certificate context from DER data + auto cert_context = CertCreateCertificateContext( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, der_cert.data(), + static_cast(der_cert.size())); + + if (!cert_context) { + out_error = GetLastError(); + return false; + } + + auto cert_guard = + scope_exit([&] { CertFreeCertificateContext(cert_context); }); + + // Setup chain parameters + CERT_CHAIN_PARA chain_para = {}; + chain_para.cbSize = sizeof(chain_para); + + // Build certificate chain with revocation checking + PCCERT_CHAIN_CONTEXT chain_context = nullptr; + auto chain_result = CertGetCertificateChain( + nullptr, cert_context, nullptr, cert_context->hCertStore, &chain_para, + CERT_CHAIN_CACHE_END_CERT | CERT_CHAIN_REVOCATION_CHECK_END_CERT | + CERT_CHAIN_REVOCATION_ACCUMULATIVE_TIMEOUT, + nullptr, &chain_context); + + if (!chain_result || !chain_context) { + out_error = GetLastError(); + return false; + } + + auto chain_guard = + scope_exit([&] { CertFreeCertificateChain(chain_context); }); + + // Check if chain has errors + if (chain_context->TrustStatus.dwErrorStatus != CERT_TRUST_NO_ERROR) { + out_error = chain_context->TrustStatus.dwErrorStatus; + return false; + } + + // Verify SSL policy + SSL_EXTRA_CERT_CHAIN_POLICY_PARA extra_policy_para = {}; + extra_policy_para.cbSize = sizeof(extra_policy_para); +#ifdef AUTHTYPE_SERVER + extra_policy_para.dwAuthType = AUTHTYPE_SERVER; +#endif + + std::wstring whost; + if (verify_hostname) { + whost = u8string_to_wstring(hostname.c_str()); + extra_policy_para.pwszServerName = const_cast(whost.c_str()); + } + + CERT_CHAIN_POLICY_PARA policy_para = {}; + policy_para.cbSize = sizeof(policy_para); +#ifdef CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS + policy_para.dwFlags = CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS; +#else + policy_para.dwFlags = 0; +#endif + policy_para.pvExtraPolicyPara = &extra_policy_para; + + CERT_CHAIN_POLICY_STATUS policy_status = {}; + policy_status.cbSize = sizeof(policy_status); + + if (!CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, chain_context, + &policy_para, &policy_status)) { + out_error = GetLastError(); + return false; + } + + if (policy_status.dwError != 0) { + out_error = policy_status.dwError; + return false; + } + + return true; +} +#endif // _WIN32 } // namespace detail +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 3: httplib namespace - Non-SSL public API implementations + */ + +void default_socket_options(socket_t sock) { + detail::set_socket_opt(sock, SOL_SOCKET, +#ifdef SO_REUSEPORT + SO_REUSEPORT, +#else + SO_REUSEADDR, +#endif + 1); +} + +std::string get_bearer_token_auth(const Request &req) { + if (req.has_header("Authorization")) { + constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); + return req.get_header_value("Authorization") + .substr(bearer_header_prefix_len); + } + return ""; +} const char *status_message(int status) { switch (status) { @@ -4426,6 +5023,11 @@ make_bearer_token_authentication_header(const std::string &token, } // Request implementation +size_t Request::get_header_value_u64(const std::string &key, size_t def, + size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + bool Request::has_header(const std::string &key) const { return detail::has_header(headers, key); } @@ -4547,6 +5149,11 @@ size_t MultipartFormData::get_file_count(const std::string &key) const { } // Response implementation +size_t Response::get_header_value_u64(const std::string &key, size_t def, + size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + bool Response::has_header(const std::string &key) const { return headers.find(key) != headers.end(); } @@ -4662,6 +5269,12 @@ void Response::set_file_content(const std::string &path) { } // Result implementation +size_t Result::get_request_header_value_u64(const std::string &key, + size_t def, + size_t id) const { + return detail::get_header_value_u64(request_headers_, key, def, id); +} + bool Result::has_request_header(const std::string &key) const { return request_headers_.find(key) != request_headers_.end(); } @@ -4697,13 +5310,16 @@ ssize_t detail::BodyReader::read(char *buf, size_t len) { if (!chunked) { // Content-Length based reading - if (bytes_read >= content_length) { + if (has_content_length && bytes_read >= content_length) { eof = true; return 0; } - auto remaining = content_length - bytes_read; - auto to_read = (std::min)(len, remaining); + auto to_read = len; + if (has_content_length) { + auto remaining = content_length - bytes_read; + to_read = (std::min)(len, remaining); + } auto n = stream->read(buf, to_read); if (n < 0) { @@ -4721,7 +5337,12 @@ ssize_t detail::BodyReader::read(char *buf, size_t len) { } bytes_read += static_cast(n); - if (bytes_read >= content_length) { eof = true; } + if (has_content_length && bytes_read >= content_length) { eof = true; } + if (payload_max_length > 0 && bytes_read > payload_max_length) { + last_error = Error::ExceedMaxPayloadSize; + eof = true; + return -1; + } return n; } @@ -4745,9 +5366,83 @@ ssize_t detail::BodyReader::read(char *buf, size_t len) { } bytes_read += static_cast(n); + if (payload_max_length > 0 && bytes_read > payload_max_length) { + last_error = Error::ExceedMaxPayloadSize; + eof = true; + return -1; + } return n; } +// ThreadPool implementation +ThreadPool::ThreadPool(size_t n, size_t mqr) + : shutdown_(false), max_queued_requests_(mqr) { + threads_.reserve(n); + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } +} + +bool ThreadPool::enqueue(std::function fn) { + { + std::unique_lock lock(mutex_); + if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { + return false; + } + jobs_.push_back(std::move(fn)); + } + + cond_.notify_one(); + return true; +} + +void ThreadPool::shutdown() { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } +} + +ThreadPool::worker::worker(ThreadPool &pool) : pool_(pool) {} + +void ThreadPool::worker::operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait(lock, + [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ + !defined(LIBRESSL_VERSION_NUMBER) + OPENSSL_thread_stop(); +#endif +} + +/* + * Group 1 (continued): detail namespace - Stream implementations + */ + namespace detail { void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, @@ -5076,6 +5771,155 @@ bool check_and_write_headers(Stream &strm, Headers &headers, } // namespace detail +/* + * Group 2 (continued): detail namespace - SSLSocketStream implementation + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED +namespace detail { + +// SSL socket stream implementation +SSLSocketStream::SSLSocketStream( + socket_t sock, tls::session_t session, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec, + std::chrono::time_point start_time) + : sock_(sock), session_(session), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + max_timeout_msec_(max_timeout_msec), start_time_(start_time) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // Clear AUTO_RETRY for proper non-blocking I/O timeout handling + // Note: create_session() also clears this, but SSLClient currently + // uses ssl_new() which does not. Until full TLS API migration is complete, + // we need to ensure AUTO_RETRY is cleared here regardless of how the + // SSL session was created. + SSL_clear_mode(static_cast(session), SSL_MODE_AUTO_RETRY); +#endif +} + +SSLSocketStream::~SSLSocketStream() = default; + +bool SSLSocketStream::is_readable() const { + return tls::pending(session_) > 0; +} + +bool SSLSocketStream::wait_readable() const { + if (max_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; +} + +bool SSLSocketStream::wait_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_) && !tls::is_peer_closed(session_, sock_); +} + +ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (tls::pending(session_) > 0) { + tls::TlsError err; + auto ret = tls::read(session_, ptr, size, err); + if (ret == 0 || err.code == tls::ErrorCode::PeerClosed) { + error_ = Error::ConnectionClosed; + } + return ret; + } else if (wait_readable()) { + tls::TlsError err; + auto ret = tls::read(session_, ptr, size, err); + if (ret < 0) { + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err.code == tls::ErrorCode::WantRead || + (err.code == tls::ErrorCode::SyscallError && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err.code == tls::ErrorCode::WantRead) { +#endif + if (tls::pending(session_) > 0) { + return tls::read(session_, ptr, size, err); + } else if (wait_readable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = tls::read(session_, ptr, size, err); + if (ret >= 0) { return ret; } + } else { + break; + } + } + assert(ret < 0); + } else if (ret == 0 || err.code == tls::ErrorCode::PeerClosed) { + error_ = Error::ConnectionClosed; + } + return ret; + } else { + error_ = Error::Timeout; + return -1; + } +} + +ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (wait_writable()) { + auto handle_size = + std::min(size, (std::numeric_limits::max)()); + + tls::TlsError err; + auto ret = tls::write(session_, ptr, handle_size, err); + if (ret < 0) { + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err.code == tls::ErrorCode::WantWrite || + (err.code == tls::ErrorCode::SyscallError && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err.code == tls::ErrorCode::WantWrite) { +#endif + if (wait_writable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = tls::write(session_, ptr, handle_size, err); + if (ret >= 0) { return ret; } + } else { + break; + } + } + assert(ret < 0); + } + return ret; + } + return -1; +} + +void SSLSocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +void SSLSocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + +socket_t SSLSocketStream::socket() const { return sock_; } + +time_t SSLSocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time_) + .count(); +} + +} // namespace detail +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 4: Server implementation + */ + // HTTP server implementation Server::Server() : new_task_queue( @@ -5677,36 +6521,40 @@ bool Server::read_content_core( // are true (no Transfer-Encoding and no Content-Length), then the message // body length is zero (no message body is present). // - // For non-SSL builds, peek into the socket to detect clients that send a - // body without a Content-Length header (raw HTTP over TCP). If there is - // pending data that exceeds the configured payload limit, treat this as an - // oversized request and fail early (causing connection close). For SSL - // builds we cannot reliably peek the decrypted application bytes, so keep - // the original behaviour. -#if !defined(CPPHTTPLIB_OPENSSL_SUPPORT) + // For non-SSL builds, detect clients that send a body without a + // Content-Length header (raw HTTP over TCP). Check both the stream's + // internal read buffer (data already read from the socket during header + // parsing) and the socket itself for pending data. If data is found and + // exceeds the configured payload limit, reject with 413. + // For SSL builds we cannot reliably peek the decrypted application bytes, + // so keep the original behaviour. +#if !defined(CPPHTTPLIB_SSL_ENABLED) if (!req.has_header("Content-Length") && !detail::is_chunked_transfer_encoding(req.headers)) { - // Only peek if payload_max_length is set to a finite value + // Only check if payload_max_length is set to a finite value if (payload_max_length_ > 0 && payload_max_length_ < (std::numeric_limits::max)()) { - socket_t s = strm.socket(); - if (s != INVALID_SOCKET) { - // Peek to check if there is any pending data - char peekbuf[1]; - ssize_t n = ::recv(s, peekbuf, 1, MSG_PEEK); - if (n > 0) { - // There is data, so read it with payload limit enforcement - auto result = detail::read_content_without_length( - strm, payload_max_length_, out); - if (result == detail::ReadContentResult::PayloadTooLarge) { - res.status = StatusCode::PayloadTooLarge_413; - return false; - } else if (result != detail::ReadContentResult::Success) { - return false; - } - return true; + // Check if there is data already buffered in the stream (read during + // header parsing) or pending on the socket. Use a non-blocking socket + // check to avoid deadlock when the client sends no body. + bool has_data = strm.is_readable(); + if (!has_data) { + socket_t s = strm.socket(); + if (s != INVALID_SOCKET) { + has_data = detail::select_read(s, 0, 0) > 0; } } + if (has_data) { + auto result = + detail::read_content_without_length(strm, payload_max_length_, out); + if (result == detail::ReadContentResult::PayloadTooLarge) { + res.status = StatusCode::PayloadTooLarge_413; + return false; + } else if (result != detail::ReadContentResult::Success) { + return false; + } + return true; + } } return true; } @@ -5815,8 +6663,10 @@ bool Server::check_if_not_modified(const Request &req, Response &res, // simplified implementation requires exact matches. auto ret = detail::split_find(val.data(), val.data() + val.size(), ',', [&](const char *b, const char *e) { - return std::equal(b, e, "*") || - std::equal(b, e, etag.begin()); + auto seg_len = static_cast(e - b); + return (seg_len == 1 && *b == '*') || + (seg_len == etag.size() && + std::equal(b, e, etag.begin())); }); if (ret) { @@ -6518,6 +7368,9 @@ void Server::output_error_log(const Error &err, } } +/* + * Group 5: ClientImpl and Client (Universal) implementation + */ // HTTP client implementation ClientImpl::ClientImpl(const std::string &host) : ClientImpl(host, 80, std::string(), std::string()) {} @@ -6561,10 +7414,6 @@ void ClientImpl::copy_settings(const ClientImpl &rhs) { basic_auth_username_ = rhs.basic_auth_username_; basic_auth_password_ = rhs.basic_auth_password_; bearer_token_auth_token_ = rhs.bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - digest_auth_username_ = rhs.digest_auth_username_; - digest_auth_password_ = rhs.digest_auth_password_; -#endif keep_alive_ = rhs.keep_alive_; follow_location_ = rhs.follow_location_; path_encode_ = rhs.path_encode_; @@ -6574,28 +7423,27 @@ void ClientImpl::copy_settings(const ClientImpl &rhs) { socket_options_ = rhs.socket_options_; compress_ = rhs.compress_; decompress_ = rhs.decompress_; + payload_max_length_ = rhs.payload_max_length_; + has_payload_max_length_ = rhs.has_payload_max_length_; interface_ = rhs.interface_; proxy_host_ = rhs.proxy_host_; proxy_port_ = rhs.proxy_port_; proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; - proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; -#endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - ca_cert_file_path_ = rhs.ca_cert_file_path_; - ca_cert_dir_path_ = rhs.ca_cert_dir_path_; - ca_cert_store_ = rhs.ca_cert_store_; -#endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - server_certificate_verification_ = rhs.server_certificate_verification_; - server_hostname_verification_ = rhs.server_hostname_verification_; - server_certificate_verifier_ = rhs.server_certificate_verifier_; -#endif logger_ = rhs.logger_; error_logger_ = rhs.error_logger_; + +#ifdef CPPHTTPLIB_SSL_ENABLED + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; + ca_cert_file_path_ = rhs.ca_cert_file_path_; + ca_cert_dir_path_ = rhs.ca_cert_dir_path_; + server_certificate_verification_ = rhs.server_certificate_verification_; + server_hostname_verification_ = rhs.server_hostname_verification_; +#endif } socket_t ClientImpl::create_client_socket(Error &error) const { @@ -6631,22 +7479,6 @@ bool ClientImpl::ensure_socket_connection(Socket &socket, Error &error) { return create_and_connect_socket(socket, error); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -bool SSLClient::ensure_socket_connection(Socket &socket, Error &error) { - if (!ClientImpl::ensure_socket_connection(socket, error)) { return false; } - - if (!proxy_host_.empty() && proxy_port_ != -1) { return true; } - - if (!initialize_ssl(socket, error)) { - shutdown_socket(socket); - close_socket(socket); - return false; - } - - return true; -} -#endif - void ClientImpl::shutdown_ssl(Socket & /*socket*/, bool /*shutdown_gracefully*/) { // If there are any requests in flight from threads other than us, then it's @@ -6671,9 +7503,10 @@ void ClientImpl::close_socket(Socket &socket) { socket_requests_are_from_thread_ == std::this_thread::get_id()); // It is also a bug if this happens while SSL is still active -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED assert(socket.ssl == nullptr); #endif + if (socket.sock == INVALID_SOCKET) { return; } detail::close_socket(socket.sock); socket.sock = INVALID_SOCKET; @@ -6722,6 +7555,8 @@ bool ClientImpl::send(Request &req, Response &res, Error &error) { if (error == Error::SSLPeerCouldBeClosed_) { assert(!ret); ret = send_(req, res, error); + // If still failing with SSLPeerCouldBeClosed_, convert to Read error + if (error == Error::SSLPeerCouldBeClosed_) { error = Error::Read; } } return ret; } @@ -6739,9 +7574,9 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { if (socket_.is_open()) { is_alive = detail::is_socket_alive(socket_.sock); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (is_alive && is_ssl()) { - if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + if (tls::is_peer_closed(socket_.ssl, socket_.sock)) { is_alive = false; } } @@ -6765,7 +7600,7 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { return false; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED // TODO: refactoring if (is_ssl()) { auto &scli = static_cast(*this); @@ -6847,9 +7682,9 @@ Result ClientImpl::send_(Request &&req) { auto res = detail::make_unique(); auto error = Error::Success; auto ret = send(req, *res, error); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers), - last_ssl_error_, last_openssl_error_}; + last_ssl_error_, last_backend_error_}; #else return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)}; #endif @@ -6926,9 +7761,9 @@ ClientImpl::open_stream(const std::string &method, const std::string &path, auto is_alive = false; if (socket_.is_open()) { is_alive = detail::is_socket_alive(socket_.sock); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (is_alive && is_ssl()) { - if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + if (tls::is_peer_closed(socket_.ssl, socket_.sock)) { is_alive = false; } } @@ -6946,7 +7781,7 @@ ClientImpl::open_stream(const std::string &method, const std::string &path, return handle; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (is_ssl()) { auto &scli = static_cast(*this); if (!proxy_host_.empty() && proxy_port_ != -1) { @@ -6962,11 +7797,12 @@ ClientImpl::open_stream(const std::string &method, const std::string &path, transfer_socket_ownership_to_handle(handle); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (is_ssl() && handle.connection_->ssl) { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl() && handle.connection_->session) { handle.socket_stream_ = detail::make_unique( - handle.connection_->sock, handle.connection_->ssl, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_); + handle.connection_->sock, handle.connection_->session, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_); } else { handle.socket_stream_ = detail::make_unique( handle.connection_->sock, read_timeout_sec_, read_timeout_usec_, @@ -7016,9 +7852,11 @@ ClientImpl::open_stream(const std::string &method, const std::string &path, } handle.body_reader_.stream = handle.stream_; + handle.body_reader_.payload_max_length = payload_max_length_; auto content_length_str = handle.response->get_header_value("Content-Length"); if (!content_length_str.empty()) { + handle.body_reader_.has_content_length = true; handle.body_reader_.content_length = static_cast(std::stoull(content_length_str)); } @@ -7066,6 +7904,7 @@ ssize_t ClientImpl::StreamHandle::read_with_decompression(char *buf, auto to_copy = (std::min)(len, available); std::memcpy(buf, decompress_buffer_.data() + decompress_offset_, to_copy); decompress_offset_ += to_copy; + decompressed_bytes_read_ += to_copy; return static_cast(to_copy); } @@ -7081,12 +7920,16 @@ ssize_t ClientImpl::StreamHandle::read_with_decompression(char *buf, if (n <= 0) { return n; } - bool decompress_ok = - decompressor_->decompress(compressed_buf, static_cast(n), - [this](const char *data, size_t data_len) { - decompress_buffer_.append(data, data_len); - return true; - }); + bool decompress_ok = decompressor_->decompress( + compressed_buf, static_cast(n), + [this](const char *data, size_t data_len) { + decompress_buffer_.append(data, data_len); + auto limit = body_reader_.payload_max_length; + if (decompressed_bytes_read_ + decompress_buffer_.size() > limit) { + return false; + } + return true; + }); if (!decompress_ok) { body_reader_.last_error = Error::Read; @@ -7099,6 +7942,7 @@ ssize_t ClientImpl::StreamHandle::read_with_decompression(char *buf, auto to_copy = (std::min)(len, decompress_buffer_.size()); std::memcpy(buf, decompress_buffer_.data(), to_copy); decompress_offset_ = to_copy; + decompressed_bytes_read_ += to_copy; return static_cast(to_copy); } @@ -7121,7 +7965,6 @@ void ClientImpl::StreamHandle::parse_trailers_if_needed() { } } -// Inline method implementations for `ChunkedDecoder`. namespace detail { ChunkedDecoder::ChunkedDecoder(Stream &s) : strm(s) {} @@ -7185,8 +8028,8 @@ bool ChunkedDecoder::parse_trailers_into(Headers &dest, void ClientImpl::transfer_socket_ownership_to_handle(StreamHandle &handle) { handle.connection_->sock = socket_.sock; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - handle.connection_->ssl = socket_.ssl; +#ifdef CPPHTTPLIB_SSL_ENABLED + handle.connection_->session = socket_.ssl; socket_.ssl = nullptr; #endif socket_.sock = INVALID_SOCKET; @@ -7239,7 +8082,7 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, ret = redirect(req, res, error); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if ((res.status == StatusCode::Unauthorized_401 || res.status == StatusCode::ProxyAuthenticationRequired_407) && req.authorization_count_ < 5) { @@ -7343,7 +8186,7 @@ bool ClientImpl::create_redirect_client( // Create appropriate client type and handle redirect if (need_ssl) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED // Create SSL client for HTTPS redirect SSLClient redirect_client(host, port); @@ -7363,9 +8206,10 @@ bool ClientImpl::create_redirect_client( server_hostname_verification_); } - // Handle CA certificate store and paths if available - if (ca_cert_store_ && X509_STORE_up_ref(ca_cert_store_)) { - redirect_client.set_ca_cert_store(ca_cert_store_); + // Transfer CA certificate to redirect client + if (!ca_cert_pem_.empty()) { + redirect_client.load_ca_cert_store(ca_cert_pem_.c_str(), + ca_cert_pem_.size()); } if (!ca_cert_file_path_.empty()) { redirect_client.set_ca_cert_path(ca_cert_file_path_, ca_cert_dir_path_); @@ -7418,7 +8262,7 @@ void ClientImpl::setup_redirect_client(ClientType &client) { if (!bearer_token_auth_token_.empty()) { client.set_bearer_token_auth(bearer_token_auth_token_); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (!digest_auth_username_.empty()) { client.set_digest_auth(digest_auth_username_, digest_auth_password_); } @@ -7438,7 +8282,7 @@ void ClientImpl::setup_redirect_client(ClientType &client) { if (!proxy_bearer_token_auth_token_.empty()) { client.set_proxy_bearer_token_auth(proxy_bearer_token_auth_token_); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (!proxy_digest_auth_username_.empty()) { client.set_proxy_digest_auth(proxy_digest_auth_username_, proxy_digest_auth_password_); @@ -7809,9 +8653,9 @@ Result ClientImpl::send_with_content_provider_and_receiver( std::move(content_provider_without_length), content_type, std::move(content_receiver), error); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED return Result{std::move(res), error, std::move(req.headers), last_ssl_error_, - last_openssl_error_}; + last_backend_error_}; #else return Result{std::move(res), error, std::move(req.headers)}; #endif @@ -7851,11 +8695,11 @@ bool ClientImpl::process_request(Stream &strm, Request &req, auto write_request_success = write_request(strm, req, close_connection, error, expect_100_continue); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (is_ssl()) { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl() && !expect_100_continue) { auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; if (!is_proxy_enabled) { - if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + if (tls::is_peer_closed(socket_.ssl, socket_.sock)) { error = Error::SSLPeerCouldBeClosed_; output_error_log(error, &req); return false; @@ -7937,6 +8781,11 @@ bool ClientImpl::process_request(Stream &strm, Request &req, [&](const char *buf, size_t n, size_t /*off*/, size_t /*len*/) { assert(res.body.size() + n <= res.body.max_size()); + if (payload_max_length_ > 0 && + (res.body.size() >= payload_max_length_ || + n > payload_max_length_ - res.body.size())) { + return false; + } res.body.append(buf, n); return true; }); @@ -7965,9 +8814,12 @@ bool ClientImpl::process_request(Stream &strm, Request &req, if (res.status != StatusCode::NotModified_304) { int dummy_status; - if (!detail::read_content(strm, res, (std::numeric_limits::max)(), - dummy_status, std::move(progress), - std::move(out), decompress_)) { + auto max_length = (!has_payload_max_length_ && req.content_receiver) + ? (std::numeric_limits::max)() + : payload_max_length_; + if (!detail::read_content(strm, res, max_length, dummy_status, + std::move(progress), std::move(out), + decompress_)) { if (error != Error::Canceled) { error = Error::Read; } output_error_log(error, &req); return false; @@ -8878,14 +9730,6 @@ void ClientImpl::set_bearer_token_auth(const std::string &token) { bearer_token_auth_token_ = token; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void ClientImpl::set_digest_auth(const std::string &username, - const std::string &password) { - digest_auth_username_ = username; - digest_auth_password_ = password; -} -#endif - void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } @@ -8922,6 +9766,11 @@ void ClientImpl::set_compress(bool on) { compress_ = on; } void ClientImpl::set_decompress(bool on) { decompress_ = on; } +void ClientImpl::set_payload_max_length(size_t length) { + payload_max_length_ = length; + has_payload_max_length_ = true; +} + void ClientImpl::set_interface(const std::string &intf) { interface_ = intf; } @@ -8941,11 +9790,11 @@ void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) { proxy_bearer_token_auth_token_ = token; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void ClientImpl::set_proxy_digest_auth(const std::string &username, - const std::string &password) { - proxy_digest_auth_username_ = username; - proxy_digest_auth_password_ = password; +#ifdef CPPHTTPLIB_SSL_ENABLED +void ClientImpl::set_digest_auth(const std::string &username, + const std::string &password) { + digest_auth_username_ = username; + digest_auth_password_ = password; } void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, @@ -8954,12 +9803,23 @@ void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, ca_cert_dir_path_ = ca_cert_dir_path; } -void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (ca_cert_store && ca_cert_store != ca_cert_store_) { - ca_cert_store_ = ca_cert_store; - } +void ClientImpl::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; } +void ClientImpl::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +void ClientImpl::enable_server_hostname_verification(bool enabled) { + server_hostname_verification_ = enabled; +} +#endif + +// ClientImpl::set_ca_cert_store is defined after TLS namespace (uses helpers) +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, std::size_t size) const { auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); @@ -8984,17 +9844,9 @@ X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, return cts; } -void ClientImpl::enable_server_certificate_verification(bool enabled) { - server_certificate_verification_ = enabled; -} - -void ClientImpl::enable_server_hostname_verification(bool enabled) { - server_hostname_verification_ = enabled; -} - void ClientImpl::set_server_certificate_verifier( - std::function verifier) { - server_certificate_verifier_ = verifier; + std::function /*verifier*/) { + // Base implementation does nothing - SSLClient overrides this } #endif @@ -9007,958 +9859,24 @@ void ClientImpl::set_error_logger(ErrorLogger error_logger) { } /* - * SSL Implementation + * SSL/TLS Common Implementation */ -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -namespace detail { -bool is_ip_address(const std::string &host) { - struct in_addr addr4; - struct in6_addr addr6; - return inet_pton(AF_INET, host.c_str(), &addr4) == 1 || - inet_pton(AF_INET6, host.c_str(), &addr6) == 1; -} - -template -SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, - U SSL_connect_or_accept, V setup) { - SSL *ssl = nullptr; - { - std::lock_guard guard(ctx_mutex); - ssl = SSL_new(ctx); - } - - if (ssl) { - set_nonblocking(sock, true); - auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); - BIO_set_nbio(bio, 1); - SSL_set_bio(ssl, bio, bio); - - if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { - SSL_shutdown(ssl); - { - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); - } - set_nonblocking(sock, false); - return nullptr; - } - BIO_set_nbio(bio, 0); - set_nonblocking(sock, false); - } - - return ssl; -} - -void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, - bool shutdown_gracefully) { - // sometimes we may want to skip this to try to avoid SIGPIPE if we know - // the remote has closed the network connection - // Note that it is not always possible to avoid SIGPIPE, this is merely a - // best-efforts. - if (shutdown_gracefully) { - (void)(sock); - // SSL_shutdown() returns 0 on first call (indicating close_notify alert - // sent) and 1 on subsequent call (indicating close_notify alert received) - if (SSL_shutdown(ssl) == 0) { - // Expected to return 1, but even if it doesn't, we free ssl - SSL_shutdown(ssl); - } - } - - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); -} - -template -bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl, - U ssl_connect_or_accept, - time_t timeout_sec, time_t timeout_usec, - int *ssl_error) { - auto res = 0; - while ((res = ssl_connect_or_accept(ssl)) != 1) { - auto err = SSL_get_error(ssl, res); - switch (err) { - case SSL_ERROR_WANT_READ: - if (select_read(sock, timeout_sec, timeout_usec) > 0) { continue; } - break; - case SSL_ERROR_WANT_WRITE: - if (select_write(sock, timeout_sec, timeout_usec) > 0) { continue; } - break; - default: break; - } - if (ssl_error) { *ssl_error = err; } - return false; - } - return true; -} - -template -bool process_server_socket_ssl( - const std::atomic &svr_sock, SSL *ssl, socket_t sock, - size_t keep_alive_max_count, time_t keep_alive_timeout_sec, - time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, T callback) { - return process_server_socket_core( - svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, - [&](bool close_connection, bool &connection_closed) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm, close_connection, connection_closed); - }); -} - -template -bool process_client_socket_ssl( - SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec, - time_t max_timeout_msec, - std::chrono::time_point start_time, T callback) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec, max_timeout_msec, - start_time); - return callback(strm); -} - -// SSL socket stream implementation -SSLSocketStream::SSLSocketStream( - socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec, - time_t max_timeout_msec, - std::chrono::time_point start_time) - : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), - read_timeout_usec_(read_timeout_usec), - write_timeout_sec_(write_timeout_sec), - write_timeout_usec_(write_timeout_usec), - max_timeout_msec_(max_timeout_msec), start_time_(start_time) { - SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); -} - -SSLSocketStream::~SSLSocketStream() = default; - -bool SSLSocketStream::is_readable() const { - return SSL_pending(ssl_) > 0; -} - -bool SSLSocketStream::wait_readable() const { - if (max_timeout_msec_ <= 0) { - return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; - } - - time_t read_timeout_sec; - time_t read_timeout_usec; - calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, - read_timeout_usec_, read_timeout_sec, read_timeout_usec); - - return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; -} - -bool SSLSocketStream::wait_writable() const { - return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && - is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); -} - -ssize_t SSLSocketStream::read(char *ptr, size_t size) { - if (SSL_pending(ssl_) > 0) { - auto ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret == 0) { error_ = Error::ConnectionClosed; } - return ret; - } else if (wait_readable()) { - auto ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret < 0) { - auto err = SSL_get_error(ssl_, ret); - auto n = 1000; -#ifdef _WIN32 - while (--n >= 0 && (err == SSL_ERROR_WANT_READ || - (err == SSL_ERROR_SYSCALL && - WSAGetLastError() == WSAETIMEDOUT))) { -#else - while (--n >= 0 && err == SSL_ERROR_WANT_READ) { -#endif - if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); - } else if (wait_readable()) { - std::this_thread::sleep_for(std::chrono::microseconds{10}); - ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret >= 0) { return ret; } - err = SSL_get_error(ssl_, ret); - } else { - break; - } - } - assert(ret < 0); - } else if (ret == 0) { - error_ = Error::ConnectionClosed; - } - return ret; - } else { - error_ = Error::Timeout; - return -1; - } -} - -ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (wait_writable()) { - auto handle_size = static_cast( - std::min(size, (std::numeric_limits::max)())); - - auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); - if (ret < 0) { - auto err = SSL_get_error(ssl_, ret); - auto n = 1000; -#ifdef _WIN32 - while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE || - (err == SSL_ERROR_SYSCALL && - WSAGetLastError() == WSAETIMEDOUT))) { -#else - while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { -#endif - if (wait_writable()) { - std::this_thread::sleep_for(std::chrono::microseconds{10}); - ret = SSL_write(ssl_, ptr, static_cast(handle_size)); - if (ret >= 0) { return ret; } - err = SSL_get_error(ssl_, ret); - } else { - break; - } - } - assert(ret < 0); - } - return ret; - } - return -1; -} - -void SSLSocketStream::get_remote_ip_and_port(std::string &ip, - int &port) const { - detail::get_remote_ip_and_port(sock_, ip, port); -} - -void SSLSocketStream::get_local_ip_and_port(std::string &ip, - int &port) const { - detail::get_local_ip_and_port(sock_, ip, port); -} - -socket_t SSLSocketStream::socket() const { return sock_; } - -time_t SSLSocketStream::duration() const { - return std::chrono::duration_cast( - std::chrono::steady_clock::now() - start_time_) - .count(); -} - -} // namespace detail - -// SSL HTTP server implementation -SSLServer::SSLServer(const char *cert_path, const char *private_key_path, - const char *client_ca_cert_file_path, - const char *client_ca_cert_dir_path, - const char *private_key_password) { - ctx_ = SSL_CTX_new(TLS_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); - - if (private_key_password != nullptr && (private_key_password[0] != '\0')) { - SSL_CTX_set_default_passwd_cb_userdata( - ctx_, - reinterpret_cast(const_cast(private_key_password))); - } - - if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != - 1 || - SSL_CTX_check_private_key(ctx_) != 1) { - last_ssl_error_ = static_cast(ERR_get_error()); - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { - SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, - client_ca_cert_dir_path); - - // Set client CA list to be sent to clients during TLS handshake - if (client_ca_cert_file_path) { - auto ca_list = SSL_load_client_CA_file(client_ca_cert_file_path); - if (ca_list != nullptr) { - SSL_CTX_set_client_CA_list(ctx_, ca_list); - } else { - // Failed to load client CA list, but we continue since - // SSL_CTX_load_verify_locations already succeeded and - // certificate verification will still work - last_ssl_error_ = static_cast(ERR_get_error()); - } - } - - SSL_CTX_set_verify( - ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); - } - } -} - -SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store) { - ctx_ = SSL_CTX_new(TLS_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); - - if (SSL_CTX_use_certificate(ctx_, cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_store) { - SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); - - // Extract CA names from the store and set them as the client CA list - auto ca_list = extract_ca_names_from_x509_store(client_ca_cert_store); - if (ca_list) { - SSL_CTX_set_client_CA_list(ctx_, ca_list); - } else { - // Failed to extract CA names, record the error - last_ssl_error_ = static_cast(ERR_get_error()); - } - - SSL_CTX_set_verify( - ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); - } - } -} - -SSLServer::SSLServer( - const std::function &setup_ssl_ctx_callback) { - ctx_ = SSL_CTX_new(TLS_method()); - if (ctx_) { - if (!setup_ssl_ctx_callback(*ctx_)) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } - } -} - -SSLServer::~SSLServer() { - if (ctx_) { SSL_CTX_free(ctx_); } -} - -bool SSLServer::is_valid() const { return ctx_; } - -SSL_CTX *SSLServer::ssl_context() const { return ctx_; } - -void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store) { - - std::lock_guard guard(ctx_mutex_); - - SSL_CTX_use_certificate(ctx_, cert); - SSL_CTX_use_PrivateKey(ctx_, private_key); - - if (client_ca_cert_store != nullptr) { - SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); - } -} - -bool SSLServer::process_and_close_socket(socket_t sock) { - auto ssl = detail::ssl_new( - sock, ctx_, ctx_mutex_, - [&](SSL *ssl2) { - return detail::ssl_connect_or_accept_nonblocking( - sock, ssl2, SSL_accept, read_timeout_sec_, read_timeout_usec_, - &last_ssl_error_); - }, - [](SSL * /*ssl2*/) { return true; }); - - auto ret = false; - if (ssl) { - std::string remote_addr; - int remote_port = 0; - detail::get_remote_ip_and_port(sock, remote_addr, remote_port); - - std::string local_addr; - int local_port = 0; - detail::get_local_ip_and_port(sock, local_addr, local_port); - - ret = detail::process_server_socket_ssl( - svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, - read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, - [&](Stream &strm, bool close_connection, bool &connection_closed) { - return process_request(strm, remote_addr, remote_port, local_addr, - local_port, close_connection, - connection_closed, - [&](Request &req) { req.ssl = ssl; }); - }); - - // Shutdown gracefully if the result seemed successful, non-gracefully if - // the connection appeared to be closed. - const bool shutdown_gracefully = ret; - detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); - } - - detail::shutdown_socket(sock); - detail::close_socket(sock); - return ret; -} - -STACK_OF(X509_NAME) * SSLServer::extract_ca_names_from_x509_store( - X509_STORE *store) { - if (!store) { return nullptr; } - - auto ca_list = sk_X509_NAME_new_null(); - if (!ca_list) { return nullptr; } - - // Get all objects from the store - auto objs = X509_STORE_get0_objects(store); - if (!objs) { - sk_X509_NAME_free(ca_list); - return nullptr; - } - - // Iterate through objects and extract certificate subject names - for (int i = 0; i < sk_X509_OBJECT_num(objs); i++) { - auto obj = sk_X509_OBJECT_value(objs, i); - if (X509_OBJECT_get_type(obj) == X509_LU_X509) { - auto cert = X509_OBJECT_get0_X509(obj); - if (cert) { - auto subject = X509_get_subject_name(cert); - if (subject) { - auto name_dup = X509_NAME_dup(subject); - if (name_dup) { sk_X509_NAME_push(ca_list, name_dup); } - } - } - } - } - - // If no names were extracted, free the list and return nullptr - if (sk_X509_NAME_num(ca_list) == 0) { - sk_X509_NAME_free(ca_list); - return nullptr; - } - - return ca_list; -} - -// SSL HTTP client implementation -SSLClient::SSLClient(const std::string &host) - : SSLClient(host, 443, std::string(), std::string()) {} - -SSLClient::SSLClient(const std::string &host, int port) - : SSLClient(host, port, std::string(), std::string()) {} - -SSLClient::SSLClient(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path, - const std::string &private_key_password) - : ClientImpl(host, port, client_cert_path, client_key_path) { - ctx_ = SSL_CTX_new(TLS_client_method()); - - SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(b, e); - }); - - if (!client_cert_path.empty() && !client_key_path.empty()) { - if (!private_key_password.empty()) { - SSL_CTX_set_default_passwd_cb_userdata( - ctx_, reinterpret_cast( - const_cast(private_key_password.c_str()))); - } - - if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), - SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), - SSL_FILETYPE_PEM) != 1) { - last_openssl_error_ = ERR_get_error(); - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } - } -} - -SSLClient::SSLClient(const std::string &host, int port, - X509 *client_cert, EVP_PKEY *client_key, - const std::string &private_key_password) - : ClientImpl(host, port) { - ctx_ = SSL_CTX_new(TLS_client_method()); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(b, e); - }); - - if (client_cert != nullptr && client_key != nullptr) { - if (!private_key_password.empty()) { - SSL_CTX_set_default_passwd_cb_userdata( - ctx_, reinterpret_cast( - const_cast(private_key_password.c_str()))); - } - - if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { - last_openssl_error_ = ERR_get_error(); - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } - } -} - -SSLClient::~SSLClient() { - if (ctx_) { SSL_CTX_free(ctx_); } - // Make sure to shut down SSL since shutdown_ssl will resolve to the - // base function rather than the derived function once we get to the - // base class destructor, and won't free the SSL (causing a leak). - shutdown_ssl_impl(socket_, true); -} - -bool SSLClient::is_valid() const { return ctx_; } - -void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (ca_cert_store) { - if (ctx_) { - if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { - // Free memory allocated for old cert and use new store - // `ca_cert_store` - SSL_CTX_set_cert_store(ctx_, ca_cert_store); - ca_cert_store_ = ca_cert_store; - } - } else { - X509_STORE_free(ca_cert_store); - } - } -} - -void SSLClient::load_ca_cert_store(const char *ca_cert, - std::size_t size) { - set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size)); -} - -long SSLClient::get_openssl_verify_result() const { - return verify_result_; -} - -SSL_CTX *SSLClient::ssl_context() const { return ctx_; } - -bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { - if (!is_valid()) { - error = Error::SSLConnection; - return false; - } - return ClientImpl::create_and_connect_socket(socket, error); -} - -// Assumes that socket_mutex_ is locked and that there are no requests in -// flight -bool SSLClient::connect_with_proxy( - Socket &socket, - std::chrono::time_point start_time, - Response &res, bool &success, Error &error) { - success = true; - Response proxy_res; - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, - start_time, [&](Stream &strm) { - Request req2; - req2.method = "CONNECT"; - req2.path = - detail::make_host_and_port_string_always_port(host_, port_); - if (max_timeout_msec_ > 0) { - req2.start_time_ = std::chrono::steady_clock::now(); - } - return process_request(strm, req2, proxy_res, false, error); - })) { - // Thread-safe to close everything because we are assuming there are no - // requests in flight - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - success = false; - return false; - } - - if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { - if (!proxy_digest_auth_username_.empty() && - !proxy_digest_auth_password_.empty()) { - std::map auth; - if (detail::parse_www_authenticate(proxy_res, auth, true)) { - // Close the current socket and create a new one for the authenticated - // request - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - - // Create a new socket for the authenticated CONNECT request - if (!ensure_socket_connection(socket, error)) { - success = false; - output_error_log(error, nullptr); - return false; - } - - proxy_res = Response(); - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, - start_time, [&](Stream &strm) { - Request req3; - req3.method = "CONNECT"; - req3.path = detail::make_host_and_port_string_always_port( - host_, port_); - req3.headers.insert(detail::make_digest_authentication_header( - req3, auth, 1, detail::random_string(10), - proxy_digest_auth_username_, proxy_digest_auth_password_, - true)); - if (max_timeout_msec_ > 0) { - req3.start_time_ = std::chrono::steady_clock::now(); - } - return process_request(strm, req3, proxy_res, false, error); - })) { - // Thread-safe to close everything because we are assuming there are - // no requests in flight - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - success = false; - return false; - } - } - } - } - - // If status code is not 200, proxy request is failed. - // Set error to ProxyConnection and return proxy response - // as the response of the request - if (proxy_res.status != StatusCode::OK_200) { - error = Error::ProxyConnection; - output_error_log(error, nullptr); - res = std::move(proxy_res); - // Thread-safe to close everything because we are assuming there are - // no requests in flight - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - return false; - } - - return true; -} - -bool SSLClient::load_certs() { - auto ret = true; - - std::call_once(initialize_cert_, [&]() { - std::lock_guard guard(ctx_mutex_); - if (!ca_cert_file_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), - nullptr)) { - last_openssl_error_ = ERR_get_error(); - ret = false; - } - } else if (!ca_cert_dir_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, nullptr, - ca_cert_dir_path_.c_str())) { - last_openssl_error_ = ERR_get_error(); - ret = false; - } - } else if (!ca_cert_store_) { - auto loaded = false; -#ifdef _WIN32 - loaded = - detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); -#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && TARGET_OS_MAC - loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); -#endif // _WIN32 - if (!loaded) { SSL_CTX_set_default_verify_paths(ctx_); } - } - }); - - return ret; -} - -bool SSLClient::initialize_ssl(Socket &socket, Error &error) { - auto ssl = detail::ssl_new( - socket.sock, ctx_, ctx_mutex_, - [&](SSL *ssl2) { - if (server_certificate_verification_) { - if (!load_certs()) { - error = Error::SSLLoadingCerts; - output_error_log(error, nullptr); - return false; - } - SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); - } - - if (!detail::ssl_connect_or_accept_nonblocking( - socket.sock, ssl2, SSL_connect, connection_timeout_sec_, - connection_timeout_usec_, &last_ssl_error_)) { - error = Error::SSLConnection; - output_error_log(error, nullptr); - return false; - } - - if (server_certificate_verification_) { - auto verification_status = SSLVerifierResponse::NoDecisionMade; - - if (server_certificate_verifier_) { - verification_status = server_certificate_verifier_(ssl2); - } - - if (verification_status == SSLVerifierResponse::CertificateRejected) { - last_openssl_error_ = ERR_get_error(); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - if (verification_status == SSLVerifierResponse::NoDecisionMade) { - verify_result_ = SSL_get_verify_result(ssl2); - - if (verify_result_ != X509_V_OK) { - last_openssl_error_ = static_cast(verify_result_); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - auto server_cert = SSL_get1_peer_certificate(ssl2); - auto se = detail::scope_exit([&] { X509_free(server_cert); }); - - if (server_cert == nullptr) { - last_openssl_error_ = ERR_get_error(); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - if (server_hostname_verification_) { - if (!verify_host(server_cert)) { - last_openssl_error_ = X509_V_ERR_HOSTNAME_MISMATCH; - error = Error::SSLServerHostnameVerification; - output_error_log(error, nullptr); - return false; - } - } - } - } - - return true; - }, - [&](SSL *ssl2) { - // Set SNI only if host is not IP address - if (!detail::is_ip_address(host_)) { -#if defined(OPENSSL_IS_BORINGSSL) - SSL_set_tlsext_host_name(ssl2, host_.c_str()); -#else - // NOTE: Direct call instead of using the OpenSSL macro to suppress - // -Wold-style-cast warning - SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, - TLSEXT_NAMETYPE_host_name, - static_cast(const_cast(host_.c_str()))); -#endif - } - return true; - }); - - if (ssl) { - socket.ssl = ssl; - return true; - } - - if (ctx_ == nullptr) { - error = Error::SSLConnection; - last_openssl_error_ = ERR_get_error(); - } - - shutdown_socket(socket); - close_socket(socket); - return false; -} - -void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { - shutdown_ssl_impl(socket, shutdown_gracefully); -} - -void SSLClient::shutdown_ssl_impl(Socket &socket, - bool shutdown_gracefully) { - if (socket.sock == INVALID_SOCKET) { - assert(socket.ssl == nullptr); - return; - } - if (socket.ssl) { - detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, - shutdown_gracefully); - socket.ssl = nullptr; - } - assert(socket.ssl == nullptr); -} - -bool SSLClient::process_socket( - const Socket &socket, - std::chrono::time_point start_time, - std::function callback) { - assert(socket.ssl); - return detail::process_client_socket_ssl( - socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time, - std::move(callback)); -} - -bool SSLClient::is_ssl() const { return true; } - -bool SSLClient::verify_host(X509 *server_cert) const { - /* Quote from RFC2818 section 3.1 "Server Identity" - - If a subjectAltName extension of type dNSName is present, that MUST - be used as the identity. Otherwise, the (most specific) Common Name - field in the Subject field of the certificate MUST be used. Although - the use of the Common Name is existing practice, it is deprecated and - Certification Authorities are encouraged to use the dNSName instead. - - Matching is performed using the matching rules specified by - [RFC2459]. If more than one identity of a given type is present in - the certificate (e.g., more than one dNSName name, a match in any one - of the set is considered acceptable.) Names may contain the wildcard - character * which is considered to match any single domain name - component or component fragment. E.g., *.a.com matches foo.a.com but - not bar.foo.a.com. f*.com matches foo.com but not bar.com. - - In some cases, the URI is specified as an IP address rather than a - hostname. In this case, the iPAddress subjectAltName must be present - in the certificate and must exactly match the IP in the URI. - - */ - return verify_host_with_subject_alt_name(server_cert) || - verify_host_with_common_name(server_cert); -} - -bool -SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { - auto ret = false; - - auto type = GEN_DNS; - - struct in6_addr addr6 = {}; - struct in_addr addr = {}; - size_t addr_len = 0; - -#ifndef __MINGW32__ - if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { - type = GEN_IPADD; - addr_len = sizeof(struct in6_addr); - } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { - type = GEN_IPADD; - addr_len = sizeof(struct in_addr); +ClientConnection::~ClientConnection() { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (session) { + tls::shutdown(session, true); + tls::free_session(session); + session = nullptr; } #endif - auto alt_names = static_cast( - X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); - - if (alt_names) { - auto dsn_matched = false; - auto ip_matched = false; - - auto count = sk_GENERAL_NAME_num(alt_names); - - for (decltype(count) i = 0; i < count && !dsn_matched; i++) { - auto val = sk_GENERAL_NAME_value(alt_names, i); - if (!val || val->type != type) { continue; } - - auto name = - reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); - if (name == nullptr) { continue; } - - auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); - - switch (type) { - case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; - - case GEN_IPADD: - if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) { - ip_matched = true; - } - break; - } - } - - if (dsn_matched || ip_matched) { ret = true; } + if (sock != INVALID_SOCKET) { + detail::close_socket(sock); + sock = INVALID_SOCKET; } - - GENERAL_NAMES_free(const_cast( - reinterpret_cast(alt_names))); - return ret; } -bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { - const auto subject_name = X509_get_subject_name(server_cert); - - if (subject_name != nullptr) { - char name[BUFSIZ]; - auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, - name, sizeof(name)); - - if (name_len != -1) { - return check_host_name(name, static_cast(name_len)); - } - } - - return false; -} - -bool SSLClient::check_host_name(const char *pattern, - size_t pattern_len) const { - // Exact match (case-insensitive) - if (host_.size() == pattern_len && - detail::case_ignore::equal(host_, std::string(pattern, pattern_len))) { - return true; - } - - // Wildcard match - // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 - std::vector pattern_components; - detail::split(&pattern[0], &pattern[pattern_len], '.', - [&](const char *b, const char *e) { - pattern_components.emplace_back(b, e); - }); - - if (host_components_.size() != pattern_components.size()) { return false; } - - auto itr = pattern_components.begin(); - for (const auto &h : host_components_) { - auto &p = *itr; - if (!httplib::detail::case_ignore::equal(p, h) && p != "*") { - bool partial_match = false; - if (!p.empty() && p[p.size() - 1] == '*') { - const auto prefix_length = p.size() - 1; - if (prefix_length == 0) { - partial_match = true; - } else if (h.size() >= prefix_length) { - partial_match = - std::equal(p.begin(), - p.begin() + static_cast( - prefix_length), - h.begin(), [](const char ca, const char cb) { - return httplib::detail::case_ignore::to_lower(ca) == - httplib::detail::case_ignore::to_lower(cb); - }); - } - } - if (!partial_match) { return false; } - } - ++itr; - } - - return true; -} -#endif - // Universal client implementation Client::Client(const std::string &scheme_host_port) : Client(scheme_host_port, std::string(), std::string()) {} @@ -9973,7 +9891,7 @@ Client::Client(const std::string &scheme_host_port, if (std::regex_match(scheme_host_port, m, re)) { auto scheme = m[1].str(); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (!scheme.empty() && (scheme != "http" && scheme != "https")) { #else if (!scheme.empty() && scheme != "http") { @@ -9994,7 +9912,7 @@ Client::Client(const std::string &scheme_host_port, auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); if (is_ssl) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED cli_ = detail::make_unique(host, port, client_cert_path, client_key_path); is_ssl_ = is_ssl; @@ -10579,12 +10497,6 @@ void Client::set_basic_auth(const std::string &username, void Client::set_bearer_token_auth(const std::string &token) { cli_->set_bearer_token_auth(token); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void Client::set_digest_auth(const std::string &username, - const std::string &password) { - cli_->set_digest_auth(username, password); -} -#endif void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } void Client::set_follow_location(bool on) { @@ -10602,6 +10514,10 @@ void Client::set_compress(bool on) { cli_->set_compress(on); } void Client::set_decompress(bool on) { cli_->set_decompress(on); } +void Client::set_payload_max_length(size_t length) { + cli_->set_payload_max_length(length); +} + void Client::set_interface(const std::string &intf) { cli_->set_interface(intf); } @@ -10616,27 +10532,6 @@ void Client::set_proxy_basic_auth(const std::string &username, void Client::set_proxy_bearer_token_auth(const std::string &token) { cli_->set_proxy_bearer_token_auth(token); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void Client::set_proxy_digest_auth(const std::string &username, - const std::string &password) { - cli_->set_proxy_digest_auth(username, password); -} -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void Client::enable_server_certificate_verification(bool enabled) { - cli_->enable_server_certificate_verification(enabled); -} - -void Client::enable_server_hostname_verification(bool enabled) { - cli_->enable_server_hostname_verification(enabled); -} - -void Client::set_server_certificate_verifier( - std::function verifier) { - cli_->set_server_certificate_verifier(verifier); -} -#endif void Client::set_logger(Logger logger) { cli_->set_logger(std::move(logger)); @@ -10646,35 +10541,3399 @@ void Client::set_error_logger(ErrorLogger error_logger) { cli_->set_error_logger(std::move(error_logger)); } +/* + * Group 6: SSL Server and Client implementation + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED + +// SSL HTTP server implementation +SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path, + const char *private_key_password) { + using namespace tls; + + ctx_ = create_server_context(); + if (!ctx_) { return; } + + // Load server certificate and private key + if (!set_server_cert_file(ctx_, cert_path, private_key_path, + private_key_password)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + return; + } + + // Load client CA certificates for client authentication + if (client_ca_cert_file_path || client_ca_cert_dir_path) { + if (!set_client_ca_file(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + return; + } + // Enable client certificate verification + set_verify_client(ctx_, true); + } +} + +SSLServer::SSLServer(const PemMemory &pem) { + using namespace tls; + ctx_ = create_server_context(); + if (ctx_) { + if (!set_server_cert_pem(ctx_, pem.cert_pem, pem.key_pem, + pem.private_key_password)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + } else if (pem.client_ca_pem && pem.client_ca_pem_len > 0) { + if (!load_ca_pem(ctx_, pem.client_ca_pem, pem.client_ca_pem_len)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + } else { + set_verify_client(ctx_, true); + } + } + } +} + +SSLServer::SSLServer(const tls::ContextSetupCallback &setup_callback) { + using namespace tls; + ctx_ = create_server_context(); + if (ctx_) { + if (!setup_callback(ctx_)) { + free_context(ctx_); + ctx_ = nullptr; + } + } +} + +SSLServer::~SSLServer() { + if (ctx_) { tls::free_context(ctx_); } +} + +bool SSLServer::is_valid() const { return ctx_ != nullptr; } + +bool SSLServer::process_and_close_socket(socket_t sock) { + using namespace tls; + + // Create TLS session with mutex protection + session_t session = nullptr; + { + std::lock_guard guard(ctx_mutex_); + session = create_session(static_cast(ctx_), sock); + } + + if (!session) { + last_ssl_error_ = static_cast(get_error()); + detail::shutdown_socket(sock); + detail::close_socket(sock); + return false; + } + + // Use scope_exit to ensure cleanup on all paths (including exceptions) + bool handshake_done = false; + bool ret = false; + auto cleanup = detail::scope_exit([&] { + // Shutdown gracefully if handshake succeeded and processing was successful + if (handshake_done) { shutdown(session, ret); } + free_session(session); + detail::shutdown_socket(sock); + detail::close_socket(sock); + }); + + // Perform TLS accept handshake with timeout + TlsError tls_err; + if (!accept_nonblocking(session, sock, read_timeout_sec_, read_timeout_usec_, + &tls_err)) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // Map TlsError to legacy ssl_error for backward compatibility + if (tls_err.code == ErrorCode::WantRead) { + last_ssl_error_ = SSL_ERROR_WANT_READ; + } else if (tls_err.code == ErrorCode::WantWrite) { + last_ssl_error_ = SSL_ERROR_WANT_WRITE; + } else { + last_ssl_error_ = SSL_ERROR_SSL; + } +#else + last_ssl_error_ = static_cast(get_error()); +#endif + return false; + } + + handshake_done = true; + + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + ret = detail::process_server_socket_ssl( + svr_sock_, session, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, connection_closed, + [&](Request &req) { req.ssl = session; }); + }); + + return ret; +} + +bool SSLServer::update_certs_pem(const char *cert_pem, + const char *key_pem, + const char *client_ca_pem, + const char *password) { + if (!ctx_) { return false; } + std::lock_guard guard(ctx_mutex_); + if (!tls::update_server_cert(ctx_, cert_pem, key_pem, password)) { + return false; + } + if (client_ca_pem) { + return tls::update_server_client_ca(ctx_, client_ca_pem); + } + return true; +} + +// SSL HTTP client implementation +SSLClient::~SSLClient() { + if (ctx_) { tls::free_context(ctx_); } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + shutdown_ssl_impl(socket_, true); +} + +bool SSLClient::is_valid() const { return ctx_ != nullptr; } + +void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + shutdown_ssl_impl(socket, shutdown_gracefully); +} + +void SSLClient::shutdown_ssl_impl(Socket &socket, + bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } + if (socket.ssl) { + tls::shutdown(socket.ssl, shutdown_gracefully); + { + std::lock_guard guard(ctx_mutex_); + tls::free_session(socket.ssl); + } + socket.ssl = nullptr; + } + assert(socket.ssl == nullptr); +} + +bool SSLClient::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { + assert(socket.ssl); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time, + std::move(callback)); +} + +bool SSLClient::is_ssl() const { return true; } + +bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { + if (!is_valid()) { + error = Error::SSLConnection; + return false; + } + return ClientImpl::create_and_connect_socket(socket, error); +} + +// Assumes that socket_mutex_ is locked and that there are no requests in +// flight +bool SSLClient::connect_with_proxy( + Socket &socket, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error) { + success = true; + Response proxy_res; + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = + detail::make_host_and_port_string_always_port(host_, port_); + if (max_timeout_msec_ > 0) { + req2.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req2, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are no + // requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + + if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(proxy_res, auth, true)) { + // Close the current socket and create a new one for the authenticated + // request + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + + // Create a new socket for the authenticated CONNECT request + if (!ensure_socket_connection(socket, error)) { + success = false; + output_error_log(error, nullptr); + return false; + } + + proxy_res = Response(); + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = detail::make_host_and_port_string_always_port( + host_, port_); + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + if (max_timeout_msec_ > 0) { + req3.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req3, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + } + } + } + + // If status code is not 200, proxy request is failed. + // Set error to ProxyConnection and return proxy response + // as the response of the request + if (proxy_res.status != StatusCode::OK_200) { + error = Error::ProxyConnection; + output_error_log(error, nullptr); + res = std::move(proxy_res); + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +bool SSLClient::ensure_socket_connection(Socket &socket, Error &error) { + if (!ClientImpl::ensure_socket_connection(socket, error)) { return false; } + + if (!proxy_host_.empty() && proxy_port_ != -1) { return true; } + + if (!initialize_ssl(socket, error)) { + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +// SSL HTTP client implementation +SSLClient::SSLClient(const std::string &host) + : SSLClient(host, 443, std::string(), std::string()) {} + +SSLClient::SSLClient(const std::string &host, int port) + : SSLClient(host, port, std::string(), std::string()) {} + +SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password) + : ClientImpl(host, port, client_cert_path, client_key_path) { + ctx_ = tls::create_client_context(); + if (!ctx_) { return; } + + tls::set_min_version(ctx_, tls::Version::TLS1_2); + + if (!client_cert_path.empty() && !client_key_path.empty()) { + const char *password = + private_key_password.empty() ? nullptr : private_key_password.c_str(); + if (!tls::set_client_cert_file(ctx_, client_cert_path.c_str(), + client_key_path.c_str(), password)) { + last_backend_error_ = tls::get_error(); + tls::free_context(ctx_); + ctx_ = nullptr; + } + } +} + +SSLClient::SSLClient(const std::string &host, int port, + const PemMemory &pem) + : ClientImpl(host, port) { + ctx_ = tls::create_client_context(); + if (!ctx_) { return; } + + tls::set_min_version(ctx_, tls::Version::TLS1_2); + + if (pem.cert_pem && pem.key_pem) { + if (!tls::set_client_cert_pem(ctx_, pem.cert_pem, pem.key_pem, + pem.private_key_password)) { + last_backend_error_ = tls::get_error(); + tls::free_context(ctx_); + ctx_ = nullptr; + } + } +} + +void SSLClient::set_ca_cert_store(tls::ca_store_t ca_cert_store) { + if (ca_cert_store && ctx_) { + // set_ca_store takes ownership of ca_cert_store + tls::set_ca_store(ctx_, ca_cert_store); + } else if (ca_cert_store) { + tls::free_ca_store(ca_cert_store); + } +} + +void +SSLClient::set_server_certificate_verifier(tls::VerifyCallback verifier) { + if (!ctx_) { return; } + tls::set_verify_callback(ctx_, verifier); +} + +void SSLClient::set_session_verifier( + std::function verifier) { + session_verifier_ = std::move(verifier); +} + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +void SSLClient::enable_windows_certificate_verification(bool enabled) { + enable_windows_cert_verification_ = enabled; +} +#endif + +void SSLClient::load_ca_cert_store(const char *ca_cert, + std::size_t size) { + if (ctx_ && ca_cert && size > 0) { + ca_cert_pem_.assign(ca_cert, size); // Store for redirect transfer + tls::load_ca_pem(ctx_, ca_cert, size); + } +} + +bool SSLClient::load_certs() { + auto ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + + if (!ca_cert_file_path_.empty()) { + if (!tls::load_ca_file(ctx_, ca_cert_file_path_.c_str())) { + last_backend_error_ = tls::get_error(); + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!tls::load_ca_dir(ctx_, ca_cert_dir_path_.c_str())) { + last_backend_error_ = tls::get_error(); + ret = false; + } + } else if (ca_cert_pem_.empty()) { + if (!tls::load_system_certs(ctx_)) { + last_backend_error_ = tls::get_error(); + } + } + }); + + return ret; +} + +bool SSLClient::initialize_ssl(Socket &socket, Error &error) { + using namespace tls; + + // Load CA certificates if server verification is enabled + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + output_error_log(error, nullptr); + return false; + } + } + + bool is_ip = detail::is_ip_address(host_); + +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT + // MbedTLS needs explicit verification mode (OpenSSL uses SSL_VERIFY_NONE + // by default and performs all verification post-handshake). + // For IP addresses with verification enabled, use OPTIONAL mode since + // MbedTLS requires hostname for VERIFY_REQUIRED. + if (is_ip && server_certificate_verification_) { + set_verify_client(ctx_, false); + } else { + set_verify_client(ctx_, server_certificate_verification_); + } +#endif + + // Create TLS session + session_t session = nullptr; + { + std::lock_guard guard(ctx_mutex_); + session = create_session(ctx_, socket.sock); + } + + if (!session) { + error = Error::SSLConnection; + last_backend_error_ = get_error(); + return false; + } + + // Use scope_exit to ensure session is freed on error paths + bool success = false; + auto session_guard = detail::scope_exit([&] { + if (!success) { free_session(session); } + }); + + // Set SNI extension (skip for IP addresses per RFC 6066). + // On MbedTLS, set_sni also enables hostname verification internally. + // On OpenSSL, set_sni only sets SNI; verification is done post-handshake. + if (!is_ip) { + if (!set_sni(session, host_.c_str())) { + error = Error::SSLConnection; + last_backend_error_ = get_error(); + return false; + } + } + + // Perform non-blocking TLS handshake with timeout + TlsError tls_err; + if (!connect_nonblocking(session, socket.sock, connection_timeout_sec_, + connection_timeout_usec_, &tls_err)) { + last_ssl_error_ = static_cast(tls_err.code); + last_backend_error_ = tls_err.backend_code; + if (tls_err.code == ErrorCode::CertVerifyFailed) { + error = Error::SSLServerVerification; + } else if (tls_err.code == ErrorCode::HostnameMismatch) { + error = Error::SSLServerHostnameVerification; + } else { + error = Error::SSLConnection; + } + output_error_log(error, nullptr); + return false; + } + + // Post-handshake session verifier callback + auto verification_status = SSLVerifierResponse::NoDecisionMade; + if (session_verifier_) { verification_status = session_verifier_(session); } + + if (verification_status == SSLVerifierResponse::CertificateRejected) { + last_backend_error_ = get_error(); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + + // Default server certificate verification + if (verification_status == SSLVerifierResponse::NoDecisionMade && + server_certificate_verification_) { + verify_result_ = tls::get_verify_result(session); + if (verify_result_ != 0) { + last_backend_error_ = static_cast(verify_result_); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + + auto server_cert = get_peer_cert(session); + if (!server_cert) { + last_backend_error_ = get_error(); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + auto cert_guard = detail::scope_exit([&] { free_cert(server_cert); }); + + // Hostname verification (post-handshake for all cases). + // On OpenSSL, verification is always post-handshake (SSL_VERIFY_NONE). + // On MbedTLS, set_sni already enabled hostname verification during + // handshake for non-IP hosts, but this check is still needed for IP + // addresses where SNI is not set. + if (server_hostname_verification_) { + if (!verify_hostname(server_cert, host_.c_str())) { + last_backend_error_ = hostname_mismatch_code(); + error = Error::SSLServerHostnameVerification; + output_error_log(error, nullptr); + return false; + } + } + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) + // Additional Windows Schannel verification. + // This provides real-time certificate validation with Windows Update + // integration, working with both OpenSSL and MbedTLS backends. + // Skip when a custom CA cert is specified, as the Windows certificate + // store would not know about user-provided CA certificates. + if (enable_windows_cert_verification_ && ca_cert_file_path_.empty() && + ca_cert_dir_path_.empty() && ca_cert_pem_.empty()) { + std::vector der; + if (get_cert_der(server_cert, der)) { + unsigned long wincrypt_error = 0; + if (!detail::verify_cert_with_windows_schannel( + der, host_, server_hostname_verification_, wincrypt_error)) { + last_backend_error_ = wincrypt_error; + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + } + } +#endif + } + + success = true; + socket.ssl = session; + return true; +} + +void Client::set_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_digest_auth(username, password); +} + +void Client::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_digest_auth(username, password); +} + +void Client::enable_server_certificate_verification(bool enabled) { + cli_->enable_server_certificate_verification(enabled); +} + +void Client::enable_server_hostname_verification(bool enabled) { + cli_->enable_server_hostname_verification(enabled); +} + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +void Client::enable_windows_certificate_verification(bool enabled) { + if (is_ssl_) { + static_cast(*cli_).enable_windows_certificate_verification( + enabled); + } +} +#endif + void Client::set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path) { cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path); } -void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { +void Client::set_ca_cert_store(tls::ca_store_t ca_cert_store) { if (is_ssl_) { static_cast(*cli_).set_ca_cert_store(ca_cert_store); - } else { - cli_->set_ca_cert_store(ca_cert_store); + } else if (ca_cert_store) { + tls::free_ca_store(ca_cert_store); } } void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { - set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size)); + set_ca_cert_store(tls::create_ca_store(ca_cert, size)); } -long Client::get_openssl_verify_result() const { +void +Client::set_server_certificate_verifier(tls::VerifyCallback verifier) { if (is_ssl_) { - return static_cast(*cli_).get_openssl_verify_result(); + static_cast(*cli_).set_server_certificate_verifier( + std::move(verifier)); } - return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? } +void Client::set_session_verifier( + std::function verifier) { + if (is_ssl_) { + static_cast(*cli_).set_session_verifier(std::move(verifier)); + } +} + +tls::ctx_t Client::tls_context() const { + if (is_ssl_) { return static_cast(*cli_).tls_context(); } + return nullptr; +} + +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 7: TLS abstraction layer - Common API + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED + +namespace tls { + +// Helper for PeerCert construction +PeerCert get_peer_cert_from_session(const_session_t session) { + return PeerCert(get_peer_cert(session)); +} + +namespace impl { + +VerifyCallback &get_verify_callback() { + static thread_local VerifyCallback callback; + return callback; +} + +VerifyCallback &get_mbedtls_verify_callback() { + static thread_local VerifyCallback callback; + return callback; +} + +} // namespace impl + +bool set_client_ca_file(ctx_t ctx, const char *ca_file, + const char *ca_dir) { + if (!ctx) { return false; } + + bool success = true; + if (ca_file && *ca_file) { + if (!load_ca_file(ctx, ca_file)) { success = false; } + } + if (ca_dir && *ca_dir) { + if (!load_ca_dir(ctx, ca_dir)) { success = false; } + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // Set CA list for client certificate request (CertificateRequest message) + if (ca_file && *ca_file) { + auto list = SSL_load_client_CA_file(ca_file); + if (list) { SSL_CTX_set_client_CA_list(static_cast(ctx), list); } + } +#endif + + return success; +} + +bool set_server_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + return set_client_cert_pem(ctx, cert, key, password); +} + +bool set_server_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + return set_client_cert_file(ctx, cert_path, key_path, password); +} + +// PeerCert implementation +PeerCert::PeerCert() = default; + +PeerCert::PeerCert(cert_t cert) : cert_(cert) {} + +PeerCert::PeerCert(PeerCert &&other) noexcept : cert_(other.cert_) { + other.cert_ = nullptr; +} + +PeerCert &PeerCert::operator=(PeerCert &&other) noexcept { + if (this != &other) { + if (cert_) { free_cert(cert_); } + cert_ = other.cert_; + other.cert_ = nullptr; + } + return *this; +} + +PeerCert::~PeerCert() { + if (cert_) { free_cert(cert_); } +} + +PeerCert::operator bool() const { return cert_ != nullptr; } + +std::string PeerCert::subject_cn() const { + return cert_ ? get_cert_subject_cn(cert_) : std::string(); +} + +std::string PeerCert::issuer_name() const { + return cert_ ? get_cert_issuer_name(cert_) : std::string(); +} + +bool PeerCert::check_hostname(const char *hostname) const { + return cert_ ? verify_hostname(cert_, hostname) : false; +} + +std::vector PeerCert::sans() const { + std::vector result; + if (cert_) { get_cert_sans(cert_, result); } + return result; +} + +bool PeerCert::validity(time_t ¬_before, time_t ¬_after) const { + return cert_ ? get_cert_validity(cert_, not_before, not_after) : false; +} + +std::string PeerCert::serial() const { + return cert_ ? get_cert_serial(cert_) : std::string(); +} + +// VerifyContext method implementations +std::string VerifyContext::subject_cn() const { + return cert ? get_cert_subject_cn(cert) : std::string(); +} + +std::string VerifyContext::issuer_name() const { + return cert ? get_cert_issuer_name(cert) : std::string(); +} + +bool VerifyContext::check_hostname(const char *hostname) const { + return cert ? verify_hostname(cert, hostname) : false; +} + +std::vector VerifyContext::sans() const { + std::vector result; + if (cert) { get_cert_sans(cert, result); } + return result; +} + +bool VerifyContext::validity(time_t ¬_before, + time_t ¬_after) const { + return cert ? get_cert_validity(cert, not_before, not_after) : false; +} + +std::string VerifyContext::serial() const { + return cert ? get_cert_serial(cert) : std::string(); +} + +// TlsError static method implementation +std::string TlsError::verify_error_to_string(long error_code) { + return verify_error_string(error_code); +} + +} // namespace tls + +// Request::peer_cert() implementation +tls::PeerCert Request::peer_cert() const { + return tls::get_peer_cert_from_session(ssl); +} + +// Request::sni() implementation +std::string Request::sni() const { + if (!ssl) { return std::string(); } + const char *s = tls::get_sni(ssl); + return s ? std::string(s) : std::string(); +} + +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 8: TLS abstraction layer - OpenSSL backend + */ + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT SSL_CTX *Client::ssl_context() const { if (is_ssl_) { return static_cast(*cli_).ssl_context(); } return nullptr; } + +void Client::set_server_certificate_verifier( + std::function verifier) { + cli_->set_server_certificate_verifier(verifier); +} + +long Client::get_verify_result() const { + if (is_ssl_) { return static_cast(*cli_).get_verify_result(); } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? +} +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +/* + * OpenSSL Backend Implementation + */ + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace tls { + +namespace impl { + +// OpenSSL-specific helpers for converting native types to PEM +std::string x509_to_pem(X509 *cert) { + if (!cert) return {}; + BIO *bio = BIO_new(BIO_s_mem()); + if (!bio) return {}; + if (PEM_write_bio_X509(bio, cert) != 1) { + BIO_free(bio); + return {}; + } + char *data = nullptr; + long len = BIO_get_mem_data(bio, &data); + std::string pem(data, static_cast(len)); + BIO_free(bio); + return pem; +} + +std::string evp_pkey_to_pem(EVP_PKEY *key) { + if (!key) return {}; + BIO *bio = BIO_new(BIO_s_mem()); + if (!bio) return {}; + if (PEM_write_bio_PrivateKey(bio, key, nullptr, nullptr, 0, nullptr, + nullptr) != 1) { + BIO_free(bio); + return {}; + } + char *data = nullptr; + long len = BIO_get_mem_data(bio, &data); + std::string pem(data, static_cast(len)); + BIO_free(bio); + return pem; +} + +std::string x509_store_to_pem(X509_STORE *store) { + if (!store) return {}; + std::string pem; + auto objs = X509_STORE_get0_objects(store); + if (!objs) return {}; + auto count = sk_X509_OBJECT_num(objs); + for (decltype(count) i = 0; i < count; i++) { + auto obj = sk_X509_OBJECT_value(objs, i); + if (X509_OBJECT_get_type(obj) == X509_LU_X509) { + auto cert = X509_OBJECT_get0_X509(obj); + if (cert) { pem += x509_to_pem(cert); } + } + } + return pem; +} + +// Helper to map OpenSSL SSL_get_error to ErrorCode +ErrorCode map_ssl_error(int ssl_error, int &out_errno) { + switch (ssl_error) { + case SSL_ERROR_NONE: return ErrorCode::Success; + case SSL_ERROR_WANT_READ: return ErrorCode::WantRead; + case SSL_ERROR_WANT_WRITE: return ErrorCode::WantWrite; + case SSL_ERROR_ZERO_RETURN: return ErrorCode::PeerClosed; + case SSL_ERROR_SYSCALL: out_errno = errno; return ErrorCode::SyscallError; + case SSL_ERROR_SSL: + default: return ErrorCode::Fatal; + } +} + +// Helper: Create client CA list from PEM string +// Returns a new STACK_OF(X509_NAME)* or nullptr on failure +// Caller takes ownership of returned list +STACK_OF(X509_NAME) * + create_client_ca_list_from_pem(const char *ca_pem) { + if (!ca_pem) { return nullptr; } + + auto ca_list = sk_X509_NAME_new_null(); + if (!ca_list) { return nullptr; } + + BIO *bio = BIO_new_mem_buf(ca_pem, -1); + if (!bio) { + sk_X509_NAME_pop_free(ca_list, X509_NAME_free); + return nullptr; + } + + X509 *cert = nullptr; + while ((cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr)) != + nullptr) { + X509_NAME *name = X509_get_subject_name(cert); + if (name) { sk_X509_NAME_push(ca_list, X509_NAME_dup(name)); } + X509_free(cert); + } + BIO_free(bio); + + return ca_list; +} + +// Helper: Extract CA names from X509_STORE +// Returns a new STACK_OF(X509_NAME)* or nullptr on failure +// Caller takes ownership of returned list +STACK_OF(X509_NAME) * + extract_client_ca_list_from_store(X509_STORE *store) { + if (!store) { return nullptr; } + + auto ca_list = sk_X509_NAME_new_null(); + if (!ca_list) { return nullptr; } + + auto objs = X509_STORE_get0_objects(store); + if (!objs) { + sk_X509_NAME_free(ca_list); + return nullptr; + } + + auto count = sk_X509_OBJECT_num(objs); + for (decltype(count) i = 0; i < count; i++) { + auto obj = sk_X509_OBJECT_value(objs, i); + if (X509_OBJECT_get_type(obj) == X509_LU_X509) { + auto cert = X509_OBJECT_get0_X509(obj); + if (cert) { + auto subject = X509_get_subject_name(cert); + if (subject) { + auto name_dup = X509_NAME_dup(subject); + if (name_dup) { sk_X509_NAME_push(ca_list, name_dup); } + } + } + } + } + + if (sk_X509_NAME_num(ca_list) == 0) { + sk_X509_NAME_free(ca_list); + return nullptr; + } + + return ca_list; +} + +// OpenSSL verify callback wrapper +int openssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { + auto &callback = get_verify_callback(); + if (!callback) { return preverify_ok; } + + // Get SSL object from X509_STORE_CTX + auto ssl = static_cast( + X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); + if (!ssl) { return preverify_ok; } + + // Get current certificate and depth + auto cert = X509_STORE_CTX_get_current_cert(ctx); + int depth = X509_STORE_CTX_get_error_depth(ctx); + int error = X509_STORE_CTX_get_error(ctx); + + // Build context + VerifyContext verify_ctx; + verify_ctx.session = static_cast(ssl); + verify_ctx.cert = static_cast(cert); + verify_ctx.depth = depth; + verify_ctx.preverify_ok = (preverify_ok != 0); + verify_ctx.error_code = error; + verify_ctx.error_string = + (error != X509_V_OK) ? X509_verify_cert_error_string(error) : nullptr; + + return callback(verify_ctx) ? 1 : 0; +} + +} // namespace impl + +ctx_t create_client_context() { + SSL_CTX *ctx = SSL_CTX_new(TLS_client_method()); + if (ctx) { + // Disable auto-retry to properly handle non-blocking I/O + SSL_CTX_clear_mode(ctx, SSL_MODE_AUTO_RETRY); + // Set minimum TLS version + SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); + } + return static_cast(ctx); +} + +void free_context(ctx_t ctx) { + if (ctx) { SSL_CTX_free(static_cast(ctx)); } +} + +bool set_min_version(ctx_t ctx, Version version) { + if (!ctx) return false; + return SSL_CTX_set_min_proto_version(static_cast(ctx), + static_cast(version)) == 1; +} + +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len) { + if (!ctx || !pem || len == 0) return false; + + auto ssl_ctx = static_cast(ctx); + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) return false; + + auto bio = BIO_new_mem_buf(pem, static_cast(len)); + if (!bio) return false; + + bool ok = true; + X509 *cert = nullptr; + while ((cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr)) != + nullptr) { + if (X509_STORE_add_cert(store, cert) != 1) { + // Ignore duplicate errors + auto err = ERR_peek_last_error(); + if (ERR_GET_REASON(err) != X509_R_CERT_ALREADY_IN_HASH_TABLE) { + ok = false; + } + } + X509_free(cert); + if (!ok) break; + } + BIO_free(bio); + + // Clear any "no more certificates" errors + ERR_clear_error(); + return ok; +} + +bool load_ca_file(ctx_t ctx, const char *file_path) { + if (!ctx || !file_path) return false; + return SSL_CTX_load_verify_locations(static_cast(ctx), file_path, + nullptr) == 1; +} + +bool load_ca_dir(ctx_t ctx, const char *dir_path) { + if (!ctx || !dir_path) return false; + return SSL_CTX_load_verify_locations(static_cast(ctx), nullptr, + dir_path) == 1; +} + +bool load_system_certs(ctx_t ctx) { + if (!ctx) return false; + auto ssl_ctx = static_cast(ctx); + +#ifdef _WIN32 + // Windows: Load from system certificate store (ROOT and CA) + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) return false; + + bool loaded_any = false; + static const wchar_t *store_names[] = {L"ROOT", L"CA"}; + for (auto store_name : store_names) { + auto hStore = CertOpenSystemStoreW(NULL, store_name); + if (!hStore) continue; + + PCCERT_CONTEXT pContext = nullptr; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + const unsigned char *data = pContext->pbCertEncoded; + auto x509 = d2i_X509(nullptr, &data, pContext->cbCertEncoded); + if (x509) { + if (X509_STORE_add_cert(store, x509) == 1) { loaded_any = true; } + X509_free(x509); + } + } + CertCloseStore(hStore, 0); + } + return loaded_any; + +#elif defined(__APPLE__) +#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN + // macOS: Load from Keychain + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) return false; + + CFArrayRef certs = nullptr; + if (SecTrustCopyAnchorCertificates(&certs) != errSecSuccess || !certs) { + return SSL_CTX_set_default_verify_paths(ssl_ctx) == 1; + } + + bool loaded_any = false; + auto count = CFArrayGetCount(certs); + for (CFIndex i = 0; i < count; i++) { + auto cert = reinterpret_cast( + const_cast(CFArrayGetValueAtIndex(certs, i))); + CFDataRef der = SecCertificateCopyData(cert); + if (der) { + const unsigned char *data = CFDataGetBytePtr(der); + auto x509 = d2i_X509(nullptr, &data, CFDataGetLength(der)); + if (x509) { + if (X509_STORE_add_cert(store, x509) == 1) { loaded_any = true; } + X509_free(x509); + } + CFRelease(der); + } + } + CFRelease(certs); + return loaded_any || SSL_CTX_set_default_verify_paths(ssl_ctx) == 1; +#else + return SSL_CTX_set_default_verify_paths(ssl_ctx) == 1; #endif +#else + // Other Unix: use default verify paths + return SSL_CTX_set_default_verify_paths(ssl_ctx) == 1; +#endif +} + +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + if (!ctx || !cert || !key) return false; + + auto ssl_ctx = static_cast(ctx); + + // Load certificate + auto cert_bio = BIO_new_mem_buf(cert, -1); + if (!cert_bio) return false; + + auto x509 = PEM_read_bio_X509(cert_bio, nullptr, nullptr, nullptr); + BIO_free(cert_bio); + if (!x509) return false; + + auto cert_ok = SSL_CTX_use_certificate(ssl_ctx, x509) == 1; + X509_free(x509); + if (!cert_ok) return false; + + // Load private key + auto key_bio = BIO_new_mem_buf(key, -1); + if (!key_bio) return false; + + auto pkey = PEM_read_bio_PrivateKey(key_bio, nullptr, nullptr, + password ? const_cast(password) + : nullptr); + BIO_free(key_bio); + if (!pkey) return false; + + auto key_ok = SSL_CTX_use_PrivateKey(ssl_ctx, pkey) == 1; + EVP_PKEY_free(pkey); + + return key_ok && SSL_CTX_check_private_key(ssl_ctx) == 1; +} + +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + if (!ctx || !cert_path || !key_path) return false; + + auto ssl_ctx = static_cast(ctx); + + if (password && password[0] != '\0') { + SSL_CTX_set_default_passwd_cb_userdata( + ssl_ctx, reinterpret_cast(const_cast(password))); + } + + return SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_path) == 1 && + SSL_CTX_use_PrivateKey_file(ssl_ctx, key_path, SSL_FILETYPE_PEM) == 1; +} + +ctx_t create_server_context() { + SSL_CTX *ctx = SSL_CTX_new(TLS_server_method()); + if (ctx) { + SSL_CTX_set_options(ctx, SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); + } + return static_cast(ctx); +} + +void set_verify_client(ctx_t ctx, bool require) { + if (!ctx) return; + SSL_CTX_set_verify(static_cast(ctx), + require + ? (SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT) + : SSL_VERIFY_NONE, + nullptr); +} + +session_t create_session(ctx_t ctx, socket_t sock) { + if (!ctx || sock == INVALID_SOCKET) return nullptr; + + auto ssl_ctx = static_cast(ctx); + SSL *ssl = SSL_new(ssl_ctx); + if (!ssl) return nullptr; + + // Disable auto-retry for proper non-blocking I/O handling + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); + + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + if (!bio) { + SSL_free(ssl); + return nullptr; + } + + SSL_set_bio(ssl, bio, bio); + return static_cast(ssl); +} + +void free_session(session_t session) { + if (session) { SSL_free(static_cast(session)); } +} + +bool set_sni(session_t session, const char *hostname) { + if (!session || !hostname) return false; + + auto ssl = static_cast(session); + + // Set SNI (Server Name Indication) only - does not enable verification +#if defined(OPENSSL_IS_BORINGSSL) + return SSL_set_tlsext_host_name(ssl, hostname) == 1; +#else + // Direct call instead of macro to suppress -Wold-style-cast warning + return SSL_ctrl(ssl, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, + static_cast(const_cast(hostname))) == 1; +#endif +} + +bool set_hostname(session_t session, const char *hostname) { + if (!session || !hostname) return false; + + auto ssl = static_cast(session); + + // Set SNI (Server Name Indication) + if (!set_sni(session, hostname)) { return false; } + + // Enable hostname verification + auto param = SSL_get0_param(ssl); + if (!param) return false; + + X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); + if (X509_VERIFY_PARAM_set1_host(param, hostname, 0) != 1) { return false; } + + SSL_set_verify(ssl, SSL_VERIFY_PEER, nullptr); + return true; +} + +TlsError connect(session_t session) { + if (!session) { return TlsError(); } + + auto ssl = static_cast(session); + auto ret = SSL_connect(ssl); + + TlsError err; + if (ret == 1) { + err.code = ErrorCode::Success; + } else { + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + err.backend_code = ERR_get_error(); + } + return err; +} + +TlsError accept(session_t session) { + if (!session) { return TlsError(); } + + auto ssl = static_cast(session); + auto ret = SSL_accept(ssl); + + TlsError err; + if (ret == 1) { + err.code = ErrorCode::Success; + } else { + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + err.backend_code = ERR_get_error(); + } + return err; +} + +bool connect_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto ssl = static_cast(session); + auto bio = SSL_get_rbio(ssl); + + // Set non-blocking mode for handshake + detail::set_nonblocking(sock, true); + if (bio) { BIO_set_nbio(bio, 1); } + + auto cleanup = detail::scope_exit([&]() { + // Restore blocking mode after handshake + if (bio) { BIO_set_nbio(bio, 0); } + detail::set_nonblocking(sock, false); + }); + + auto res = 0; + while ((res = SSL_connect(ssl)) != 1) { + auto ssl_err = SSL_get_error(ssl, res); + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + case SSL_ERROR_WANT_WRITE: + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + default: break; + } + if (err) { + err->code = impl::map_ssl_error(ssl_err, err->sys_errno); + err->backend_code = ERR_get_error(); + } + return false; + } + if (err) { err->code = ErrorCode::Success; } + return true; +} + +bool accept_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto ssl = static_cast(session); + auto bio = SSL_get_rbio(ssl); + + // Set non-blocking mode for handshake + detail::set_nonblocking(sock, true); + if (bio) { BIO_set_nbio(bio, 1); } + + auto cleanup = detail::scope_exit([&]() { + // Restore blocking mode after handshake + if (bio) { BIO_set_nbio(bio, 0); } + detail::set_nonblocking(sock, false); + }); + + auto res = 0; + while ((res = SSL_accept(ssl)) != 1) { + auto ssl_err = SSL_get_error(ssl, res); + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + case SSL_ERROR_WANT_WRITE: + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + default: break; + } + if (err) { + err->code = impl::map_ssl_error(ssl_err, err->sys_errno); + err->backend_code = ERR_get_error(); + } + return false; + } + if (err) { err->code = ErrorCode::Success; } + return true; +} + +ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto ssl = static_cast(session); + constexpr auto max_len = + static_cast((std::numeric_limits::max)()); + if (len > max_len) { len = max_len; } + auto ret = SSL_read(ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return ret; + } + + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + if (err.code == ErrorCode::Fatal) { err.backend_code = ERR_get_error(); } + return -1; +} + +ssize_t write(session_t session, const void *buf, size_t len, + TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto ssl = static_cast(session); + auto ret = SSL_write(ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return ret; + } + + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + if (err.code == ErrorCode::Fatal) { err.backend_code = ERR_get_error(); } + return -1; +} + +int pending(const_session_t session) { + if (!session) return 0; + return SSL_pending(static_cast(const_cast(session))); +} + +void shutdown(session_t session, bool graceful) { + if (!session) return; + + auto ssl = static_cast(session); + if (graceful) { + // First call sends close_notify + if (SSL_shutdown(ssl) == 0) { + // Second call waits for peer's close_notify + SSL_shutdown(ssl); + } + } +} + +bool is_peer_closed(session_t session, socket_t sock) { + if (!session) return true; + + // Temporarily set socket to non-blocking to avoid blocking on SSL_peek + detail::set_nonblocking(sock, true); + auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + auto ssl = static_cast(session); + char buf; + auto ret = SSL_peek(ssl, &buf, 1); + if (ret > 0) return false; + + auto err = SSL_get_error(ssl, ret); + return err == SSL_ERROR_ZERO_RETURN; +} + +cert_t get_peer_cert(const_session_t session) { + if (!session) return nullptr; + return static_cast(SSL_get1_peer_certificate( + static_cast(const_cast(session)))); +} + +void free_cert(cert_t cert) { + if (cert) { X509_free(static_cast(cert)); } +} + +bool verify_hostname(cert_t cert, const char *hostname) { + if (!cert || !hostname) return false; + + auto x509 = static_cast(cert); + + // Use X509_check_ip_asc for IP addresses, X509_check_host for DNS names + if (detail::is_ip_address(hostname)) { + return X509_check_ip_asc(x509, hostname, 0) == 1; + } + return X509_check_host(x509, hostname, strlen(hostname), 0, nullptr) == 1; +} + +uint64_t hostname_mismatch_code() { + return static_cast(X509_V_ERR_HOSTNAME_MISMATCH); +} + +long get_verify_result(const_session_t session) { + if (!session) return X509_V_ERR_UNSPECIFIED; + return SSL_get_verify_result(static_cast(const_cast(session))); +} + +std::string get_cert_subject_cn(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + auto subject_name = X509_get_subject_name(x509); + if (!subject_name) return ""; + + char buf[256]; + auto len = + X509_NAME_get_text_by_NID(subject_name, NID_commonName, buf, sizeof(buf)); + if (len < 0) return ""; + return std::string(buf, static_cast(len)); +} + +std::string get_cert_issuer_name(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + auto issuer_name = X509_get_issuer_name(x509); + if (!issuer_name) return ""; + + char buf[256]; + X509_NAME_oneline(issuer_name, buf, sizeof(buf)); + return std::string(buf); +} + +bool get_cert_sans(cert_t cert, std::vector &sans) { + sans.clear(); + if (!cert) return false; + auto x509 = static_cast(cert); + + auto names = static_cast( + X509_get_ext_d2i(x509, NID_subject_alt_name, nullptr, nullptr)); + if (!names) return true; // No SANs is valid + + auto count = sk_GENERAL_NAME_num(names); + for (int i = 0; i < count; i++) { + auto gen = sk_GENERAL_NAME_value(names, i); + if (!gen) continue; + + SanEntry entry; + switch (gen->type) { + case GEN_DNS: + entry.type = SanType::DNS; + if (gen->d.dNSName) { + entry.value = std::string( + reinterpret_cast( + ASN1_STRING_get0_data(gen->d.dNSName)), + static_cast(ASN1_STRING_length(gen->d.dNSName))); + } + break; + case GEN_IPADD: + entry.type = SanType::IP; + if (gen->d.iPAddress) { + auto data = ASN1_STRING_get0_data(gen->d.iPAddress); + auto len = ASN1_STRING_length(gen->d.iPAddress); + if (len == 4) { + // IPv4 + char buf[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, data, buf, sizeof(buf)); + entry.value = buf; + } else if (len == 16) { + // IPv6 + char buf[INET6_ADDRSTRLEN]; + inet_ntop(AF_INET6, data, buf, sizeof(buf)); + entry.value = buf; + } + } + break; + case GEN_EMAIL: + entry.type = SanType::EMAIL; + if (gen->d.rfc822Name) { + entry.value = std::string( + reinterpret_cast( + ASN1_STRING_get0_data(gen->d.rfc822Name)), + static_cast(ASN1_STRING_length(gen->d.rfc822Name))); + } + break; + case GEN_URI: + entry.type = SanType::URI; + if (gen->d.uniformResourceIdentifier) { + entry.value = std::string( + reinterpret_cast( + ASN1_STRING_get0_data(gen->d.uniformResourceIdentifier)), + static_cast( + ASN1_STRING_length(gen->d.uniformResourceIdentifier))); + } + break; + default: entry.type = SanType::OTHER; break; + } + + if (!entry.value.empty()) { sans.push_back(std::move(entry)); } + } + + GENERAL_NAMES_free(names); + return true; +} + +bool get_cert_validity(cert_t cert, time_t ¬_before, + time_t ¬_after) { + if (!cert) return false; + auto x509 = static_cast(cert); + + auto nb = X509_get0_notBefore(x509); + auto na = X509_get0_notAfter(x509); + if (!nb || !na) return false; + + ASN1_TIME *epoch = ASN1_TIME_new(); + if (!epoch) return false; + auto se = detail::scope_exit([&] { ASN1_TIME_free(epoch); }); + + if (!ASN1_TIME_set(epoch, 0)) return false; + + int pday, psec; + + if (!ASN1_TIME_diff(&pday, &psec, epoch, nb)) return false; + not_before = 86400 * (time_t)pday + psec; + + if (!ASN1_TIME_diff(&pday, &psec, epoch, na)) return false; + not_after = 86400 * (time_t)pday + psec; + + return true; +} + +std::string get_cert_serial(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + auto serial = X509_get_serialNumber(x509); + if (!serial) return ""; + + auto bn = ASN1_INTEGER_to_BN(serial, nullptr); + if (!bn) return ""; + + auto hex = BN_bn2hex(bn); + BN_free(bn); + if (!hex) return ""; + + std::string result(hex); + OPENSSL_free(hex); + return result; +} + +bool get_cert_der(cert_t cert, std::vector &der) { + if (!cert) return false; + auto x509 = static_cast(cert); + auto len = i2d_X509(x509, nullptr); + if (len < 0) return false; + der.resize(static_cast(len)); + auto p = der.data(); + i2d_X509(x509, &p); + return true; +} + +const char *get_sni(const_session_t session) { + if (!session) return nullptr; + auto ssl = static_cast(const_cast(session)); + return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); +} + +uint64_t peek_error() { return ERR_peek_last_error(); } + +uint64_t get_error() { return ERR_get_error(); } + +std::string error_string(uint64_t code) { + char buf[256]; + ERR_error_string_n(static_cast(code), buf, sizeof(buf)); + return std::string(buf); +} + +ca_store_t create_ca_store(const char *pem, size_t len) { + auto mem = BIO_new_mem_buf(pem, static_cast(len)); + if (!mem) { return nullptr; } + auto mem_guard = detail::scope_exit([&] { BIO_free_all(mem); }); + + auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); + if (!inf) { return nullptr; } + + auto store = X509_STORE_new(); + if (store) { + for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { + auto itmp = sk_X509_INFO_value(inf, i); + if (!itmp) { continue; } + if (itmp->x509) { X509_STORE_add_cert(store, itmp->x509); } + if (itmp->crl) { X509_STORE_add_crl(store, itmp->crl); } + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + return static_cast(store); +} + +void free_ca_store(ca_store_t store) { + if (store) { X509_STORE_free(static_cast(store)); } +} + +bool set_ca_store(ctx_t ctx, ca_store_t store) { + if (!ctx || !store) { return false; } + auto ssl_ctx = static_cast(ctx); + auto x509_store = static_cast(store); + + // Check if same store is already set + if (SSL_CTX_get_cert_store(ssl_ctx) == x509_store) { return true; } + + // SSL_CTX_set_cert_store takes ownership and frees the old store + SSL_CTX_set_cert_store(ssl_ctx, x509_store); + return true; +} + +size_t get_ca_certs(ctx_t ctx, std::vector &certs) { + certs.clear(); + if (!ctx) { return 0; } + auto ssl_ctx = static_cast(ctx); + + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) { return 0; } + + auto objs = X509_STORE_get0_objects(store); + if (!objs) { return 0; } + + auto count = sk_X509_OBJECT_num(objs); + for (decltype(count) i = 0; i < count; i++) { + auto obj = sk_X509_OBJECT_value(objs, i); + if (!obj) { continue; } + if (X509_OBJECT_get_type(obj) == X509_LU_X509) { + auto x509 = X509_OBJECT_get0_X509(obj); + if (x509) { + // Increment reference count so caller can free it + X509_up_ref(x509); + certs.push_back(static_cast(x509)); + } + } + } + return certs.size(); +} + +std::vector get_ca_names(ctx_t ctx) { + std::vector names; + if (!ctx) { return names; } + auto ssl_ctx = static_cast(ctx); + + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) { return names; } + + auto objs = X509_STORE_get0_objects(store); + if (!objs) { return names; } + + auto count = sk_X509_OBJECT_num(objs); + for (decltype(count) i = 0; i < count; i++) { + auto obj = sk_X509_OBJECT_value(objs, i); + if (!obj) { continue; } + if (X509_OBJECT_get_type(obj) == X509_LU_X509) { + auto x509 = X509_OBJECT_get0_X509(obj); + if (x509) { + auto subject = X509_get_subject_name(x509); + if (subject) { + char buf[512]; + X509_NAME_oneline(subject, buf, sizeof(buf)); + names.push_back(buf); + } + } + } + } + return names; +} + +bool update_server_cert(ctx_t ctx, const char *cert_pem, + const char *key_pem, const char *password) { + if (!ctx || !cert_pem || !key_pem) { return false; } + auto ssl_ctx = static_cast(ctx); + + // Load certificate from PEM + auto cert_bio = BIO_new_mem_buf(cert_pem, -1); + if (!cert_bio) { return false; } + auto cert = PEM_read_bio_X509(cert_bio, nullptr, nullptr, nullptr); + BIO_free(cert_bio); + if (!cert) { return false; } + + // Load private key from PEM + auto key_bio = BIO_new_mem_buf(key_pem, -1); + if (!key_bio) { + X509_free(cert); + return false; + } + auto key = PEM_read_bio_PrivateKey(key_bio, nullptr, nullptr, + password ? const_cast(password) + : nullptr); + BIO_free(key_bio); + if (!key) { + X509_free(cert); + return false; + } + + // Update certificate and key + auto ret = SSL_CTX_use_certificate(ssl_ctx, cert) == 1 && + SSL_CTX_use_PrivateKey(ssl_ctx, key) == 1; + + X509_free(cert); + EVP_PKEY_free(key); + return ret; +} + +bool update_server_client_ca(ctx_t ctx, const char *ca_pem) { + if (!ctx || !ca_pem) { return false; } + auto ssl_ctx = static_cast(ctx); + + // Create new X509_STORE from PEM + auto store = create_ca_store(ca_pem, strlen(ca_pem)); + if (!store) { return false; } + + // SSL_CTX_set_cert_store takes ownership + SSL_CTX_set_cert_store(ssl_ctx, static_cast(store)); + + // Set client CA list for client certificate request + auto ca_list = impl::create_client_ca_list_from_pem(ca_pem); + if (ca_list) { + // SSL_CTX_set_client_CA_list takes ownership of ca_list + SSL_CTX_set_client_CA_list(ssl_ctx, ca_list); + } + + return true; +} + +bool set_verify_callback(ctx_t ctx, VerifyCallback callback) { + if (!ctx) { return false; } + auto ssl_ctx = static_cast(ctx); + + impl::get_verify_callback() = std::move(callback); + + if (impl::get_verify_callback()) { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, impl::openssl_verify_callback); + } else { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr); + } + return true; +} + +long get_verify_error(const_session_t session) { + if (!session) { return -1; } + auto ssl = static_cast(const_cast(session)); + return SSL_get_verify_result(ssl); +} + +std::string verify_error_string(long error_code) { + if (error_code == X509_V_OK) { return ""; } + const char *str = X509_verify_cert_error_string(static_cast(error_code)); + return str ? str : "unknown error"; +} + +namespace impl { + +// OpenSSL-specific helpers for public API wrappers +ctx_t create_server_context_from_x509(X509 *cert, EVP_PKEY *key, + X509_STORE *client_ca_store, + int &out_error) { + out_error = 0; + auto cert_pem = x509_to_pem(cert); + auto key_pem = evp_pkey_to_pem(key); + if (cert_pem.empty() || key_pem.empty()) { + out_error = static_cast(ERR_get_error()); + return nullptr; + } + + auto ctx = create_server_context(); + if (!ctx) { + out_error = static_cast(get_error()); + return nullptr; + } + + if (!set_server_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr)) { + out_error = static_cast(get_error()); + free_context(ctx); + return nullptr; + } + + if (client_ca_store) { + // Set cert store for verification (SSL_CTX_set_cert_store takes ownership) + SSL_CTX_set_cert_store(static_cast(ctx), client_ca_store); + + // Extract and set client CA list directly from store (more efficient than + // PEM conversion) + auto ca_list = extract_client_ca_list_from_store(client_ca_store); + if (ca_list) { + SSL_CTX_set_client_CA_list(static_cast(ctx), ca_list); + } + + set_verify_client(ctx, true); + } + + return ctx; +} + +void update_server_certs_from_x509(ctx_t ctx, X509 *cert, EVP_PKEY *key, + X509_STORE *client_ca_store) { + auto cert_pem = x509_to_pem(cert); + auto key_pem = evp_pkey_to_pem(key); + + if (!cert_pem.empty() && !key_pem.empty()) { + update_server_cert(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr); + } + + if (client_ca_store) { + auto ca_pem = x509_store_to_pem(client_ca_store); + if (!ca_pem.empty()) { update_server_client_ca(ctx, ca_pem.c_str()); } + X509_STORE_free(client_ca_store); + } +} + +ctx_t create_client_context_from_x509(X509 *cert, EVP_PKEY *key, + const char *password, + unsigned long &out_error) { + out_error = 0; + auto ctx = create_client_context(); + if (!ctx) { + out_error = static_cast(get_error()); + return nullptr; + } + + if (cert && key) { + auto cert_pem = x509_to_pem(cert); + auto key_pem = evp_pkey_to_pem(key); + if (cert_pem.empty() || key_pem.empty()) { + out_error = ERR_get_error(); + free_context(ctx); + return nullptr; + } + if (!set_client_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(), + password)) { + out_error = static_cast(get_error()); + free_context(ctx); + return nullptr; + } + } + + return ctx; +} + +} // namespace impl + +} // namespace tls + +// ClientImpl::set_ca_cert_store - defined here to use +// tls::impl::x509_store_to_pem Deprecated: converts X509_STORE to PEM and +// stores for redirect transfer +void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { + ca_cert_pem_ = tls::impl::x509_store_to_pem(ca_cert_store); + } +} + +SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + ctx_ = tls::impl::create_server_context_from_x509( + cert, private_key, client_ca_cert_store, last_ssl_error_); +} + +SSLServer::SSLServer( + const std::function &setup_ssl_ctx_callback) { + // Use abstract API to create context + ctx_ = tls::create_server_context(); + if (ctx_) { + // Pass to OpenSSL-specific callback (ctx_ is SSL_CTX* internally) + auto ssl_ctx = static_cast(ctx_); + if (!setup_ssl_ctx_callback(*ssl_ctx)) { + tls::free_context(ctx_); + ctx_ = nullptr; + } + } +} + +SSL_CTX *SSLServer::ssl_context() const { + return static_cast(ctx_); +} + +void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + std::lock_guard guard(ctx_mutex_); + tls::impl::update_server_certs_from_x509(ctx_, cert, private_key, + client_ca_cert_store); +} + +SSLClient::SSLClient(const std::string &host, int port, + X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password) + : ClientImpl(host, port) { + const char *password = + private_key_password.empty() ? nullptr : private_key_password.c_str(); + ctx_ = tls::impl::create_client_context_from_x509( + client_cert, client_key, password, last_backend_error_); +} + +long SSLClient::get_verify_result() const { return verify_result_; } + +void SSLClient::set_server_certificate_verifier( + std::function verifier) { + // Wrap SSL* callback into backend-independent session_verifier_ + auto v = std::make_shared>( + std::move(verifier)); + session_verifier_ = [v](tls::session_t session) { + return (*v)(static_cast(session)); + }; +} + +SSL_CTX *SSLClient::ssl_context() const { + return static_cast(ctx_); +} + +bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6 = {}; + struct in_addr addr = {}; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_matched = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (!val || val->type != type) { continue; } + + auto name = + reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + if (name == nullptr) { continue; } + + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: + dsn_matched = + detail::match_hostname(std::string(name, name_len), host_); + break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) { + ip_matched = true; + } + break; + } + } + + if (dsn_matched || ip_matched) { ret = true; } + } + + GENERAL_NAMES_free(const_cast( + reinterpret_cast(alt_names))); + return ret; +} + +bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { + return detail::match_hostname( + std::string(name, static_cast(name_len)), host_); + } + } + + return false; +} + +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +/* + * Group 9: TLS abstraction layer - Mbed TLS backend + */ + +/* + * Mbed TLS Backend Implementation + */ + +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT +namespace tls { + +namespace impl { + +// Mbed TLS session wrapper +struct MbedTlsSession { + mbedtls_ssl_context ssl; + socket_t sock = INVALID_SOCKET; + std::string hostname; // For client: set via set_sni + std::string sni_hostname; // For server: received from client via SNI callback + + MbedTlsSession() { mbedtls_ssl_init(&ssl); } + + ~MbedTlsSession() { mbedtls_ssl_free(&ssl); } + + MbedTlsSession(const MbedTlsSession &) = delete; + MbedTlsSession &operator=(const MbedTlsSession &) = delete; +}; + +// Thread-local error code accessor for Mbed TLS (since it doesn't have an error +// queue) +int &mbedtls_last_error() { + static thread_local int err = 0; + return err; +} + +// Helper to map Mbed TLS error to ErrorCode +ErrorCode map_mbedtls_error(int ret, int &out_errno) { + if (ret == 0) { return ErrorCode::Success; } + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { return ErrorCode::WantRead; } + if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { return ErrorCode::WantWrite; } + if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + return ErrorCode::PeerClosed; + } + if (ret == MBEDTLS_ERR_NET_CONN_RESET || ret == MBEDTLS_ERR_NET_SEND_FAILED || + ret == MBEDTLS_ERR_NET_RECV_FAILED) { + out_errno = errno; + return ErrorCode::SyscallError; + } + if (ret == MBEDTLS_ERR_X509_CERT_VERIFY_FAILED) { + return ErrorCode::CertVerifyFailed; + } + return ErrorCode::Fatal; +} + +// BIO-like send callback for Mbed TLS +int mbedtls_net_send_cb(void *ctx, const unsigned char *buf, + size_t len) { + auto sock = *static_cast(ctx); +#ifdef _WIN32 + auto ret = + send(sock, reinterpret_cast(buf), static_cast(len), 0); + if (ret == SOCKET_ERROR) { + int err = WSAGetLastError(); + if (err == WSAEWOULDBLOCK) { return MBEDTLS_ERR_SSL_WANT_WRITE; } + return MBEDTLS_ERR_NET_SEND_FAILED; + } +#else + auto ret = send(sock, buf, len, 0); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return MBEDTLS_ERR_SSL_WANT_WRITE; + } + return MBEDTLS_ERR_NET_SEND_FAILED; + } +#endif + return static_cast(ret); +} + +// BIO-like recv callback for Mbed TLS +int mbedtls_net_recv_cb(void *ctx, unsigned char *buf, size_t len) { + auto sock = *static_cast(ctx); +#ifdef _WIN32 + auto ret = + recv(sock, reinterpret_cast(buf), static_cast(len), 0); + if (ret == SOCKET_ERROR) { + int err = WSAGetLastError(); + if (err == WSAEWOULDBLOCK) { return MBEDTLS_ERR_SSL_WANT_READ; } + return MBEDTLS_ERR_NET_RECV_FAILED; + } +#else + auto ret = recv(sock, buf, len, 0); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + return MBEDTLS_ERR_NET_RECV_FAILED; + } +#endif + if (ret == 0) { return MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY; } + return static_cast(ret); +} + +// MbedTlsContext constructor/destructor implementations +MbedTlsContext::MbedTlsContext() { + mbedtls_ssl_config_init(&conf); + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); + mbedtls_x509_crt_init(&ca_chain); + mbedtls_x509_crt_init(&own_cert); + mbedtls_pk_init(&own_key); +} + +MbedTlsContext::~MbedTlsContext() { + mbedtls_pk_free(&own_key); + mbedtls_x509_crt_free(&own_cert); + mbedtls_x509_crt_free(&ca_chain); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + mbedtls_ssl_config_free(&conf); +} + +// Thread-local storage for SNI captured during handshake +// This is needed because the SNI callback doesn't have a way to pass +// session-specific data before the session is fully set up +std::string &mbedpending_sni() { + static thread_local std::string sni; + return sni; +} + +// SNI callback for Mbed TLS server to capture client's SNI hostname +int mbedtls_sni_callback(void *p_ctx, mbedtls_ssl_context *ssl, + const unsigned char *name, size_t name_len) { + (void)p_ctx; + (void)ssl; + + // Store SNI name in thread-local storage + // It will be retrieved and stored in the session after handshake + if (name && name_len > 0) { + mbedpending_sni().assign(reinterpret_cast(name), name_len); + } else { + mbedpending_sni().clear(); + } + return 0; // Accept any SNI +} + +int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt, + int cert_depth, uint32_t *flags); + +// Check if a string is an IPv4 address +bool is_ipv4_address(const std::string &str) { + int dots = 0; + for (char c : str) { + if (c == '.') { + dots++; + } else if (!isdigit(static_cast(c))) { + return false; + } + } + return dots == 3; +} + +// Parse IPv4 address string to bytes +bool parse_ipv4(const std::string &str, unsigned char *out) { + int parts[4]; + if (sscanf(str.c_str(), "%d.%d.%d.%d", &parts[0], &parts[1], &parts[2], + &parts[3]) != 4) { + return false; + } + for (int i = 0; i < 4; i++) { + if (parts[i] < 0 || parts[i] > 255) return false; + out[i] = static_cast(parts[i]); + } + return true; +} + +// MbedTLS verify callback wrapper +int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt, + int cert_depth, uint32_t *flags) { + auto &callback = get_verify_callback(); + if (!callback) { return 0; } // Continue with default verification + + // data points to the MbedTlsSession + auto *session = static_cast(data); + + // Build context + VerifyContext verify_ctx; + verify_ctx.session = static_cast(session); + verify_ctx.cert = static_cast(crt); + verify_ctx.depth = cert_depth; + verify_ctx.preverify_ok = (*flags == 0); + verify_ctx.error_code = static_cast(*flags); + + // Convert Mbed TLS flags to error string + static thread_local char error_buf[256]; + if (*flags != 0) { + mbedtls_x509_crt_verify_info(error_buf, sizeof(error_buf), "", *flags); + verify_ctx.error_string = error_buf; + } else { + verify_ctx.error_string = nullptr; + } + + bool accepted = callback(verify_ctx); + + if (accepted) { + *flags = 0; // Clear all error flags + return 0; + } + return MBEDTLS_ERR_X509_CERT_VERIFY_FAILED; +} + +} // namespace impl + +ctx_t create_client_context() { + auto ctx = new (std::nothrow) impl::MbedTlsContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = false; + + // Seed the random number generator + const char *pers = "httplib_client"; + int ret = mbedtls_ctr_drbg_seed( + &ctx->ctr_drbg, mbedtls_entropy_func, &ctx->entropy, + reinterpret_cast(pers), strlen(pers)); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set up SSL config for client + ret = mbedtls_ssl_config_defaults(&ctx->conf, MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set random number generator + mbedtls_ssl_conf_rng(&ctx->conf, mbedtls_ctr_drbg_random, &ctx->ctr_drbg); + + // Default: verify peer certificate + mbedtls_ssl_conf_authmode(&ctx->conf, MBEDTLS_SSL_VERIFY_REQUIRED); + + // Set minimum TLS version to 1.2 +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_ssl_conf_min_tls_version(&ctx->conf, MBEDTLS_SSL_VERSION_TLS1_2); +#else + mbedtls_ssl_conf_min_version(&ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, + MBEDTLS_SSL_MINOR_VERSION_3); +#endif + + return static_cast(ctx); +} + +ctx_t create_server_context() { + auto ctx = new (std::nothrow) impl::MbedTlsContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = true; + + // Seed the random number generator + const char *pers = "httplib_server"; + int ret = mbedtls_ctr_drbg_seed( + &ctx->ctr_drbg, mbedtls_entropy_func, &ctx->entropy, + reinterpret_cast(pers), strlen(pers)); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set up SSL config for server + ret = mbedtls_ssl_config_defaults(&ctx->conf, MBEDTLS_SSL_IS_SERVER, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set random number generator + mbedtls_ssl_conf_rng(&ctx->conf, mbedtls_ctr_drbg_random, &ctx->ctr_drbg); + + // Default: don't verify client + mbedtls_ssl_conf_authmode(&ctx->conf, MBEDTLS_SSL_VERIFY_NONE); + + // Set minimum TLS version to 1.2 +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_ssl_conf_min_tls_version(&ctx->conf, MBEDTLS_SSL_VERSION_TLS1_2); +#else + mbedtls_ssl_conf_min_version(&ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, + MBEDTLS_SSL_MINOR_VERSION_3); +#endif + + // Set SNI callback to capture client's SNI hostname + mbedtls_ssl_conf_sni(&ctx->conf, impl::mbedtls_sni_callback, nullptr); + + return static_cast(ctx); +} + +void free_context(ctx_t ctx) { + if (ctx) { delete static_cast(ctx); } +} + +bool set_min_version(ctx_t ctx, Version version) { + if (!ctx) { return false; } + auto mctx = static_cast(ctx); + +#ifdef CPPHTTPLIB_MBEDTLS_V3 + // Mbed TLS 3.x uses mbedtls_ssl_protocol_version enum + mbedtls_ssl_protocol_version min_ver = MBEDTLS_SSL_VERSION_TLS1_2; + if (version >= Version::TLS1_3) { +#if defined(MBEDTLS_SSL_PROTO_TLS1_3) + min_ver = MBEDTLS_SSL_VERSION_TLS1_3; +#endif + } + mbedtls_ssl_conf_min_tls_version(&mctx->conf, min_ver); +#else + // Mbed TLS 2.x uses major/minor version numbers + int major = MBEDTLS_SSL_MAJOR_VERSION_3; + int minor = MBEDTLS_SSL_MINOR_VERSION_3; // TLS 1.2 + if (version >= Version::TLS1_3) { +#if defined(MBEDTLS_SSL_PROTO_TLS1_3) + minor = MBEDTLS_SSL_MINOR_VERSION_4; // TLS 1.3 +#else + minor = MBEDTLS_SSL_MINOR_VERSION_3; // Fall back to TLS 1.2 +#endif + } + mbedtls_ssl_conf_min_version(&mctx->conf, major, minor); +#endif + return true; +} + +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len) { + if (!ctx || !pem) { return false; } + auto mctx = static_cast(ctx); + + // mbedtls_x509_crt_parse expects null-terminated string for PEM + // Add null terminator if not present + std::string pem_str(pem, len); + int ret = mbedtls_x509_crt_parse( + &mctx->ca_chain, reinterpret_cast(pem_str.c_str()), + pem_str.size() + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + return true; +} + +bool load_ca_file(ctx_t ctx, const char *file_path) { + if (!ctx || !file_path) { return false; } + auto mctx = static_cast(ctx); + + int ret = mbedtls_x509_crt_parse_file(&mctx->ca_chain, file_path); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + return true; +} + +bool load_ca_dir(ctx_t ctx, const char *dir_path) { + if (!ctx || !dir_path) { return false; } + auto mctx = static_cast(ctx); + + int ret = mbedtls_x509_crt_parse_path(&mctx->ca_chain, dir_path); + if (ret < 0) { // Returns number of certs on success, negative on error + impl::mbedtls_last_error() = ret; + return false; + } + + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + return true; +} + +bool load_system_certs(ctx_t ctx) { + if (!ctx) { return false; } + auto mctx = static_cast(ctx); + bool loaded = false; + +#ifdef _WIN32 + // Load from Windows certificate store (ROOT and CA) + static const wchar_t *store_names[] = {L"ROOT", L"CA"}; + for (auto store_name : store_names) { + HCERTSTORE hStore = CertOpenSystemStoreW(0, store_name); + if (hStore) { + PCCERT_CONTEXT pContext = nullptr; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + int ret = mbedtls_x509_crt_parse_der( + &mctx->ca_chain, pContext->pbCertEncoded, pContext->cbCertEncoded); + if (ret == 0) { loaded = true; } + } + CertCloseStore(hStore, 0); + } + } +#elif defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) + // Load from macOS Keychain + CFArrayRef certs = nullptr; + OSStatus status = SecTrustCopyAnchorCertificates(&certs); + if (status == errSecSuccess && certs) { + CFIndex count = CFArrayGetCount(certs); + for (CFIndex i = 0; i < count; i++) { + SecCertificateRef cert = + (SecCertificateRef)CFArrayGetValueAtIndex(certs, i); + CFDataRef data = SecCertificateCopyData(cert); + if (data) { + int ret = mbedtls_x509_crt_parse_der( + &mctx->ca_chain, CFDataGetBytePtr(data), + static_cast(CFDataGetLength(data))); + if (ret == 0) { loaded = true; } + CFRelease(data); + } + } + CFRelease(certs); + } +#else + // Try common CA certificate locations on Linux/Unix + static const char *ca_paths[] = { + "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu + "/etc/pki/tls/certs/ca-bundle.crt", // RHEL/CentOS + "/etc/ssl/ca-bundle.pem", // OpenSUSE + "/etc/pki/tls/cacert.pem", // OpenELEC + "/etc/ssl/cert.pem", // Alpine, FreeBSD + nullptr}; + + for (const char **path = ca_paths; *path; ++path) { + int ret = mbedtls_x509_crt_parse_file(&mctx->ca_chain, *path); + if (ret >= 0) { + loaded = true; + break; + } + } + + // Also try the CA directory + if (!loaded) { + static const char *ca_dirs[] = {"/etc/ssl/certs", // Debian/Ubuntu + "/etc/pki/tls/certs", // RHEL/CentOS + "/usr/share/ca-certificates", nullptr}; + + for (const char **dir = ca_dirs; *dir; ++dir) { + int ret = mbedtls_x509_crt_parse_path(&mctx->ca_chain, *dir); + if (ret >= 0) { + loaded = true; + break; + } + } + } +#endif + + if (loaded) { + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + } + return loaded; +} + +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + if (!ctx || !cert || !key) { return false; } + auto mctx = static_cast(ctx); + + // Parse certificate + std::string cert_str(cert); + int ret = mbedtls_x509_crt_parse( + &mctx->own_cert, + reinterpret_cast(cert_str.c_str()), + cert_str.size() + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Parse private key + std::string key_str(key); + const unsigned char *pwd = + password ? reinterpret_cast(password) : nullptr; + size_t pwd_len = password ? strlen(password) : 0; + +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_parse_key( + &mctx->own_key, reinterpret_cast(key_str.c_str()), + key_str.size() + 1, pwd, pwd_len, mbedtls_ctr_drbg_random, + &mctx->ctr_drbg); +#else + ret = mbedtls_pk_parse_key( + &mctx->own_key, reinterpret_cast(key_str.c_str()), + key_str.size() + 1, pwd, pwd_len); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + return true; +} + +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + if (!ctx || !cert_path || !key_path) { return false; } + auto mctx = static_cast(ctx); + + // Parse certificate file + int ret = mbedtls_x509_crt_parse_file(&mctx->own_cert, cert_path); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Parse private key file +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_parse_keyfile(&mctx->own_key, key_path, password, + mbedtls_ctr_drbg_random, &mctx->ctr_drbg); +#else + ret = mbedtls_pk_parse_keyfile(&mctx->own_key, key_path, password); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + return true; +} + +void set_verify_client(ctx_t ctx, bool require) { + if (!ctx) { return; } + auto mctx = static_cast(ctx); + mctx->verify_client = require; + if (require) { + mbedtls_ssl_conf_authmode(&mctx->conf, MBEDTLS_SSL_VERIFY_REQUIRED); + } else { + // If a verify callback is set, use OPTIONAL mode to ensure the callback + // is called (matching OpenSSL behavior). Otherwise use NONE. + mbedtls_ssl_conf_authmode(&mctx->conf, mctx->has_verify_callback + ? MBEDTLS_SSL_VERIFY_OPTIONAL + : MBEDTLS_SSL_VERIFY_NONE); + } +} + +session_t create_session(ctx_t ctx, socket_t sock) { + if (!ctx || sock == INVALID_SOCKET) { return nullptr; } + auto mctx = static_cast(ctx); + + auto session = new (std::nothrow) impl::MbedTlsSession(); + if (!session) { return nullptr; } + + session->sock = sock; + + int ret = mbedtls_ssl_setup(&session->ssl, &mctx->conf); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete session; + return nullptr; + } + + // Set BIO callbacks + mbedtls_ssl_set_bio(&session->ssl, &session->sock, impl::mbedtls_net_send_cb, + impl::mbedtls_net_recv_cb, nullptr); + + // Set per-session verify callback with session pointer if callback is + // registered + if (mctx->has_verify_callback) { + mbedtls_ssl_set_verify(&session->ssl, impl::mbedtls_verify_callback, + session); + } + + return static_cast(session); +} + +void free_session(session_t session) { + if (session) { delete static_cast(session); } +} + +bool set_sni(session_t session, const char *hostname) { + if (!session || !hostname) { return false; } + auto msession = static_cast(session); + + int ret = mbedtls_ssl_set_hostname(&msession->ssl, hostname); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + msession->hostname = hostname; + return true; +} + +bool set_hostname(session_t session, const char *hostname) { + // In Mbed TLS, set_hostname also sets up hostname verification + return set_sni(session, hostname); +} + +TlsError connect(session_t session) { + TlsError err; + if (!session) { + err.code = ErrorCode::Fatal; + return err; + } + + auto msession = static_cast(session); + int ret = mbedtls_ssl_handshake(&msession->ssl); + + if (ret == 0) { + err.code = ErrorCode::Success; + } else { + err.code = impl::map_mbedtls_error(ret, err.sys_errno); + err.backend_code = static_cast(-ret); + impl::mbedtls_last_error() = ret; + } + + return err; +} + +TlsError accept(session_t session) { + // Same as connect for Mbed TLS - handshake works for both client and server + auto result = connect(session); + + // After successful handshake, capture SNI from thread-local storage + if (result.code == ErrorCode::Success && session) { + auto msession = static_cast(session); + msession->sni_hostname = std::move(impl::mbedpending_sni()); + impl::mbedpending_sni().clear(); + } + + return result; +} + +bool connect_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto msession = static_cast(session); + + // Set socket to non-blocking mode + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + int ret; + while ((ret = mbedtls_ssl_handshake(&msession->ssl)) != 0) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } + + // TlsError or timeout + if (err) { + err->code = impl::map_mbedtls_error(ret, err->sys_errno); + err->backend_code = static_cast(-ret); + } + impl::mbedtls_last_error() = ret; + return false; + } + + if (err) { err->code = ErrorCode::Success; } + return true; +} + +bool accept_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + // Same implementation as connect for Mbed TLS + bool result = + connect_nonblocking(session, sock, timeout_sec, timeout_usec, err); + + // After successful handshake, capture SNI from thread-local storage + if (result && session) { + auto msession = static_cast(session); + msession->sni_hostname = std::move(impl::mbedpending_sni()); + impl::mbedpending_sni().clear(); + } + + return result; +} + +ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto msession = static_cast(session); + int ret = + mbedtls_ssl_read(&msession->ssl, static_cast(buf), len); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return 0; + } + + err.code = impl::map_mbedtls_error(ret, err.sys_errno); + err.backend_code = static_cast(-ret); + impl::mbedtls_last_error() = ret; + return -1; +} + +ssize_t write(session_t session, const void *buf, size_t len, + TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto msession = static_cast(session); + int ret = mbedtls_ssl_write(&msession->ssl, + static_cast(buf), len); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return 0; + } + + err.code = impl::map_mbedtls_error(ret, err.sys_errno); + err.backend_code = static_cast(-ret); + impl::mbedtls_last_error() = ret; + return -1; +} + +int pending(const_session_t session) { + if (!session) { return 0; } + auto msession = + static_cast(const_cast(session)); + return static_cast(mbedtls_ssl_get_bytes_avail(&msession->ssl)); +} + +void shutdown(session_t session, bool graceful) { + if (!session) { return; } + auto msession = static_cast(session); + + if (graceful) { + // Try to send close_notify, but don't block forever + int ret; + int attempts = 0; + while ((ret = mbedtls_ssl_close_notify(&msession->ssl)) != 0 && + attempts < 3) { + if (ret != MBEDTLS_ERR_SSL_WANT_READ && + ret != MBEDTLS_ERR_SSL_WANT_WRITE) { + break; + } + attempts++; + } + } +} + +bool is_peer_closed(session_t session, socket_t sock) { + if (!session || sock == INVALID_SOCKET) { return true; } + auto msession = static_cast(session); + + // Check if there's already decrypted data available in the TLS buffer + // If so, the connection is definitely alive + if (mbedtls_ssl_get_bytes_avail(&msession->ssl) > 0) { return false; } + + // Set socket to non-blocking to avoid blocking on read + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + // Try a 1-byte read to check connection status + // Note: This will consume the byte if data is available, but for the + // purpose of checking if peer is closed, this should be acceptable + // since we're only called when we expect the connection might be closing + unsigned char buf; + int ret = mbedtls_ssl_read(&msession->ssl, &buf, 1); + + // If we got data or WANT_READ (would block), connection is alive + if (ret > 0 || ret == MBEDTLS_ERR_SSL_WANT_READ) { return false; } + + // If we get a peer close notify or a connection reset, the peer is closed + return ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || + ret == MBEDTLS_ERR_NET_CONN_RESET || ret == 0; +} + +cert_t get_peer_cert(const_session_t session) { + if (!session) { return nullptr; } + auto msession = + static_cast(const_cast(session)); + + // Mbed TLS returns a pointer to the internal peer cert chain. + // WARNING: This pointer is only valid while the session is active. + // Do not use the certificate after calling free_session(). + const mbedtls_x509_crt *cert = mbedtls_ssl_get_peer_cert(&msession->ssl); + return const_cast(cert); +} + +void free_cert(cert_t cert) { + // Mbed TLS: peer certificate is owned by the SSL context. + // No-op here, but callers should still call this for cross-backend + // portability. + (void)cert; +} + +bool verify_hostname(cert_t cert, const char *hostname) { + if (!cert || !hostname) { return false; } + auto mcert = static_cast(cert); + std::string host_str(hostname); + + // Check if hostname is an IP address + bool is_ip = impl::is_ipv4_address(host_str); + unsigned char ip_bytes[4]; + if (is_ip) { impl::parse_ipv4(host_str, ip_bytes); } + + // Check Subject Alternative Names (SAN) + // In Mbed TLS 3.x, subject_alt_names contains raw values without ASN.1 tags + // - DNS names: raw string bytes + // - IP addresses: raw IP bytes (4 for IPv4, 16 for IPv6) + const mbedtls_x509_sequence *san = &mcert->subject_alt_names; + while (san != nullptr && san->buf.p != nullptr && san->buf.len > 0) { + const unsigned char *p = san->buf.p; + size_t len = san->buf.len; + + if (is_ip) { + // Check if this SAN is an IPv4 address (4 bytes) + if (len == 4 && memcmp(p, ip_bytes, 4) == 0) { return true; } + // Check if this SAN is an IPv6 address (16 bytes) - skip for now + } else { + // Check if this SAN is a DNS name (printable ASCII string) + bool is_dns = len > 0; + for (size_t i = 0; i < len && is_dns; i++) { + if (p[i] < 32 || p[i] > 126) { is_dns = false; } + } + if (is_dns) { + std::string san_name(reinterpret_cast(p), len); + if (detail::match_hostname(san_name, host_str)) { return true; } + } + } + san = san->next; + } + + // Fallback: Check Common Name (CN) in subject + char cn[256]; + int ret = mbedtls_x509_dn_gets(cn, sizeof(cn), &mcert->subject); + if (ret > 0) { + std::string cn_str(cn); + + // Look for "CN=" in the DN string + size_t cn_pos = cn_str.find("CN="); + if (cn_pos != std::string::npos) { + size_t start = cn_pos + 3; + size_t end = cn_str.find(',', start); + std::string cn_value = + cn_str.substr(start, end == std::string::npos ? end : end - start); + + if (detail::match_hostname(cn_value, host_str)) { return true; } + } + } + + return false; +} + +uint64_t hostname_mismatch_code() { + return static_cast(MBEDTLS_X509_BADCERT_CN_MISMATCH); +} + +long get_verify_result(const_session_t session) { + if (!session) { return -1; } + auto msession = + static_cast(const_cast(session)); + uint32_t flags = mbedtls_ssl_get_verify_result(&msession->ssl); + // Return 0 (X509_V_OK equivalent) if verification passed + return flags == 0 ? 0 : static_cast(flags); +} + +std::string get_cert_subject_cn(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + // Find the CN in the subject + const mbedtls_x509_name *name = &x509->subject; + while (name != nullptr) { + if (MBEDTLS_OID_CMP(MBEDTLS_OID_AT_CN, &name->oid) == 0) { + return std::string(reinterpret_cast(name->val.p), + name->val.len); + } + name = name->next; + } + return ""; +} + +std::string get_cert_issuer_name(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + // Build a human-readable issuer name string + char buf[512]; + int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), &x509->issuer); + if (ret < 0) return ""; + return std::string(buf); +} + +bool get_cert_sans(cert_t cert, std::vector &sans) { + sans.clear(); + if (!cert) return false; + auto x509 = static_cast(cert); + + // Parse the Subject Alternative Name extension + const mbedtls_x509_sequence *cur = &x509->subject_alt_names; + while (cur != nullptr) { + if (cur->buf.len > 0) { + // Mbed TLS stores SAN as ASN.1 sequences + // The tag byte indicates the type + const unsigned char *p = cur->buf.p; + size_t len = cur->buf.len; + + // First byte is the tag + unsigned char tag = *p; + p++; + len--; + + // Parse length (simple single-byte length assumed) + if (len > 0 && *p < 0x80) { + size_t value_len = *p; + p++; + len--; + + if (value_len <= len) { + SanEntry entry; + // ASN.1 context tags for GeneralName + switch (tag & 0x1F) { + case 2: // dNSName + entry.type = SanType::DNS; + entry.value = + std::string(reinterpret_cast(p), value_len); + break; + case 7: // iPAddress + entry.type = SanType::IP; + if (value_len == 4) { + // IPv4 + char buf[16]; + snprintf(buf, sizeof(buf), "%d.%d.%d.%d", p[0], p[1], p[2], p[3]); + entry.value = buf; + } else if (value_len == 16) { + // IPv6 + char buf[64]; + snprintf(buf, sizeof(buf), + "%02x%02x:%02x%02x:%02x%02x:%02x%02x:" + "%02x%02x:%02x%02x:%02x%02x:%02x%02x", + p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], + p[9], p[10], p[11], p[12], p[13], p[14], p[15]); + entry.value = buf; + } + break; + case 1: // rfc822Name (email) + entry.type = SanType::EMAIL; + entry.value = + std::string(reinterpret_cast(p), value_len); + break; + case 6: // uniformResourceIdentifier + entry.type = SanType::URI; + entry.value = + std::string(reinterpret_cast(p), value_len); + break; + default: entry.type = SanType::OTHER; break; + } + + if (!entry.value.empty()) { sans.push_back(std::move(entry)); } + } + } + } + cur = cur->next; + } + return true; +} + +bool get_cert_validity(cert_t cert, time_t ¬_before, + time_t ¬_after) { + if (!cert) return false; + auto x509 = static_cast(cert); + + // Convert mbedtls_x509_time to time_t + auto to_time_t = [](const mbedtls_x509_time &t) -> time_t { + struct tm tm_time = {}; + tm_time.tm_year = t.year - 1900; + tm_time.tm_mon = t.mon - 1; + tm_time.tm_mday = t.day; + tm_time.tm_hour = t.hour; + tm_time.tm_min = t.min; + tm_time.tm_sec = t.sec; +#ifdef _WIN32 + return _mkgmtime(&tm_time); +#else + return timegm(&tm_time); +#endif + }; + + not_before = to_time_t(x509->valid_from); + not_after = to_time_t(x509->valid_to); + return true; +} + +std::string get_cert_serial(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + // Convert serial number to hex string + std::string result; + result.reserve(x509->serial.len * 2); + for (size_t i = 0; i < x509->serial.len; i++) { + char hex[3]; + snprintf(hex, sizeof(hex), "%02X", x509->serial.p[i]); + result += hex; + } + return result; +} + +bool get_cert_der(cert_t cert, std::vector &der) { + if (!cert) return false; + auto crt = static_cast(cert); + if (!crt->raw.p || crt->raw.len == 0) return false; + der.assign(crt->raw.p, crt->raw.p + crt->raw.len); + return true; +} + +const char *get_sni(const_session_t session) { + if (!session) return nullptr; + auto msession = static_cast(session); + + // For server: return SNI received from client during handshake + if (!msession->sni_hostname.empty()) { + return msession->sni_hostname.c_str(); + } + + // For client: return the hostname set via set_sni + if (!msession->hostname.empty()) { return msession->hostname.c_str(); } + + return nullptr; +} + +uint64_t peek_error() { + // Mbed TLS doesn't have an error queue, return the last error + return static_cast(-impl::mbedtls_last_error()); +} + +uint64_t get_error() { + // Mbed TLS doesn't have an error queue, return and clear the last error + uint64_t err = static_cast(-impl::mbedtls_last_error()); + impl::mbedtls_last_error() = 0; + return err; +} + +std::string error_string(uint64_t code) { + char buf[256]; + mbedtls_strerror(-static_cast(code), buf, sizeof(buf)); + return std::string(buf); +} + +ca_store_t create_ca_store(const char *pem, size_t len) { + auto *ca_chain = new (std::nothrow) mbedtls_x509_crt; + if (!ca_chain) { return nullptr; } + + mbedtls_x509_crt_init(ca_chain); + + // mbedtls_x509_crt_parse expects null-terminated PEM + int ret = mbedtls_x509_crt_parse(ca_chain, + reinterpret_cast(pem), + len + 1); // +1 for null terminator + if (ret != 0) { + // Try without +1 in case PEM is already null-terminated + ret = mbedtls_x509_crt_parse( + ca_chain, reinterpret_cast(pem), len); + if (ret != 0) { + mbedtls_x509_crt_free(ca_chain); + delete ca_chain; + return nullptr; + } + } + + return static_cast(ca_chain); +} + +void free_ca_store(ca_store_t store) { + if (store) { + auto *ca_chain = static_cast(store); + mbedtls_x509_crt_free(ca_chain); + delete ca_chain; + } +} + +bool set_ca_store(ctx_t ctx, ca_store_t store) { + if (!ctx || !store) { return false; } + auto *mbed_ctx = static_cast(ctx); + auto *ca_chain = static_cast(store); + + // Free existing CA chain + mbedtls_x509_crt_free(&mbed_ctx->ca_chain); + mbedtls_x509_crt_init(&mbed_ctx->ca_chain); + + // Copy the CA chain (deep copy) + // Parse from the raw data of the source cert + mbedtls_x509_crt *src = ca_chain; + while (src != nullptr) { + int ret = mbedtls_x509_crt_parse_der(&mbed_ctx->ca_chain, src->raw.p, + src->raw.len); + if (ret != 0) { return false; } + src = src->next; + } + + // Update the SSL config to use the new CA chain + mbedtls_ssl_conf_ca_chain(&mbed_ctx->conf, &mbed_ctx->ca_chain, nullptr); + return true; +} + +size_t get_ca_certs(ctx_t ctx, std::vector &certs) { + certs.clear(); + if (!ctx) { return 0; } + auto *mbed_ctx = static_cast(ctx); + + // Iterate through the CA chain + mbedtls_x509_crt *cert = &mbed_ctx->ca_chain; + while (cert != nullptr && cert->raw.len > 0) { + // Create a copy of the certificate for the caller + auto *copy = new mbedtls_x509_crt; + mbedtls_x509_crt_init(copy); + int ret = mbedtls_x509_crt_parse_der(copy, cert->raw.p, cert->raw.len); + if (ret == 0) { + certs.push_back(static_cast(copy)); + } else { + mbedtls_x509_crt_free(copy); + delete copy; + } + cert = cert->next; + } + return certs.size(); +} + +std::vector get_ca_names(ctx_t ctx) { + std::vector names; + if (!ctx) { return names; } + auto *mbed_ctx = static_cast(ctx); + + // Iterate through the CA chain + mbedtls_x509_crt *cert = &mbed_ctx->ca_chain; + while (cert != nullptr && cert->raw.len > 0) { + char buf[512]; + int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), &cert->subject); + if (ret > 0) { names.push_back(buf); } + cert = cert->next; + } + return names; +} + +bool update_server_cert(ctx_t ctx, const char *cert_pem, + const char *key_pem, const char *password) { + if (!ctx || !cert_pem || !key_pem) { return false; } + auto *mbed_ctx = static_cast(ctx); + + // Free existing certificate and key + mbedtls_x509_crt_free(&mbed_ctx->own_cert); + mbedtls_pk_free(&mbed_ctx->own_key); + mbedtls_x509_crt_init(&mbed_ctx->own_cert); + mbedtls_pk_init(&mbed_ctx->own_key); + + // Parse certificate PEM + int ret = mbedtls_x509_crt_parse( + &mbed_ctx->own_cert, reinterpret_cast(cert_pem), + strlen(cert_pem) + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Parse private key PEM +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_parse_key( + &mbed_ctx->own_key, reinterpret_cast(key_pem), + strlen(key_pem) + 1, + password ? reinterpret_cast(password) : nullptr, + password ? strlen(password) : 0, mbedtls_ctr_drbg_random, + &mbed_ctx->ctr_drbg); +#else + ret = mbedtls_pk_parse_key( + &mbed_ctx->own_key, reinterpret_cast(key_pem), + strlen(key_pem) + 1, + password ? reinterpret_cast(password) : nullptr, + password ? strlen(password) : 0); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Configure SSL to use the new certificate and key + ret = mbedtls_ssl_conf_own_cert(&mbed_ctx->conf, &mbed_ctx->own_cert, + &mbed_ctx->own_key); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + return true; +} + +bool update_server_client_ca(ctx_t ctx, const char *ca_pem) { + if (!ctx || !ca_pem) { return false; } + auto *mbed_ctx = static_cast(ctx); + + // Free existing CA chain + mbedtls_x509_crt_free(&mbed_ctx->ca_chain); + mbedtls_x509_crt_init(&mbed_ctx->ca_chain); + + // Parse CA PEM + int ret = mbedtls_x509_crt_parse( + &mbed_ctx->ca_chain, reinterpret_cast(ca_pem), + strlen(ca_pem) + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Update SSL config to use new CA chain + mbedtls_ssl_conf_ca_chain(&mbed_ctx->conf, &mbed_ctx->ca_chain, nullptr); + return true; +} + +bool set_verify_callback(ctx_t ctx, VerifyCallback callback) { + if (!ctx) { return false; } + auto *mbed_ctx = static_cast(ctx); + + impl::get_verify_callback() = std::move(callback); + mbed_ctx->has_verify_callback = + static_cast(impl::get_verify_callback()); + + if (mbed_ctx->has_verify_callback) { + // Set OPTIONAL mode to ensure callback is called even when verification + // is disabled (matching OpenSSL behavior where SSL_VERIFY_PEER is set) + mbedtls_ssl_conf_authmode(&mbed_ctx->conf, MBEDTLS_SSL_VERIFY_OPTIONAL); + mbedtls_ssl_conf_verify(&mbed_ctx->conf, impl::mbedtls_verify_callback, + nullptr); + } else { + mbedtls_ssl_conf_verify(&mbed_ctx->conf, nullptr, nullptr); + } + return true; +} + +long get_verify_error(const_session_t session) { + if (!session) { return -1; } + auto *msession = + static_cast(const_cast(session)); + return static_cast(mbedtls_ssl_get_verify_result(&msession->ssl)); +} + +std::string verify_error_string(long error_code) { + if (error_code == 0) { return ""; } + char buf[256]; + mbedtls_x509_crt_verify_info(buf, sizeof(buf), "", + static_cast(error_code)); + // Remove trailing newline if present + std::string result(buf); + while (!result.empty() && (result.back() == '\n' || result.back() == ' ')) { + result.pop_back(); + } + return result; +} + +} // namespace tls + +#endif // CPPHTTPLIB_MBEDTLS_SUPPORT + } // namespace httplib diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 7c7790f41f..1fd8b1d1e8 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.30.2" -#define CPPHTTPLIB_VERSION_NUM "0x001E02" +#define CPPHTTPLIB_VERSION "0.31.0" +#define CPPHTTPLIB_VERSION_NUM "0x001F00" /* * Platform compatibility check @@ -147,7 +147,7 @@ #endif #ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH -#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH (100 * 1024 * 1024) // 100MB #endif #ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH @@ -383,6 +383,45 @@ using socket_t = int; #endif // CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef _WIN32 +#include +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#endif // _WIN32 +#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) +#if TARGET_OS_MAC +#include +#endif +#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN + +// Mbed TLS 3.x API compatibility +#if MBEDTLS_VERSION_MAJOR >= 3 +#define CPPHTTPLIB_MBEDTLS_V3 +#endif + +#endif // CPPHTTPLIB_MBEDTLS_SUPPORT + +// Define CPPHTTPLIB_SSL_ENABLED if any SSL backend is available +// This simplifies conditional compilation when adding new backends (e.g., +// wolfSSL) +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || defined(CPPHTTPLIB_MBEDTLS_SUPPORT) +#define CPPHTTPLIB_SSL_ENABLED +#endif + #ifdef CPPHTTPLIB_ZLIB_SUPPORT #include #endif @@ -799,6 +838,105 @@ public: using Range = std::pair; using Ranges = std::vector; +#ifdef CPPHTTPLIB_SSL_ENABLED +// TLS abstraction layer - public type definitions and API +namespace tls { + +// Opaque handles (defined as void* for abstraction) +using ctx_t = void *; +using session_t = void *; +using const_session_t = const void *; // For read-only session access +using cert_t = void *; +using ca_store_t = void *; + +// TLS versions +enum class Version { + TLS1_2 = 0x0303, + TLS1_3 = 0x0304, +}; + +// Subject Alternative Names (SAN) entry types +enum class SanType { DNS, IP, EMAIL, URI, OTHER }; + +// SAN entry structure +struct SanEntry { + SanType type; + std::string value; +}; + +// Verification context for certificate verification callback +struct VerifyContext { + session_t session; // TLS session handle + cert_t cert; // Current certificate being verified + int depth; // Certificate chain depth (0 = leaf) + bool preverify_ok; // OpenSSL/Mbed TLS pre-verification result + long error_code; // Backend-specific error code (0 = no error) + const char *error_string; // Human-readable error description + + // Certificate introspection methods + std::string subject_cn() const; + std::string issuer_name() const; + bool check_hostname(const char *hostname) const; + std::vector sans() const; + bool validity(time_t ¬_before, time_t ¬_after) const; + std::string serial() const; +}; + +using VerifyCallback = std::function; + +// TlsError codes for TLS operations (backend-independent) +enum class ErrorCode : int { + Success = 0, + WantRead, // Non-blocking: need to wait for read + WantWrite, // Non-blocking: need to wait for write + PeerClosed, // Peer closed the connection + Fatal, // Unrecoverable error + SyscallError, // System call error (check sys_errno) + CertVerifyFailed, // Certificate verification failed + HostnameMismatch, // Hostname verification failed +}; + +// TLS error information +struct TlsError { + ErrorCode code = ErrorCode::Fatal; + uint64_t backend_code = 0; // OpenSSL: ERR_get_error(), mbedTLS: return value + int sys_errno = 0; // errno when SyscallError + + // Convert verification error code to human-readable string + static std::string verify_error_to_string(long error_code); +}; + +// RAII wrapper for peer certificate +class PeerCert { +public: + PeerCert(); + PeerCert(PeerCert &&other) noexcept; + PeerCert &operator=(PeerCert &&other) noexcept; + ~PeerCert(); + + PeerCert(const PeerCert &) = delete; + PeerCert &operator=(const PeerCert &) = delete; + + explicit operator bool() const; + std::string subject_cn() const; + std::string issuer_name() const; + bool check_hostname(const char *hostname) const; + std::vector sans() const; + bool validity(time_t ¬_before, time_t ¬_after) const; + std::string serial() const; + +private: + explicit PeerCert(cert_t cert); + cert_t cert_ = nullptr; + friend PeerCert get_peer_cert_from_session(const_session_t session); +}; + +// Callback for TLS context setup (used by SSLServer constructor) +using ContextSetupCallback = std::function; + +} // namespace tls +#endif + struct Request { std::string method; std::string path; @@ -828,9 +966,6 @@ struct Request { ContentReceiverWithProgress content_receiver; DownloadProgress download_progress; UploadProgress upload_progress; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - const SSL *ssl = nullptr; -#endif bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, const char *def = "", @@ -858,6 +993,12 @@ struct Request { size_t authorization_count_ = 0; std::chrono::time_point start_time_ = (std::chrono::steady_clock::time_point::min)(); + +#ifdef CPPHTTPLIB_SSL_ENABLED + tls::const_session_t ssl = nullptr; + tls::PeerCert peer_cert() const; + std::string sni() const; +#endif }; struct Response { @@ -1005,74 +1146,18 @@ public: class ThreadPool final : public TaskQueue { public: - explicit ThreadPool(size_t n, size_t mqr = 0) - : shutdown_(false), max_queued_requests_(mqr) { - threads_.reserve(n); - while (n) { - threads_.emplace_back(worker(*this)); - n--; - } - } - + explicit ThreadPool(size_t n, size_t mqr = 0); ThreadPool(const ThreadPool &) = delete; ~ThreadPool() override = default; - bool enqueue(std::function fn) override { - { - std::unique_lock lock(mutex_); - if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { - return false; - } - jobs_.push_back(std::move(fn)); - } - - cond_.notify_one(); - return true; - } - - void shutdown() override { - // Stop all worker threads... - { - std::unique_lock lock(mutex_); - shutdown_ = true; - } - - cond_.notify_all(); - - // Join... - for (auto &t : threads_) { - t.join(); - } - } + bool enqueue(std::function fn) override; + void shutdown() override; private: struct worker { - explicit worker(ThreadPool &pool) : pool_(pool) {} + explicit worker(ThreadPool &pool); - void operator()() { - for (;;) { - std::function fn; - { - std::unique_lock lock(pool_.mutex_); - - pool_.cond_.wait( - lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); - - if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } - - fn = pool_.jobs_.front(); - pool_.jobs_.pop_front(); - } - - assert(true == static_cast(fn)); - fn(); - } - -#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ - !defined(LIBRESSL_VERSION_NUMBER) - OPENSSL_thread_stop(); -#endif - } + void operator()(); ThreadPool &pool_; }; @@ -1184,6 +1269,9 @@ int close_socket(socket_t sock); ssize_t write_headers(Stream &strm, const Headers &headers); +bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, + time_t usec); + } // namespace detail class Server { @@ -1429,17 +1517,6 @@ public: Headers &&request_headers = Headers{}) : res_(std::move(res)), err_(err), request_headers_(std::move(request_headers)) {} -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - Result(std::unique_ptr &&res, Error err, Headers &&request_headers, - int ssl_error) - : res_(std::move(res)), err_(err), - request_headers_(std::move(request_headers)), ssl_error_(ssl_error) {} - Result(std::unique_ptr &&res, Error err, Headers &&request_headers, - int ssl_error, unsigned long ssl_openssl_error) - : res_(std::move(res)), err_(err), - request_headers_(std::move(request_headers)), ssl_error_(ssl_error), - ssl_openssl_error_(ssl_openssl_error) {} -#endif // Response operator bool() const { return res_ != nullptr; } bool operator==(std::nullptr_t) const { return res_ == nullptr; } @@ -1454,13 +1531,6 @@ public: // Error Error error() const { return err_; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - // SSL Error - int ssl_error() const { return ssl_error_; } - // OpenSSL Error - unsigned long ssl_openssl_error() const { return ssl_openssl_error_; } -#endif - // Request Headers bool has_request_header(const std::string &key) const; std::string get_request_header_value(const std::string &key, @@ -1474,64 +1544,76 @@ private: std::unique_ptr res_; Error err_ = Error::Unknown; Headers request_headers_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef CPPHTTPLIB_SSL_ENABLED +public: + Result(std::unique_ptr &&res, Error err, Headers &&request_headers, + int ssl_error) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)), ssl_error_(ssl_error) {} + Result(std::unique_ptr &&res, Error err, Headers &&request_headers, + int ssl_error, unsigned long ssl_backend_error) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)), ssl_error_(ssl_error), + ssl_backend_error_(ssl_backend_error) {} + + int ssl_error() const { return ssl_error_; } + unsigned long ssl_backend_error() const { return ssl_backend_error_; } + +private: int ssl_error_ = 0; - unsigned long ssl_openssl_error_ = 0; + unsigned long ssl_backend_error_ = 0; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +public: + [[deprecated("Use ssl_backend_error() instead")]] + unsigned long ssl_openssl_error() const { + return ssl_backend_error_; + } #endif }; struct ClientConnection { socket_t sock = INVALID_SOCKET; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSL *ssl = nullptr; -#endif bool is_open() const { return sock != INVALID_SOCKET; } ClientConnection() = default; - ~ClientConnection() { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (ssl) { - SSL_free(ssl); - ssl = nullptr; - } -#endif - if (sock != INVALID_SOCKET) { - detail::close_socket(sock); - sock = INVALID_SOCKET; - } - } + ~ClientConnection(); ClientConnection(const ClientConnection &) = delete; ClientConnection &operator=(const ClientConnection &) = delete; ClientConnection(ClientConnection &&other) noexcept : sock(other.sock) -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED , - ssl(other.ssl) + session(other.session) #endif { other.sock = INVALID_SOCKET; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - other.ssl = nullptr; +#ifdef CPPHTTPLIB_SSL_ENABLED + other.session = nullptr; #endif } ClientConnection &operator=(ClientConnection &&other) noexcept { if (this != &other) { sock = other.sock; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - ssl = other.ssl; -#endif other.sock = INVALID_SOCKET; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - other.ssl = nullptr; +#ifdef CPPHTTPLIB_SSL_ENABLED + session = other.session; + other.session = nullptr; #endif } return *this; } + +#ifdef CPPHTTPLIB_SSL_ENABLED + tls::session_t session = nullptr; +#endif }; namespace detail { @@ -1540,7 +1622,9 @@ struct ChunkedDecoder; struct BodyReader { Stream *stream = nullptr; + bool has_content_length = false; size_t content_length = 0; + size_t payload_max_length = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; size_t bytes_read = 0; bool chunked = false; bool eof = false; @@ -1610,6 +1694,7 @@ public: std::unique_ptr decompressor_; std::string decompress_buffer_; size_t decompress_offset_ = 0; + size_t decompressed_bytes_read_ = 0; }; // clang-format off @@ -1756,10 +1841,6 @@ public: void set_basic_auth(const std::string &username, const std::string &password); void set_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const std::string &username, - const std::string &password); -#endif void set_keep_alive(bool on); void set_follow_location(bool on); @@ -1770,30 +1851,14 @@ public: void set_decompress(bool on); + void set_payload_max_length(size_t length); + void set_interface(const std::string &intf); void set_proxy(const std::string &host, int port); void set_proxy_basic_auth(const std::string &username, const std::string &password); void set_proxy_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const std::string &username, - const std::string &password); -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_ca_cert_path(const std::string &ca_cert_file_path, - const std::string &ca_cert_dir_path = std::string()); - void set_ca_cert_store(X509_STORE *ca_cert_store); - X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void enable_server_certificate_verification(bool enabled); - void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier( - std::function verifier); -#endif void set_logger(Logger logger); void set_error_logger(ErrorLogger error_logger); @@ -1801,11 +1866,15 @@ public: protected: struct Socket { socket_t sock = INVALID_SOCKET; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSL *ssl = nullptr; -#endif + + // For Mbed TLS compatibility: start_time for request timeout tracking + std::chrono::time_point start_time_; bool is_open() const { return sock != INVALID_SOCKET; } + +#ifdef CPPHTTPLIB_SSL_ENABLED + tls::session_t ssl = nullptr; +#endif }; virtual bool create_and_connect_socket(Socket &socket, Error &error); @@ -1872,10 +1941,6 @@ protected: std::string basic_auth_username_; std::string basic_auth_password_; std::string bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string digest_auth_username_; - std::string digest_auth_password_; -#endif bool keep_alive_ = false; bool follow_location_ = false; @@ -1890,6 +1955,9 @@ protected: bool compress_ = false; bool decompress_ = true; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + bool has_payload_max_length_ = false; + std::string interface_; std::string proxy_host_; @@ -1898,33 +1966,11 @@ protected: std::string proxy_basic_auth_username_; std::string proxy_basic_auth_password_; std::string proxy_bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string proxy_digest_auth_username_; - std::string proxy_digest_auth_password_; -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string ca_cert_file_path_; - std::string ca_cert_dir_path_; - - X509_STORE *ca_cert_store_ = nullptr; -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - bool server_certificate_verification_ = true; - bool server_hostname_verification_ = true; - std::function server_certificate_verifier_; -#endif mutable std::mutex logger_mutex_; Logger logger_; ErrorLogger error_logger_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - int last_ssl_error_ = 0; - unsigned long last_openssl_error_ = 0; -#endif - private: bool send_(Request &req, Response &res, Error &error); Result send_(Request &&req); @@ -1969,6 +2015,44 @@ private: virtual bool is_ssl() const; void transfer_socket_ownership_to_handle(StreamHandle &handle); + +#ifdef CPPHTTPLIB_SSL_ENABLED +public: + void set_digest_auth(const std::string &username, + const std::string &password); + void set_proxy_digest_auth(const std::string &username, + const std::string &password); + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + +protected: + std::string digest_auth_username_; + std::string digest_auth_password_; + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + bool server_certificate_verification_ = true; + bool server_hostname_verification_ = true; + std::string ca_cert_pem_; // Store CA cert PEM for redirect transfer + int last_ssl_error_ = 0; + unsigned long last_backend_error_ = 0; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +public: + [[deprecated("Use load_ca_cert_store() instead")]] + void set_ca_cert_store(X509_STORE *ca_cert_store); + + [[deprecated("Use tls::create_ca_store() instead")]] + X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; + + [[deprecated("Use set_server_certificate_verifier(VerifyCallback) instead")]] + virtual void set_server_certificate_verifier( + std::function verifier); +#endif }; class Client { @@ -2138,10 +2222,6 @@ public: void set_basic_auth(const std::string &username, const std::string &password); void set_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const std::string &username, - const std::string &password); -#endif void set_keep_alive(bool on); void set_follow_location(bool on); @@ -2153,49 +2233,65 @@ public: void set_decompress(bool on); + void set_payload_max_length(size_t length); + void set_interface(const std::string &intf); void set_proxy(const std::string &host, int port); void set_proxy_basic_auth(const std::string &username, const std::string &password); void set_proxy_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const std::string &username, - const std::string &password); -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void enable_server_certificate_verification(bool enabled); - void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier( - std::function verifier); -#endif - void set_logger(Logger logger); void set_error_logger(ErrorLogger error_logger); - // SSL -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_ca_cert_path(const std::string &ca_cert_file_path, - const std::string &ca_cert_dir_path = std::string()); - - void set_ca_cert_store(X509_STORE *ca_cert_store); - void load_ca_cert_store(const char *ca_cert, std::size_t size); - - long get_openssl_verify_result() const; - - SSL_CTX *ssl_context() const; -#endif - private: std::unique_ptr cli_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED +public: + void set_digest_auth(const std::string &username, + const std::string &password); + void set_proxy_digest_auth(const std::string &username, + const std::string &password); + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + + void set_ca_cert_store(tls::ca_store_t ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + void set_server_certificate_verifier(tls::VerifyCallback verifier); + + void set_session_verifier( + std::function verifier); + + tls::ctx_t tls_context() const; + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) + void enable_windows_certificate_verification(bool enabled); +#endif + +private: bool is_ssl_ = false; #endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +public: + [[deprecated("Use tls_context() instead")]] + SSL_CTX *ssl_context() const; + + [[deprecated("Use set_session_verifier(session_t) instead")]] + void set_server_certificate_verifier( + std::function verifier); + + [[deprecated("Use Result::ssl_backend_error() instead")]] + long get_verify_result() const; +#endif }; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED class SSLServer : public Server { public: SSLServer(const char *cert_path, const char *private_key_path, @@ -2203,32 +2299,60 @@ public: const char *client_ca_cert_dir_path = nullptr, const char *private_key_password = nullptr); - SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); + struct PemMemory { + const char *cert_pem; + size_t cert_pem_len; + const char *key_pem; + size_t key_pem_len; + const char *client_ca_pem; + size_t client_ca_pem_len; + const char *private_key_password; + }; + explicit SSLServer(const PemMemory &pem); - SSLServer( - const std::function &setup_ssl_ctx_callback); + // The callback receives the ctx_t handle which can be cast to the + // appropriate backend type (SSL_CTX* for OpenSSL, + // tls::impl::MbedTlsContext* for Mbed TLS) + explicit SSLServer(const tls::ContextSetupCallback &setup_callback); ~SSLServer() override; bool is_valid() const override; - SSL_CTX *ssl_context() const; + bool update_certs_pem(const char *cert_pem, const char *key_pem, + const char *client_ca_pem = nullptr, + const char *password = nullptr); - void update_certs(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); + tls::ctx_t tls_context() const { return ctx_; } int ssl_last_error() const { return last_ssl_error_; } private: bool process_and_close_socket(socket_t sock) override; - STACK_OF(X509_NAME) * extract_ca_names_from_x509_store(X509_STORE *store); - - SSL_CTX *ctx_; + tls::ctx_t ctx_ = nullptr; std::mutex ctx_mutex_; int last_ssl_error_ = 0; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +public: + [[deprecated("Use SSLServer(PemMemory) or " + "SSLServer(ContextSetupCallback) instead")]] + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + + [[deprecated("Use SSLServer(ContextSetupCallback) instead")]] + SSLServer( + const std::function &setup_ssl_ctx_callback); + + [[deprecated("Use tls_context() instead")]] + SSL_CTX *ssl_context() const; + + [[deprecated("Use update_certs_pem() instead")]] + void update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); +#endif }; class SSLClient final : public ClientImpl { @@ -2242,20 +2366,34 @@ public: const std::string &client_key_path, const std::string &private_key_password = std::string()); - explicit SSLClient(const std::string &host, int port, X509 *client_cert, - EVP_PKEY *client_key, - const std::string &private_key_password = std::string()); + struct PemMemory { + const char *cert_pem; + size_t cert_pem_len; + const char *key_pem; + size_t key_pem_len; + const char *private_key_password; + }; + explicit SSLClient(const std::string &host, int port, const PemMemory &pem); ~SSLClient() override; bool is_valid() const override; - void set_ca_cert_store(X509_STORE *ca_cert_store); + void set_ca_cert_store(tls::ca_store_t ca_cert_store); void load_ca_cert_store(const char *ca_cert, std::size_t size); - long get_openssl_verify_result() const; + void set_server_certificate_verifier(tls::VerifyCallback verifier); - SSL_CTX *ssl_context() const; + // Post-handshake session verifier (backend-independent) + void set_session_verifier( + std::function verifier); + + tls::ctx_t tls_context() const { return ctx_; } + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) + void enable_windows_certificate_verification(bool enabled); +#endif private: bool create_and_connect_socket(Socket &socket, Error &error) override; @@ -2277,26 +2415,45 @@ private: bool load_certs(); - bool verify_host(X509 *server_cert) const; - bool verify_host_with_subject_alt_name(X509 *server_cert) const; - bool verify_host_with_common_name(X509 *server_cert) const; - bool check_host_name(const char *pattern, size_t pattern_len) const; - - SSL_CTX *ctx_; + tls::ctx_t ctx_ = nullptr; std::mutex ctx_mutex_; std::once_flag initialize_cert_; - std::vector host_components_; - long verify_result_ = 0; - friend class ClientImpl; -}; + std::function session_verifier_; + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) + bool enable_windows_cert_verification_ = true; #endif -/* - * Implementation of template methods. - */ + friend class ClientImpl; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +public: + [[deprecated("Use SSLClient(host, port, PemMemory) instead")]] + explicit SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key, + const std::string &private_key_password = std::string()); + + [[deprecated("Use Result::ssl_backend_error() instead")]] + long get_verify_result() const; + + [[deprecated("Use tls_context() instead")]] + SSL_CTX *ssl_context() const; + + [[deprecated("Use set_session_verifier(session_t) instead")]] + void set_server_certificate_verifier( + std::function verifier) override; + +private: + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; +#endif +}; +#endif // CPPHTTPLIB_SSL_ENABLED namespace detail { @@ -2345,66 +2502,6 @@ inline size_t get_header_value_u64(const Headers &headers, } // namespace detail -inline size_t Request::get_header_value_u64(const std::string &key, size_t def, - size_t id) const { - return detail::get_header_value_u64(headers, key, def, id); -} - -inline size_t Response::get_header_value_u64(const std::string &key, size_t def, - size_t id) const { - return detail::get_header_value_u64(headers, key, def, id); -} - -namespace detail { - -inline bool set_socket_opt_impl(socket_t sock, int level, int optname, - const void *optval, socklen_t optlen) { - return setsockopt(sock, level, optname, -#ifdef _WIN32 - reinterpret_cast(optval), -#else - optval, -#endif - optlen) == 0; -} - -inline bool set_socket_opt(socket_t sock, int level, int optname, int optval) { - return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval)); -} - -inline bool set_socket_opt_time(socket_t sock, int level, int optname, - time_t sec, time_t usec) { -#ifdef _WIN32 - auto timeout = static_cast(sec * 1000 + usec / 1000); -#else - timeval timeout; - timeout.tv_sec = static_cast(sec); - timeout.tv_usec = static_cast(usec); -#endif - return set_socket_opt_impl(sock, level, optname, &timeout, sizeof(timeout)); -} - -} // namespace detail - -inline void default_socket_options(socket_t sock) { - detail::set_socket_opt(sock, SOL_SOCKET, -#ifdef SO_REUSEPORT - SO_REUSEPORT, -#else - SO_REUSEADDR, -#endif - 1); -} - -inline std::string get_bearer_token_auth(const Request &req) { - if (req.has_header("Authorization")) { - constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); - return req.get_header_value("Authorization") - .substr(bearer_header_prefix_len); - } - return ""; -} - template inline Server & Server::set_read_timeout(const std::chrono::duration &duration) { @@ -2429,12 +2526,6 @@ Server::set_idle_interval(const std::chrono::duration &duration) { return *this; } -inline size_t Result::get_request_header_value_u64(const std::string &key, - size_t def, - size_t id) const { - return detail::get_header_value_u64(request_headers_, key, def, id); -} - template inline void ClientImpl::set_connection_timeout( const std::chrono::duration &duration) { @@ -2842,105 +2933,73 @@ bool is_field_content(const std::string &s); bool is_field_value(const std::string &s); } // namespace fields - } // namespace detail +/* + * TLS Abstraction Layer Declarations + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED +// TLS abstraction layer - backend-specific type declarations +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT +namespace tls { +namespace impl { + +// Mbed TLS context wrapper (holds config, entropy, DRBG, CA chain, own +// cert/key). This struct is accessible via tls::impl for use in SSL context +// setup callbacks (cast ctx_t to tls::impl::MbedTlsContext*). +struct MbedTlsContext { + mbedtls_ssl_config conf; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_x509_crt ca_chain; + mbedtls_x509_crt own_cert; + mbedtls_pk_context own_key; + bool is_server = false; + bool verify_client = false; + bool has_verify_callback = false; + + MbedTlsContext(); + ~MbedTlsContext(); + + MbedTlsContext(const MbedTlsContext &) = delete; + MbedTlsContext &operator=(const MbedTlsContext &) = delete; +}; + +} // namespace impl +} // namespace tls +#endif + +#endif // CPPHTTPLIB_SSL_ENABLED + namespace stream { class Result { public: - Result() : chunk_size_(8192) {} - - explicit Result(ClientImpl::StreamHandle &&handle, size_t chunk_size = 8192) - : handle_(std::move(handle)), chunk_size_(chunk_size) {} - - Result(Result &&other) noexcept - : handle_(std::move(other.handle_)), buffer_(std::move(other.buffer_)), - current_size_(other.current_size_), chunk_size_(other.chunk_size_), - finished_(other.finished_) { - other.current_size_ = 0; - other.finished_ = true; - } - - Result &operator=(Result &&other) noexcept { - if (this != &other) { - handle_ = std::move(other.handle_); - buffer_ = std::move(other.buffer_); - current_size_ = other.current_size_; - chunk_size_ = other.chunk_size_; - finished_ = other.finished_; - other.current_size_ = 0; - other.finished_ = true; - } - return *this; - } - + Result(); + explicit Result(ClientImpl::StreamHandle &&handle, size_t chunk_size = 8192); + Result(Result &&other) noexcept; + Result &operator=(Result &&other) noexcept; Result(const Result &) = delete; Result &operator=(const Result &) = delete; - // Check if the result is valid (connection succeeded and response received) - bool is_valid() const { return handle_.is_valid(); } - explicit operator bool() const { return is_valid(); } - - // Response status code - int status() const { - return handle_.response ? handle_.response->status : -1; - } - - // Response headers - const Headers &headers() const { - static const Headers empty_headers; - return handle_.response ? handle_.response->headers : empty_headers; - } - + // Response info + bool is_valid() const; + explicit operator bool() const; + int status() const; + const Headers &headers() const; std::string get_header_value(const std::string &key, - const char *def = "") const { - return handle_.response ? handle_.response->get_header_value(key, def) - : def; - } + const char *def = "") const; + bool has_header(const std::string &key) const; + Error error() const; + Error read_error() const; + bool has_read_error() const; - bool has_header(const std::string &key) const { - return handle_.response ? handle_.response->has_header(key) : false; - } - - // Error information - Error error() const { return handle_.error; } - Error read_error() const { return handle_.get_read_error(); } - bool has_read_error() const { return handle_.has_read_error(); } - - // Streaming iteration API - // Call next() to read the next chunk, then access data via data()/size() - // Returns true if data was read, false when stream is exhausted - bool next() { - if (!handle_.is_valid() || finished_) { return false; } - - if (buffer_.size() < chunk_size_) { buffer_.resize(chunk_size_); } - - ssize_t n = handle_.read(&buffer_[0], chunk_size_); - if (n > 0) { - current_size_ = static_cast(n); - return true; - } - - current_size_ = 0; - finished_ = true; - return false; - } - - // Pointer to current chunk data (valid after next() returns true) - const char *data() const { return buffer_.data(); } - - // Size of current chunk (valid after next() returns true) - size_t size() const { return current_size_; } - - // Convenience method: read all remaining data into a string - std::string read_all() { - std::string result; - while (next()) { - result.append(data(), size()); - } - return result; - } + // Stream reading + bool next(); + const char *data() const; + size_t size() const; + std::string read_all(); private: ClientImpl::StreamHandle handle_; @@ -3205,13 +3264,8 @@ struct SSEMessage { std::string data; // Event payload std::string id; // Event ID for Last-Event-ID header - SSEMessage() : event("message") {} - - void clear() { - event = "message"; - data.clear(); - id.clear(); - } + SSEMessage(); + void clear(); }; class SSEClient { @@ -3220,255 +3274,40 @@ public: using ErrorHandler = std::function; using OpenHandler = std::function; - SSEClient(Client &client, const std::string &path) - : client_(client), path_(path) {} - - SSEClient(Client &client, const std::string &path, const Headers &headers) - : client_(client), path_(path), headers_(headers) {} - - ~SSEClient() { stop(); } + SSEClient(Client &client, const std::string &path); + SSEClient(Client &client, const std::string &path, const Headers &headers); + ~SSEClient(); SSEClient(const SSEClient &) = delete; SSEClient &operator=(const SSEClient &) = delete; // Event handlers - SSEClient &on_message(MessageHandler handler) { - on_message_ = std::move(handler); - return *this; - } - - SSEClient &on_event(const std::string &type, MessageHandler handler) { - event_handlers_[type] = std::move(handler); - return *this; - } - - SSEClient &on_open(OpenHandler handler) { - on_open_ = std::move(handler); - return *this; - } - - SSEClient &on_error(ErrorHandler handler) { - on_error_ = std::move(handler); - return *this; - } - - SSEClient &set_reconnect_interval(int ms) { - reconnect_interval_ms_ = ms; - return *this; - } - - SSEClient &set_max_reconnect_attempts(int n) { - max_reconnect_attempts_ = n; - return *this; - } + SSEClient &on_message(MessageHandler handler); + SSEClient &on_event(const std::string &type, MessageHandler handler); + SSEClient &on_open(OpenHandler handler); + SSEClient &on_error(ErrorHandler handler); + SSEClient &set_reconnect_interval(int ms); + SSEClient &set_max_reconnect_attempts(int n); // State accessors - bool is_connected() const { return connected_.load(); } - const std::string &last_event_id() const { return last_event_id_; } + bool is_connected() const; + const std::string &last_event_id() const; // Blocking start - runs event loop with auto-reconnect - void start() { - running_.store(true); - run_event_loop(); - } + void start(); // Non-blocking start - runs in background thread - void start_async() { - running_.store(true); - async_thread_ = std::thread([this]() { run_event_loop(); }); - } + void start_async(); // Stop the client (thread-safe) - void stop() { - running_.store(false); - client_.stop(); // Cancel any pending operations - if (async_thread_.joinable()) { async_thread_.join(); } - } + void stop(); private: - // Parse a single SSE field line - // Returns true if this line ends an event (blank line) - bool parse_sse_line(const std::string &line, SSEMessage &msg, int &retry_ms) { - // Blank line signals end of event - if (line.empty() || line == "\r") { return true; } - - // Lines starting with ':' are comments (ignored) - if (!line.empty() && line[0] == ':') { return false; } - - // Find the colon separator - auto colon_pos = line.find(':'); - if (colon_pos == std::string::npos) { - // Line with no colon is treated as field name with empty value - return false; - } - - auto field = line.substr(0, colon_pos); - std::string value; - - // Value starts after colon, skip optional single space - if (colon_pos + 1 < line.size()) { - auto value_start = colon_pos + 1; - if (line[value_start] == ' ') { value_start++; } - value = line.substr(value_start); - // Remove trailing \r if present - if (!value.empty() && value.back() == '\r') { value.pop_back(); } - } - - // Handle known fields - if (field == "event") { - msg.event = value; - } else if (field == "data") { - // Multiple data lines are concatenated with newlines - if (!msg.data.empty()) { msg.data += "\n"; } - msg.data += value; - } else if (field == "id") { - // Empty id is valid (clears the last event ID) - msg.id = value; - } else if (field == "retry") { - // Parse retry interval in milliseconds - { - int v = 0; - auto res = - detail::from_chars(value.data(), value.data() + value.size(), v); - if (res.ec == std::errc{}) { retry_ms = v; } - } - } - // Unknown fields are ignored per SSE spec - - return false; - } - - // Main event loop with auto-reconnect - void run_event_loop() { - auto reconnect_count = 0; - - while (running_.load()) { - // Build headers, including Last-Event-ID if we have one - auto request_headers = headers_; - if (!last_event_id_.empty()) { - request_headers.emplace("Last-Event-ID", last_event_id_); - } - - // Open streaming connection - auto result = stream::Get(client_, path_, request_headers); - - // Connection error handling - if (!result) { - connected_.store(false); - if (on_error_) { on_error_(result.error()); } - - if (!should_reconnect(reconnect_count)) { break; } - wait_for_reconnect(); - reconnect_count++; - continue; - } - - if (result.status() != 200) { - connected_.store(false); - // For certain errors, don't reconnect - if (result.status() == 204 || // No Content - server wants us to stop - result.status() == 404 || // Not Found - result.status() == 401 || // Unauthorized - result.status() == 403) { // Forbidden - if (on_error_) { on_error_(Error::Connection); } - break; - } - - if (on_error_) { on_error_(Error::Connection); } - - if (!should_reconnect(reconnect_count)) { break; } - wait_for_reconnect(); - reconnect_count++; - continue; - } - - // Connection successful - connected_.store(true); - reconnect_count = 0; - if (on_open_) { on_open_(); } - - // Event receiving loop - std::string buffer; - SSEMessage current_msg; - - while (running_.load() && result.next()) { - buffer.append(result.data(), result.size()); - - // Process complete lines in the buffer - size_t line_start = 0; - size_t newline_pos; - - while ((newline_pos = buffer.find('\n', line_start)) != - std::string::npos) { - auto line = buffer.substr(line_start, newline_pos - line_start); - line_start = newline_pos + 1; - - // Parse the line and check if event is complete - auto event_complete = - parse_sse_line(line, current_msg, reconnect_interval_ms_); - - if (event_complete && !current_msg.data.empty()) { - // Update last_event_id for reconnection - if (!current_msg.id.empty()) { last_event_id_ = current_msg.id; } - - // Dispatch event to appropriate handler - dispatch_event(current_msg); - - current_msg.clear(); - } - } - - // Keep unprocessed data in buffer - buffer.erase(0, line_start); - } - - // Connection ended - connected_.store(false); - - if (!running_.load()) { break; } - - // Check for read errors - if (result.has_read_error()) { - if (on_error_) { on_error_(result.read_error()); } - } - - if (!should_reconnect(reconnect_count)) { break; } - wait_for_reconnect(); - reconnect_count++; - } - - connected_.store(false); - } - - // Dispatch event to appropriate handler - void dispatch_event(const SSEMessage &msg) { - // Check for specific event type handler first - auto it = event_handlers_.find(msg.event); - if (it != event_handlers_.end()) { - it->second(msg); - return; - } - - // Fall back to generic message handler - if (on_message_) { on_message_(msg); } - } - - // Check if we should attempt to reconnect - bool should_reconnect(int count) const { - if (!running_.load()) { return false; } - if (max_reconnect_attempts_ == 0) { return true; } // unlimited - return count < max_reconnect_attempts_; - } - - // Wait for reconnect interval - void wait_for_reconnect() { - // Use small increments to check running_ flag frequently - auto waited = 0; - while (running_.load() && waited < reconnect_interval_ms_) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - waited += 100; - } - } + bool parse_sse_line(const std::string &line, SSEMessage &msg, int &retry_ms); + void run_event_loop(); + void dispatch_event(const SSEMessage &msg); + bool should_reconnect(int count) const; + void wait_for_reconnect(); // Client and path Client &client_;